#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jan 30 14:05:37 2021

@author: ojacques
"""
from scipy.spatial import distance
import math
import random
import numpy as np
import statistics
import matplotlib.pyplot as plt
import time as t

# Imprime informações sobre os dados de iris.data
def info_dataset(amostras, verbose=True):
	if verbose:
		print('Total de amostras: %d' % len(amostras))
	rotulo1, rotulo2, rotulo3 = 0, 0, 0
	for amostra in amostras:
		if amostra[-1] == 'Iris-setosa':
			rotulo1 += 1
		elif amostra[-1] == 'Iris-versicolor':
			rotulo2 += 1
		elif amostra[-1] == 'Iris-virginica':
			rotulo3 += 1
		else:
			print("nenhuma classificação!")
	if verbose:
		print('Total rotulo 1: %d' % rotulo1)
		print('Total rotulo 2: %d' % rotulo2)
		print('Total rotulo 3: %d' % rotulo3)
	return [len(amostras), rotulo1, rotulo2, rotulo3]


def classifica(amostras,C):
    dist = [np.sum((a-C)**2,axis=1)**0.5 for a in amostras] #axis=0, soma das colunas (| | .. |), axis=1 soma das linhas
    return np.argmin(dist,1) # pega o índice do menor valor em cada linha da matriz (0-coluna, 1-linha). Indica o centroide mais próximo de a
    
######################### PROGRAMA #########################

#     # Inicialização da lista de amostras
#     amostras = []
    
#     # Carregar dados de arquivo CSV com amostras
#     with open('iris.data', 'r') as f:
#     	#ler arquivo linha por linha
#     	for linha in f.readlines():
#     		# obter os atributos da amostra
#     		atrib = linha.replace('\n','').split(',') #separa os dados conforme a vírgula
#     		amostras.append([float(atrib[0]), float(atrib[1]), float(atrib[2]), float(atrib[3]), atrib[4]])
    	
#     # Guardar totais de cada rótulo/classe
# tamAmostras, rotulo1, rotulo2, rotulo3 = info_dataset(amostras, verbose=False)
K=2
amostras=np.array([[8,5],
          [22,14],
          [18,8],
          [9,7],
          [20,15],
          [13,13],
          [6,23],
          [29,27],
          [11,29],
          [13,2]], float)
print(amostras)

#(4,12) (17,16)

mini=np.min(amostras,axis=0)
maxi=np.max(amostras,axis=0)

l,d = amostras.shape # d é o número de variáveis ou número de campos, a dimensão

#chuta os K centróides, cada centróide com d colunas
#Cada centroide é uma linha com d colunas

C=np.array([[np.random.randint(mini[col],maxi[col]) for col in range(d)] for k in range(K)],dtype=float) #C[K,d]
print('Centróide: ')
print(C)
continua=True
it = 0
cores = ['b','g','r','c','m','y','k','w']
print('Total de iterações: %d' %it)
plt.subplot(1,2,1)
plt.scatter(amostras[:,0],amostras[:,1])


#Centroides C0, .. CK com d colunas, d é a dimensao
for k in range(K):
    plt.scatter(C[k,0],C[k,1], marker='x',linewidth=3, color = cores[k],)
    plt.axis([0,np.max(maxi)+2,0,np.max(maxi)+2])    
  
while continua:
    it +=1
    classes=classifica(amostras,C)
    qtdeClasse = statistics.Counter(classes) #conta elementos em cada classe
    cl = classes.reshape(l,1) #transforma em vetor algébrico
    oldC = C
    #Encontra o novo centróide
    C =np.array([np.sum(amostras*(cl==c),0)/qtdeClasse[c] for c in range(K)])
    print("\nIteração: %d " %it)
    print("Centróide: ")
    print(C)
    #input('<Pressione qualquer tecla>')
    plt.subplot(1,2,2)
    plt.cla() #limpa o axe(2) (grafico 2)
    for k in range(K):
        
        plt.scatter(C[k,0],C[k,1], marker='x',linewidth=3, color = cores[k])
        plt.axis([0,np.max(maxi)+2,0,np.max(maxi)+2])
        am = amostras[classes==k]
        plt.scatter(am[:,0],am[:,1], color=cores[k])
        plt.draw()
        
    plt.pause(6) #permite pausar para mostrar plotagem    
    if np.array_equal(C,oldC):
        continua = False

cores = ['b','g','r','c','m','y','k','w']
print('Total de iterações: %d' %it)
plt.subplot(1,2,2)
for k in range(K):
    p=plt.scatter(C[k,0],C[k,1], marker='x',linewidth=3, color = cores[k])
    am = amostras[classes==k]
    plt.scatter(am[:,0],am[:,1], color=cores[k])
    plt.axis([0,np.max(maxi)+2,0,np.max(maxi)+2])
    plt.draw()
#plt.show()

input('<Pressione qualquer tecla>')    

        

    


   
