import numpy as np
import numpy.random as rd
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt

## distance
def euc(X : [float], Y : [float]) -> float :
    """Prend en argument deux listes de flottant de même taille et 
    renvoie la distance euclidienne entre ces deux vecteurs.
    """
    S = 0
    for i in range(len(X)) :
        S += (X[i]-Y[i])**2
    return S

## Plus proches voisin
def liste_voisin(p : [float], k : int, Ltrain : list) -> list :
    """Prend en argument un point p, un entier k et une liste de 
    couples (point, étiquette) Ltrain et qui renvoiela liste des k
    vecteur les plus proches de p pour al distance euclidienne.
    """
    D = [euc(p, L[0]) for L in Ltrain]
    I = []
    for i in range(k) :
        jmin = 0
        for j in range(len(D)) :
            if j not in I and D[j]<D[jmin] :
                jmin = j
        I.append(jmin)
    return [Ltrain[i] for i in I]

## étiquette majoritaire dans la liste des k plus proches voisins
def etiquette_maj(L : list) -> int :
    """Prend en agument une liste de couples (point, étiquette) L et
    renvoie l'étiquette majoritaire dans cette liste
    """
    d = {}
    for l in L :
        if l[1] in d :
            d[l[1]] += 1
        else :
            d[l[1]] =1
    m = 0
    for x in d :
        if d[x] > m :
            xm = x
            m = d[x]
    return xm

## l'algorithme knn
def knn(p : [float], k : int , Ltrain : list) -> int :
    """Prend en argument un vecteur p, un entier k et liste Ltrain de
    couples (point, étiquette) et qui renvoie l'étiquette majoritaire 
    des k plus proches voisins de p dans la liste Ltrain.
    """
    L =liste_voisin(p, k, Ltrain)
    return etiquette_maj(L)

## Création de liste des couples (point, étiquette)
digits = load_digits()
X = digits.data
Y = digits.target
Z = digits.images
L = []
for k in range(len(X)) :
    L.append([X[k],Y[k]])

## Mélange de cette liste
rd.shuffle(L)

## Séparation de la liste en deux parties, une pour l'entrainement et l'autre 
## pour le test
Ltrain = L[:int(len(L)*.8)]
Ltest = L[int(len(L)*.8):]

## Test pour différente valeur de k et calcul de la matrice de confusion

for k in range(1, 8) :
    M = np.zeros((10,10), dtype=int)
    p = 0
    for l in Ltest :
        et = knn(l[0], k, Ltrain)
        if et == l[1] :
            p += 1
        M[et,l[1]] += 1
    print(k, p/len(Ltest))
    print(M)