#########################################################################################
#         Isolement de chiffres dans une image au format png (COuleur ou n&b)           #
#       Puis construction d'une base de données pour tester l'algorithme des KNN        #
#########################################################################################

import matplotlib.pyplot as plt
import numpy as np


def convertir_image(img:np.ndarray) -> np.ndarray :  # Pour passer l'image en niveaux de gris
    img_g = np.zeros((len(img),len(img[0])))
    for i in range(len(img_g)):
        for j in range(len(img_g[0])):
            niveau_gris = 0.2126*img[i,j,0]+0.7152*img[i,j,1]+0.0722*img[i,j,2]
            img_g[i,j] = np.float32(niveau_gris)
    return img_g

image_c = plt.imread('chiffres_gris.png')
image_g= convertir_image(image_c)

def affiche(img):                   # Fonction pour afficher l'image
    plt.close()
    plt.figure(1, figsize=(2, 2))
    plt.imshow(img,cmap='gray')
    plt.show()

def dim(img:np.ndarray) -> tuple :   # Renvoie les dimensions d'une image
    nl,nc=len(img),len(img[0])
    return nl,nc


def img_blanche(nl:int,nc:int) -> np.ndarray :  # Pour créer une image blanche
        liste = [[1.]*nc for i in range(nl)]
        return np.array(liste)


def convertir_N_B(img:np.ndarray,seuil:float) -> np.array : # Pour passer une image en
    nl,nc=dim(img)                                          #  niveau de gris en une
    img_nb=img_blanche(nl,nc)                               #  image en noir ou blanc
    for i in range(nl):
        for j in range(nc):
            if img[i,j] < seuil:
                img_nb[i,j]= 0.0
            else:
                img_nb[i,j]= 1.0
    return img_nb


def graphe_vide(nl:int,nc:int)->np.ndarray :    # Crée un graphe vide
        graph = [[None]*nc for i in range(nl)]
        return np.array(graph)

def creer_graph(img:np.ndarray,seuil:float)->np.ndarray :   # Crée un graphe à partir
    nl,nc = dim(img)                                        #  de l'image en noir ou blanc
    imgnb = convertir_N_B(img,seuil)
    graph = graphe_vide(nl,nc)
    for i in range(nl):
        for j in range(nc):
            if imgnb[i,j] == 0 :
                list_v = []
                if i==0:
                    ih=0
                else : ih=i-1
                if i==nl-1 :
                    ib=i
                else : ib=i+1
                if j==0:
                    jg=0
                else : jg=j-1
                if j==nc-1:
                    jd=j
                else : jd=j+1
                for k in range(ih,ib+1):
                    for l in range(jg,jd+1):
                        if imgnb[k,l]==0.0 and (k,l)!=(i,j):
                            list_v.append((k,l))
                graph[i,j] = (None,list_v)
    return graph



def coloriage_pixels(graph:np.ndarray,i:int,j:int,couleur:int)->None:
    couple = graph[i,j]     # Attribue un indice (couleur) à tous les sommets (pixels)
    Lvoisins = couple[1]    #  D'une même composante connexe du graphe
    if couple[0] == None :
        graph[i,j] = (couleur,Lvoisins)
        for pixel in Lvoisins :
            coloriage_pixels(graph,pixel[0],pixel[1],couleur)

def coloriage_graph(graph:np.ndarray)->None:
    couleur = 0                         # Attribue un indice (couleur) différents
    nl,nc=dim(graph)                    # à chaque composante connexe du graphe
    for i in range(nl):                 # pour toutes les composantes connexes
        for j in range(nc):             # On parcourt tous les pixels de la ligne
            if graph[i,j]!=None and graph[i,j][0]==None:
                coloriage_pixels(graph,i,j,couleur)
                couleur += 1
    return couleur


def trace_chiffre(graph:np.ndarray,couleur:int)->np.ndarray:
    nl,nc=dim(graph)                    # On mémorise les dimensions de l'image (ou du graphe)
    image_chiffre=img_blanche(nl,nc)    # On initialise une image blanche
    for i in range(nl):                 # On parcourt toutes les lignes de l'image
        for j in range(nc):             # On parcourt tous les pixels de la ligne
            if graph[i,j]!= None and graph[i,j][0]==couleur:    # Si le pixel du graphe est de
                                                                #  la même couleur que "couleur"
                image_chiffre[i,j]=0.0  # On noircit le pixel de l'image blanche
    return image_chiffre                # On retourne l'image

