import random as rd
import numpy as np
import matplotlib.pyplot as plt


#----------outils de test et visualisation, dimension 2----------------

def donneesSimulees(nbPts=500,nbCluster=5,sigma=0.01):
    """crée nbPts observations 
    réparties dans nbCluster clusters 
    de "taille" sigma - travail dans le plan"""
    #détermination de centres aléatoires pour les clusters
    centres=[[rd.random(),rd.random()] for i in range(nbCluster)]
    #création des observations, 
    observations=[]
    groupe=[]
    for n in range(nbPts):
        #chacune est rattachée à un centre aléatoire
        numCentre=rd.randint(0,nbCluster-1)
        groupe.append(numCentre)
        posx,posy=centres[numCentre]
        #et située sur une gaussienne centrée en ce point
        x=rd.gauss(posx,sigma)
        y=rd.gauss(posy,sigma)
        observations.append([x,y])
    return centres,observations,groupe

def afficheCentres(points,symbole):
    X=[c[0] for c in points]
    Y=[c[1] for c in points]
    plt.plot(X,Y,symbole+"k")

def afficheClasses(centres,observations,groupe):
    """affiche jusqu'à 0 nuages de points sur la fenêtre courante"""
    n=len(observations)
    plt.axis("equal")
    for i in range(len(centres)):
        X=[]
        Y=[]
        for k in range(n):
            if groupe[k]==i:
                X.append(observations[k][0])
                Y.append(observations[k][1])
        plt.plot(X,Y,".",color="C"+str(i))
    afficheCentres(centres,"+")

#création de données simulées
centres,observations,groupe=donneesSimulees()

#------------programmation des k moyennes (dim quelconque)-----


def distCarre(X1,X2):
    """carré de la distance euclidienne entre deux vecteurs de R^d"""
    pass

def initialisationForgy(observations,k):
    #voir la documentation de rd.sample
    pass
    
def association(observations,centres):
    pass
    
def newCentres(observations,references,k):
    pass

def kMoyAfficheEtapes(observations,k):
    plt.subplot(231)
    plt.title("Vraies données")
    afficheClasses(observations,groupe)
    plt.subplot(232)
    references=[0 for i in range(len(observations))]
    #initialiser la variable centres..............
    #centres=....
    plt.title("Départ")
    afficheClasses(observations,references)
    afficheCentres(centres,"+") 
    #étapes    
    for i in range(5):
        #Mise à jour de references et centres
        #
        #
        plt.subplot(233+i)
        plt.title("Etape "+str(i+1))
        afficheClasses(observations,references)
        afficheCentres(centres,"+")
    plt.show()
    return references
    
def listesIdentiques(l1,l2):
    pass

def kMoy(observations,k,groupe,initialisation=initialisationForgy,affichage=True):
    #initialiser centres avec la fonction initialisation fournie
    if affichage:
        plt.subplot(131)
        plt.title("Vraies données")
        afficheClasses(observations,groupe)
        plt.subplot(132)
        references=[0 for i in range(len(observations))]
        plt.title("Départ")
        afficheClasses(observations,references)
        afficheCentres(centres,"+")   
    #calculer la première association avant de lancer la boucle
    #references=.....
    # puis la boucle......      
    if affichage:
        plt.subplot(133)
        plt.title("Résultat après "+str(i)+ "étapes")
        afficheClasses(centres,observations,references)
        afficheCentres(centres,"+")     
    return references

def fonctionRepartition(poids):
    pass
    
def tirageHasard(poids):
    pass
    
def initialisationPlusPlus(observations,k):
    pass

def kMoyPlusPlus(observations,k,groupe,affichage=True):
    pass

#------ fonctions d'évaluation ---------

def distorsion(observations,groupes,k):
    pass

def distMoyACluster(observations,i0,groupes,indG):
    pass
    
def coeffSilhouettePoint(observations,i0,groupes,k):
    pass
    
def coeffSilhouetteMoy(observations,groupes,k):
    pass

def graphiquePerformance(observations,fonctionEval):
    valeurs=[]
    for k in range(2,10):
        reponse3=kMoyPlusPlus(observations,k,False)
        val=fonctionEval(observations,reponse3,k)
        print("Pour k=",k," coefficient ",val)
        valeurs.append(val)
    plt.plot([k for k in range(2,10)],valeurs)
    plt.show()

#------mise en oeuvre, visualisation (d=2)---------

# affichage des étapes
#reponse1=kMoyAfficheEtapes(observations,5,groupesDepart)

#algorithme des k-Moyennes    
#reponse2=kMoy(observations,5,groupesDepart)

#algorithme des k-Moyennes  ++  
#reponse3=kMoyPlusPlus(observations,5,groupesDepart)

#affichage final
plt.show()

#------performance, choix de k---------

graphiquePerformance(observations,distorsion)

#graphiquePerformance(observations,coeffSilhouetteMoy)




