################################################################################
################################################################################
#   K-MEAN
################################################################################
################################################################################

import numpy as np, numpy.random as rd, numpy.linalg as alg
import matplotlib.pyplot as plt

################################################################################
#   GENERATION DE DONNEES
################################################################################

C=3
tm=[(0,0),(2,3),(2,0)]
ts=[(.4,.4),(.6,.7),(.5,.5)]


def gener_pop(N,tm,ts):
    C=len(tm) # C : nombre de classes
    pop=[[] for k in range(C)]
    for c in range(C):
        mx,my=tm[c]
        sx,sy=ts[c]
        for i in range(N):
            pt=np.array([rd.normal(mx,sx),rd.normal(my,sy)])
            pop[c].append(pt)
    return pop

def melange(pop):
    # On perd la classification
    res=[]
    C=len(pop)
    for c in range(C):
        for pt in pop[c]:
            res.append(pt)
    return res
    

pop_init=gener_pop(20,tm,ts)
pop=melange(pop_init)

tx=[pt[0] for pt in pop]
ty=[pt[1] for pt in pop]
plt.plot(tx,ty,'o')
plt.show()


################################################################################
#   K-MEAN
################################################################################

k=3

# INITIALISATION

# À modifier!
decrease=False
V=float('inf')

# Initialisation de la partition
S=[[] for i in range(k)]
#
# À compléter
#

# Initialisation des centroïdes
centr=[0]*k


while decrease:

    # Affichage des k classes de points au cours de la boucle
    for i in range(k):
        tx=[pt[0] for pt in S[i]]
        ty=[pt[1] for pt in S[i]]
        plt.plot(tx,ty,'o')
    plt.show()

    # Calcul des centroïdes
    #
    # À compléter
    #
    
    # Mise à jour de la partition
    S=[[] for i in range(k)]
    #
    # À compléter
    #
    
    # Gestion du booléen decrease
    #
    # À compléter
    # 

# Affiche des k classes en sortie des k-mean
for i in range(k):
    tx=[pt[0] for pt in S[i]]
    ty=[pt[1] for pt in S[i]]
    plt.plot(tx,ty,'o')
plt.show()