def animation_zoom(img:np.ndarray,seuil:float)->None:
    nl,nc=dim(img)                      # On mémorise les dimensions de l'image
    graph=creer_graph(img,seuil)        # On crée le graphe de l'image
    couleur_max=coloriage_graph(graph)  # On le colorie
    L_img_zoom=[]
    #L_targets=[8,1,2,9,5,2,3,1,0,3,1,4,0,1,6,7,8,7,5,9,4,5,2,0,7,0,4,9] # 'chiffres_gris.png'
    L_targets=[]
    for i in range(couleur_max):        # On parcours toutes les couleurs (tous les chiffres)
        diapo=trace_chiffre(graph,i)    # On crée l'image du chiffre
        l_min,l_max=nl-1,0                  # On initialise les limites du chiffres lignes
        c_min,c_max=nc-1,0                  #   et colonnes (mini=maxi et maxi=mini)
        for k in range(nl):                 # On parcours toutes les lignes de l'image
            for l in range(nc):             # On parcours tous les pixels de l'image
                if diapo[k,l]==0.0:         # Si les pixels sont noirs
                    if l_min>k : l_min=k    # On ajuste la limite supérieure
                    if l_max<k : l_max=k    # On ajuste la limite inférieure
                    if c_min>l : c_min=l    # On ajuste la limite gauche
                    if c_max<l : c_max=l    # On ajuste la limite droite
        diapo_zoom=img_blanche(l_max+1-l_min,c_max+1-c_min) # On crée une image blanche aux
                                                            #  dimensions du chiffre (petite)
        for k in range(l_max+1-l_min):                  # On parcours tous les pixels de la
            for l in range(c_max+1-c_min):              #   grande image dans les limites
                diapo_zoom[k,l]=diapo[k+l_min,l+c_min]  # On noirci les pixels correspondant
                                                        #   dans la petite image
        affiche(diapo_zoom)                  # On affiche cette image
        plt.pause(0.5)                       #   pendant 0,5 secondes
        chiffre=int(input('Entrez le chiffre affiché : '))  # On saisie manuellement le label
        diapo_zoom=redim(diapo_zoom)        # On redimensionne l'image en 8x8 niveau de gris
        affiche(diapo_zoom)                 # On affiche cette image
        plt.pause(0.5)                      #   pendant 0,5 secondes
        diapo_applatie=[]                   # On crée une liste vide pour une image applatie
        for ligne in diapo_zoom :       # Balayage verticla de l'image 8x8
            for niveau_pixel in ligne:  # Balayage horizontal de l'image 8x8
                diapo_applatie.append(int(16*(1-niveau_pixel)))     # ajout du pixel en négatif
        L_img_zoom.append([chiffre,np.array(diapo_applatie)])       # Stockage label et image
        # L_img_zoom.append([L_targets[i],np.array(diapo_applatie)])# Dans la liste des données
    return L_img_zoom

def redim(img:np.array)->np.array:  # Redimensionner une image noir ou blanc
    h,l=np.shape(img)               #  de hxl en image en niveau de gris 8x8
    if l>h :                                    ##
        V=np.ones((1,l))                        #
        if (l-h)%2==1:                          # Elargissement de l'image en
            img=np.concatenate((V,img),axis=0)  #  (en blanc) en haut et en bas
        for i in range((l-h)//2):               #  si largeur > hauteur
            img=np.concatenate((V,img),axis=0)  #
            img=np.concatenate((img,V),axis=0)  ##
    if h>l :                                    ##
        V=np.ones((h,1))                        #
        if (h-l)%2==1:                          # Elargissement de l'image en
            img=np.concatenate((V,img),axis=1)  #  (en blanc) à gauche et à droite
        for i in range((h-l)//2):               #  si largeur < hauteur
            img=np.concatenate((V,img),axis=1)  #
            img=np.concatenate((img,V),axis=1)  ##
    result_int=np.zeros((len(img),8))   # Creation d'un image noire hx8
    for i in range(len(img)):           # Balayage verticale de l'image hxl
        for k in range(8):              # Balayge horizontal de l'image 8x8
            niveau=0                # Niveau de gris du 1/8 de ligne = 0
            nb_pix=len(img[i])      # largeur de l'image hxl
            for j in range(nb_pix):                             #
                if j/nb_pix>=k/8 and (j+1)/nb_pix<=(k+1)/8 :    # Portions de
                    niveau+=img[i,j]                            # pixels inclus
                elif (k+1)/8>j/nb_pix and (k+1)/8<(j+1)/nb_pix :# dans le 1/8 de
                    niveau+=img[i,j]*((k+1)/8-j/nb_pix)*nb_pix  # ligne x1 (blanc)
                elif k/8>j/nb_pix and k/8<(j+1)/nb_pix :        # ou x0 (noir)
                    niveau+=img[i,j]*((j+1)/nb_pix-k/8)*nb_pix  #
            niveau=niveau*8/nb_pix      # Niveau de gris du 1/8 de ligne
            result_int[i,k]=niveau      # Niveau de gris du pixel de l'image hx8
    result=np.zeros((8,8))          # Creation d'un image noire 8x8
    for j in range(8):              # Balayage horizontal de l'image 8x8
        for k in range(8):          # Balayage vertical de l'image 8x8
            niveau=0                # Niveau de gris du 1/8 de colonne = 0
            nb_pix=len(img)         # hauteur de l'image hxl
            for i in range(nb_pix): # balaye verticla de l'image hx8
                if i/nb_pix>=k/8 and (i+1)/nb_pix<=(k+1)/8 :            #
                    niveau+=result_int[i,j]                             # Portions de
                elif (k+1)/8>i/nb_pix and (k+1)/8<(i+1)/nb_pix :        # pixels inclus
                    niveau+=result_int[i,j]*((k+1)/8-i/nb_pix)*nb_pix   # dans le 1/8 de
                elif k/8>i/nb_pix and k/8<(i+1)/nb_pix :                # colonne niveau
                    niveau+=result_int[i,j]*((i+1)/nb_pix-k/8)*nb_pix   # de gris
            niveau=niveau*8/nb_pix      # Niveau de gris du 1/8 de colonne
            result[k,j]=niveau          # Niveau de gris du pixel de l'image hx8
    return result

donnees_image=animation_zoom(image_g,0.8)

