from math import *
from random import *
import  matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy

from sklearn.datasets import load_digits

# Chargement du jeu de données des chiffres
chiffres=load_digits()

X=chiffres.data
#print(X)
Y=chiffres.target # recuperation des classes
#print(Y)

# Affichage des chiffres
def afficherChiffres( X, Y, limit_max=10 ):
    """   afficherChiffres( X: list, Y: list, limit_max=10 : int ):
          entrees : X, liste du jeu de données
                  : Y, liste des etiquettes associees
                  : limit_max, entier, nombre de caractères affichés de chaque type
    """
    classes, nombres = np.unique( Y, return_counts = True )
    nombre_max = min( np.max(nombres), limit_max )
    img = np.zeros( ( 100, nombre_max*10 ) )
    for i in range( 10 ) :
        index_classe = np.where(Y == i)[0][:limit_max]
        for j, echantillon in enumerate( index_classe ):
            img[i*10+1:i*10+9,j*10+1:j*10+9] = X[echantillon].reshape((8, 8))
    plt.imshow( img, cmap='binary' )
    plt.xticks([])
    plt.yticks( 5 + 10*np.arange(10), np.arange(10) )
    plt.show()
# afficherChiffres( X, Y, 14)

def obtenir_echantillons( X , Y, nbEchantillons = 10 ):
    '''Donne une sélection aléatoire d indices d'échantillons pour chaque classe.'''
    index = []
    for classe in np.unique(Y):
        index_classe = np.where(Y == classe)[0]
        replace = len(index_classe) > nbEchantillons
        index += list( np.random.choice(index_classe, size=nbEchantillons, replace=replace) )
    return index

index = obtenir_echantillons( X, Y )
afficherChiffres( X[index], Y[index] )


""" _________________________________________________
      k plus proches voisins
_____________________________________________________"""
def distance( im1 , im2 ):
	d=0
	for i in range( len( im1 ) ):
		d = d + ( im2[i] - im1[i] )**2
	d = np.sqrt( d )
	return d

def PlusProchesVoisins( X , Y , im , k ):
    # calcul des distances entre chaque image du jeu de données et l'image im
    lesDistances=[]
    nbClasses = 10
    for i in range( len(X) ):
        d = distance( X[i] , im )
        lesDistances.append( [d,Y[i]] )

    # tri par ordre croissant de distance
    lesDistances.sort()

    # on cherche la classe majoritaire parmi les voisins les plus présents :
    # on compte le nombre de voisins plus proches dans chaque catégorie
    voisins=[0]*nbClasses
    for i in range( k ): #on regarde les k plus proches voisins de im
        voisins[ lesDistances[i][1] ]+=1
    # on cherche la ou les catégories plus représentées
    rep = [] # liste des chiffres les plus représentés
    maxi = 0 # compte le maximum
    for i in range( nbClasses ):
        if voisins[ i ] > maxi:
            maxi = voisins[ i ]
            rep = [ i ]
        elif voisins[ i ]== maxi :
            maxi = voisins[ i ]
            rep.append( i )
    return rep


# Séparation des données d'apprentissage et de tests
AX,AY=[],[]
TX,TY=[],[]
dA={} # dictionnaire des clés déjà choisies comme donnée d'apprentissage
dT={}
i=0 # compteur
while i<= len( X )-1:
	j=randint( 0, len( X )-1 )
	if j not in dA:
		AX.append( X[j] )
		AY.append( Y[j] )
		dA[j]=1
		i=i+1
	elif j not in dT:
		TX.append( X[j] )
		TY.append( Y[j] )
		dT[ j ]=1
		i=i+1
# apreès separation puis echantillonnage
index = obtenir_echantillons(AX, AY)
#afficherChiffres(AX, AY)
afficherChiffres( X[index], Y[index] )

# Précision
nb = 0 # compteur des bonnes prédictions
for i in range(len(TX)):
	prediction=PlusProchesVoisins( AX, AY, TX[i], 3)
	if len( prediction ) == 1: # pas d'indécision
		if prediction[0] == TY[i]:
			nb = nb + 1
nb = nb / len( TX )

"""___________________________________
      k-moyennes
______________________________________"""
def barycentre(s):
    bary=[0]*64
    for k in range(len(s)):
        for i in range(64):
            bary[i]=bary[i]+s[k][i]#/len(s)
        bary[i] = bary[i] / len( s )
    return bary

def dist( X1, X2 ):
    d=0
    for i in range(len(X1)):
        d=d+(X1[i]-X2[i])**2
    d= np.sqrt(d)
    return d

def kmoyennes( X, k ):
    # choix au hasard de k images
    changement = True
    lesCentres=np.array([ X[randint(0,len(X)-1)] for i in range(k)] )

    while changement: # on boucle tant que les centres bougent
        lesClasses=[ [] for i in range(k) ] # k classes
        for i in range(len(X)):
            # recherche du centre le plus proche de Xi
            dmin = np.inf
            jmin = -1
            for j in range(k):
                d = dist(X[i],lesCentres[j])
                if d < dmin:
                    dmin = d
                    jmin = j

            # on ajoute Xi à la classe lesClasses[jmin]
            lesClasses[ jmin ].append( X[i] )

      # calcul des nouveaux barycentres pour chaque classe
        nouveauxCentres=np.array([barycentre(lesClasses[i]) for i in range(k)])
        if nouveauxCentres.all() == lesCentres.all():
            changement = False
        else :
            lesCentres=deepcopy( nouveauxCentres )
    return lesClasses , lesCentres

lesClasses , lesCentres = kmoyennes(X,10)
index=[i for i in range(10)]

afficherChiffres( lesCentres, Y[index] )