import numpy as np
import numpy.random as rd
import matplotlib.pyplot as plt

## distance 
def euc(X : [float], Y : [float]) -> float :
    """Prend en argument deux listes de flottant de même taille et 
    renvoie la distance euclidienne entre ces deux vecteurs.
    """
    S = 0
    for i in range(len(X)) :
        S += (X[i]-Y[i])**2
    return S

## classe
def classe(L : [[float]], B : [[float]]) -> [[int]] :
    """Prend en argument une liste de listes L et une liste de 
    barycentre B et renvoie une liste de listes C de même taille de
    que B telle que C[i] contient toutes les indices de vecteur L 
    qui sont le plus proche du barycentre B[i].
    """
    C = [[] for i in range(len(B))]
    for i in range(len(L)) :
        dmin = float('inf')
        for j in range(len(B)) :
            d = euc(L[i],B[j])
            if d < dmin :
                dmin = d
                jmin = j
        C[jmin].append(i)
    return C

## Barycentre
def bary(c : [int], L : [[float]]) -> [float] :
    """Prend en argument la liste c des indices d'une classe et l
    a liste L des vecteurs et renvoie le barycentre des vecteurs de 
    L dont les indices sont dans c.
    """
    n = len(L[0])
    b = [0]*n
    for i in c :
        for j in range(n) :
            b[j] += L[i][j]
    for i in range(n) :
        b[i] = b[i]/len(c)
    return b

## convertion
def MatToLigne(M : np.ndarray) -> [[float]] :
    """Prend en argument un tableau de taille (n,m,3) et renvoie 
    la liste de listes de taille (n*m,3) correspondante.
    """
    L = []
    for m in M :
        L += [[int(y) for y in x] for x in m]
    return L

def LigneToMat(L : [[float]], n : int, m : int) -> np.ndarray :
    """Prend en argument une liste de listes de taille (n*m,3) et 
    deux entiers n et m et renvoie le tableau de taille (n,m,3) 
    correspondant.
    """
    M = np.zeros((n, m, 3), dtype = 'uint8')
    for i in range(n) :
        for j in range(m) :
            M[i,j] = L[i*m+j][:3]
    return M

## Kmeans
def Kmeans(L : [[float]], k : int) -> ([[int]], [[float]]) :
    """Prend en argument une liste de listes L et un entier k 
    applique l'algorithme kmeans et renvoie la liste de listes C et 
    une liste de listes B telles que C[i] contient les indices des 
    vecteurs étant le plus proche du barycentre B[i]
    """
    B = [L[rd.randint(0,len(L))] for _ in range(k)]
    C = classe(L, B)
    Bp = [bary(c, L) for c in C]
    while B != Bp :
        B = Bp.copy()
        C = classe(L, B)
        Bp = [bary(c, L) for c in C]
    return C, B

###################################################################
## réduction du nombre de couleurs
def reduc(M : np.ndarray, k : int) -> np.ndarray :
    """Prend en argument un tableau correspondant à une photo et 
    k un entier et revoie le tableau de la même photo en k couleurs
    """
    n, m = len(M), len(M[0])
    L = MatToLigne(M)
    C, B = Kmeans(L, k)
    Lp = [[0]*3 for i in range(len(L))]
    for k in range(len(C)) :
        for i in C[k] :
            Lp[i] = B[k]
    return LigneToMat(Lp, n, m)
## Pour tester sur une image
"""
Mim = plt.imread("perroquet.jpg")
Mimp = reduc(Mim, 8)
plt.imshow(Mimp)
plt.show()
"""

## Inertie
def inertie(B : [[float]], X : [float]) -> float :
    """Prend en argument une liste de listes B et une liste X et 
    calcule l'inertie de ce vecteur par rapport à cette famille.
    """
    S = 0
    for b in B :
        S += euc(b, X)
    return S
###### Pour aller plus loin
## Farthest
def farthest(L : [[float]], k : int) -> [[int]] :
    B = [L[rd.randint(0,len(L))]]
    for i in range(k) :
        jmax, Imax = 0, 0
        for j in range(len(L)) :
            I = inertie(B, L[j])
            if (L[j] not in B) and I > Imax :
                jmax, Imax = j, I
        B.append(L[jmax])
    return B

def Kmeansf(L : [[float]], k : int) -> ([[int]], [[float]]) :
    """Prend en argument une liste de listes L et un entier k 
    applique l'algorithme kmeans et renvoie la liste de listes C 
    et une liste de listes B telles que C[i] contient les indices 
    des vecteurs étant le plus proche du barycentre B[i]
    """
    B = farthest(L, k)
    C = classe(L, B)
    Bp = [bary(C[i], L) for i in range(k)]
    n = 1
    while B != Bp :
        B = Bp.copy()
        C = classe(L, B)
        Bp = [bary(C[i], L) for i in range(k)]
        n += 1
    print(n)
    return C, B

def reducf(M : np.ndarray, k : int) -> np.ndarray :
    """Prend en argument un tableau correspondant à une photo et k 
    un entier et renvoie le tableau de la même photo en k couleurs
    """
    n, m = len(M), len(M[0])
    L = MatToLigne(M)
    C, B = Kmeansf(L, k)
    Lp = [[0]*3 for i in range(len(L))]
    for k in range(len(C)) :
        for i in C[k] :
            Lp[i] = B[k]
    return LigneToMat(Lp, n, m)

"""
Mimp = plt.imread("perroquet.jpg")
Mimp = reducf(Mimp, 4)

plt.imshow(Mimp)
plt.show()
"""