import matplotlib.pyplot as plt
import os
import random

##
#Question 1
#Lire l'image, obtenir une liste de listes de listes

original=plt.imread("S:\PSI ET\Espace d'échange\INFORMATIQUE\TP4 compression/sceaux.jpg").tolist()
##
#Question 2
# faire la liste des triplets/pixels
data0=[]
for i in range(len(original)):
    for j in range(len(original[i])):
        data0.append(original[i][j])

##
#Question 3
#arrondi:
data=[]
for pixel in data0:
    for j in range(3):
        pixel[j]=5*round(pixel[j]/5)
    data.append(pixel)
##
#Question 4
def init(data,k): # initialisation naive de K means
    return random.sample(data,k)

##
#Question 5
def dist2(x,y):
    s=0
    for i in range(len(x)):
        s=s+(x[i]-y[i])**2
    return s

def indice_plus_proche(x,L):
    imin=0
    dmin=dist2(x,L[0])
    for i in range(1,len(L)):
        d=dist2(x,L[i])
        if d < dmin:
            dmin=d
            imin=i
    return imin

# Phase d'affectation de l'algo K moyennes
def affectation(data,M):
    A=[]
    for x in data:
        A.append(indice_plus_proche(x,M))
    return A
##
#Question 6

# addition de vecteurs
def addition(x,y):
    z=[]
    for i in range(len(x)):
        z.append(x[i]+y[i])
    return z
# multiplication d'un vecteur par un scalaire
def multiplication(a,x):
    z=[]
    for i in range(len(x)):
        z.append(a*x[i])
    return z
#Arrondi d'un vecteur :
def arrondi(x):
    z=[]
    for j in range(len(x)):
        z.append(5*round(x[j]/5))
    return z


##
#Question 7

# phase de mise à jour de l'algo K moyennes
def mise_a_jour(data,C,k):
    # Initialisation des compteurs
    M=[] # moyennes
    E=[0]*k # effectifs
    for i in range(k):
        M.append( [0]*len(data[0]) )
    # calcul des sommes
    for i in range(len(data)):
        ci=C[i] # la classe de l'exemple i
        E[ci]=E[ci]+1
        M[ci]=addition(M[ci], data[i])
    # division des sommes
    for j in range(k):
        M[j]=multiplication(1/E[j] , M[j])
        M[j]=arrondi(M[j])
    return M

##
#Question 8
# Algo des K moyennes
def kmeans(data,k):
    M=initpp(data,k) # utiliser init (cf ci dessus) ou initpp (voir plus bas)
    fini=False
    i=0
    while not fini:
        print('iteration N°' + str(i) )
        print('  ')
        print(M)
        i=i+1
        A=affectation(data,M)
        print('***********************************')
        M2=mise_a_jour(data,A,k)
        if M2==M:
            fini=True
        else:
            M=M2
    return M,A

##
#Question 9
# reconstruction de l'image «compressée»
def reconstruction(M,A,h,w):
    im2=[]
    for i in range(h):
        im2.append([])
        for j in range(w):
            im2[i].append(M[A[i*w+j]])
    return im2

##
#Question 10
h=len(original)
w=len(original[0])
M,A=kmeans(data,32)
im2=reconstruction(M,A,h,w)
for i in range(h):
    for j in range(w):
        for k in range(3):
           im2[i][j][k]=int(im2[i][j][k])
plt.imshow(im2)
plt.show()

##
#Raffinements

# Sommes des moments d'inertie des clusters (pas testé)
def inertie(data,M,A):
    s=0
    for i in range(len(data)):
        s=s+dist2(data[i],M[A[i]])
    return s

# Initialisation par la methode «kmeans++»
def initpp(data,k):
    L=[random.choice(data)]
    print("init n°",1,"fait")
    for i in range(k-1):
        L.append( next_choice(data,L) )
        print("init n°",i+2,"fait")
    return L

# dichotomie : premier i tel que L[i] dépasse t (avec L triée)
def indice_premier_depassement(L,t):
    a=0
    b=len(L)-1
    while a<b:
        c=(a+b)//2
        if L[c]>t:
            b=c
        else:
            a=c+1
    return a

# choix d'un point de données, avec proba pondérée par le
# carré de la distance au point de L le plus proche
def next_choice(data,L):
    proba=[]
    cumul=0
    for x in data:
        j=indice_plus_proche(x,L)
        y=L[j]
        cumul=cumul+dist2(x,y)
        proba.append(cumul)
    t=random.uniform(0,cumul)
    i=indice_premier_depassement(proba,t)
    return data[i]

