# -*- coding: utf-8 -*-
"""
Éditeur de Spyder

Ceci est un script temporaire.
"""


#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 23 13:22:38 2021

@author: adomps
"""

import numpy as np
import matplotlib.pyplot as plt
import random as rd


###########################################################################
##### Fonctions fournies par le professeur
###########################################################################

def represente_donnees_2d(X, Y, ind1, ind2, dico_couleurs_classes, noms_attributs) :
    ''' X est un tableau de N lignes et au moins 2 colonnes contenant N observations
    en dimension p >= 2. Chaque colonne correspond à un attribut et les noms de
    ces colonnes sont fournis dans la liste noms_attributs.
    
    Y est le tableau des étiquettes correspondantes : l'observation X[i] a pour
    étiquette Y[i] (c'est une classe).
    
    Les paramètres ind1 et ind2 sont deux indices à fournir entre O et p-1. La
    fonction représente les N observations par des points dans le plan en 
    utilisant pour chaque     observation i, les coordonées d'indice X[i][ind1] 
    et X[i][ind2]. On obtient donc la projection de observations, au départ 
    fournies en dimension p, dans le plan (ind1, ind2). 
    
    On attribue à tous les points de la même classe une même couleur. La 
    correspondances entre les classes et les couleurs est exprimée dans le dictionnaire 
    dico_couleurs_classes.'''
       
    plt.figure()
    plt.xlabel(noms_attributs[ind1])
    plt.ylabel(noms_attributs[ind2])            
    liste_markers = ['o', 'v', '+', '^', 'v' ] # pour distinguer les différents groupes
    i_marker = 0
    
    for classe in dico_couleurs_classes :
        masque =  (Y == classe) # on repère les éléments de Y égaux à une classe
        X1selection = X[masque, ind1] # on sélectionne ces éléments
        X2selection = X[masque, ind2]
        couleur = dico_couleurs_classes[classe]        
        mark = liste_markers[i_marker % 5]
        plt.scatter(X1selection, X2selection, color = couleur, label = classe, 
                    marker = mark)
        i_marker +=1
    plt.legend()  
    
def applique_sur_jeu_complet(parametre, ma_fonction, X_test, X_app, Y_app) :
    ''' Applique ma_fonction sur toutes les observations de X_test et renvoie
    le tableau numpy des prédictions. 
    Pour les k_PPV, parametre = k.
    Pour les epsilon_PPV, parametre = epsilon. '''
    Y_pred = []
    for i in range(len(X_test)) :
        y_pred = ma_fonction(parametre, X_test[i], X_app, Y_app)
        Y_pred.append(y_pred)    
    return np.array(Y_pred)    
    
###########################################################################
####  Partie I : données sur les iris
###########################################################################

noms_attributs_iris = ["long_sepale", "larg_sepale", "long_petale",
                       "largr_petale", "espece"]

classes_iris = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']

# Dictionnaire associant une couleur à chaque classe (pour les graphiques)
dico_especes_iris_couleurs = { 'Iris-setosa' : 'blue', 'Iris-versicolor' : 'red',
                         'Iris-virginica' : 'green'}

X_iris = np.array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]])

Y_iris = np.array(['Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
       'Iris-setosa', 'Iris-setosa', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
       'Iris-virginica', 'Iris-virginica'], dtype='<U15')


# Représentation graphique. On choisit deux attributs sur 4 en précisant leurs indices
represente_donnees_2d(X_iris, Y_iris, 0, 1, dico_especes_iris_couleurs, 
                      noms_attributs_iris)   

represente_donnees_2d(X_iris, Y_iris, 0, 3, dico_especes_iris_couleurs, 
                      noms_attributs_iris)


def separe_donnees(X, Y, fraction_test) :
    ''' Sépare les données représentées par les observations X et les étiquettes
    Y en deux nouveaux jeux, l'un pour l'apprentissage et l'autre pour la
    validation. On renvoie X_app,  X_test, Y_app, Y_test. Le nombre d'éléments
    tests est défini par sa fraction par rapport au total. Les éléments de chaque
    groupe sont tirés aléatoirement. C'est l'équivalent de train_test_split du
    module sklearn.'''
    N, dim = X.shape
    typeY = Y.dtype
    ind = [i for i in range(N)]
    rd.shuffle(ind) # mélange aléatoire de ces indices
    N_test = int(N * fraction_test)
    N_app = N - N_test
    X_test = np.zeros((N_test, dim))
    X_app = np.zeros((N_app, dim))
    Y_app = np.zeros(N_app, dtype = typeY)
    Y_test = np.zeros(N_test, dtype = typeY)   
    for i, j in enumerate(ind[: N_app]) :
        X_app[i] = X[j]
        Y_app[i] = Y[j]
    for i, j in enumerate(ind[N_app:]) :
        #print(i, j)
        X_test[i] = X[j]
        Y_test[i] = Y[j]
    return X_app, X_test, Y_app, Y_test


X_app, X_test, Y_app, Y_test = separe_donnees(X_iris, Y_iris, 0.2)


# Si on veut le refaire soi-même
def moyenne(T) :
    S = 0
    for a in T :
        S += a
    return S / len(T)

def ecart_type(T) :
    m = moyenne(T)
    D = T - m
    S = 0
    for x in D :
        S += x**2
    return np.sqrt(S/len(D))


def parametres_normalisation(X) :
    ''' Renvoie la moyenne et l'écart type des observations X.
    Ces nombres sont calculés séparément pour chaque attribut, 
    c'est à dire pour chaque colonne de X. On renvoie deux vecteurs qui
    seront utiles ensuite pour appliquer la normalisation. 
    C'est l'équivalent de la méthode fit de StandardScaler dans 
    le module sklearn. '''
    moyennes = np.average(X, axis = 0) # ou moyenne(X_iris)
    ecart_types = np.std(X, axis = 0) # ou ecart_type(X_iris)
    return moyennes, ecart_types

moyennes_iris, ecarts_types_iris = parametres_normalisation(X_iris)


def normalise(X, moyennes, ecarts_types) :
    ''' Normalise les observations X. Pour chque  colonne, on soustrait
    une moyenne contenue dans le vecteur moyennes et un écart-type contenu
    dans le vecteur ecarts_types. '''
    X_norm = (X - moyennes) / ecarts_types
    return X_norm
    

normalise(X_iris, moyennes_iris, ecarts_types_iris)



def d_euclid(V, W) :
    ''' Distance euclidienne entre deux vecteurs représentés par des 
    tableaux numpy V et W. '''
    return np.sqrt(np.sum((V-W)**2)) 

d_euclid(X_iris[0], X_iris[100])


def comptage(liste) :
    ''' Renvoie un dictionnaire dont les clés sont les éléments existant dans
    la liste, sans répétition, et dont les valeurs sont les nombres d'occurence
    de chacun de ces éléments. 
    [Classe1 : nb1, Classe2 : nb2, etc]
    C'est une réécriture de la fonction Counter du module collections.'''
    # On initialise les compteurs. Pour un dictionnaire, l'initialisation
    # par description ne répète pas les éléments qui seraient éventuellement
    # répétés dans $x$. C'est ici très pratique !
    Compteurs = {x : 0 for x in liste} 
    # on pourrait aussi créer un dico vide et faire à chaque fois le test
    # if x in Compteurs pour savoir s'il faut créer une nouvelle clé.
    for x in liste :
        Compteurs[x] += 1
    return Compteurs

comptage(['a', 'a', 'c', 'b', 'c'])

def vote_majoritaire(liste_etiquettes) :  
    ''' etiquette est une liste d'étiquettes (ou classes) des k PPV. Par exemple,
    [A, A, B, C, A, C, A, B, B ]. On renvoie
    l'étiquette la plus représentée. En cas d'égalité, on choisit arbitrairement
    la première étiquette rencontrée dans la liste liste_etiquettes. '''
    assert len(liste_etiquettes) > 0, " Pas de votant ! "
    depouillement = comptage(liste_etiquettes) # on dépouille les votes
    meilleur_score = 0
    gagnant = None
    for e in depouillement :
        if depouillement[e] > meilleur_score :
            meilleur_score = depouillement[e]
            gagnant = e
    return gagnant


vote_majoritaire(['a', 'a', 'c', 'b', 'c'])

# Idée en cas d'égalité : livre Joel Grus
def vote_majoritaire_departage(liste_etiquettes) :  
    ''' etiquette est la liste des étiquettes (ou classes) des k PPV. On renvoie
    l'étiquette la plus représentée. Les etiqueets sont triées par distances croissantes.
    En cas d'égalité, on élimine la dernière étiquette  (la plus éloignée) et on recommence.
    On poursuit ainsi récursivement jusqu'à ce qu'il ne reste qu'un seul vainqueur.'''
    
    assert len(liste_etiquettes) > 0, " Pas de votant ! "    
    depouillement = comptage(liste_etiquettes) # on dépouille les votes
    meilleur_score = 0
    gagnants = []
    for e in depouillement :
        if depouillement[e] > meilleur_score :
            gagnants = [e] # e est pour l'instant l'unique gagnant
            meilleur_score = depouillement[e]
        elif depouillement[e] == meilleur_score :
            gagnants.append(e) # on a alors un nouveau gagnant dans la liste
    if len(gagnants) == 1 :
        return gagnants[0]
    else : # on recommence en supprimant le plus lointain
        # On pourrait faire plus astucieux .....
        return vote_majoritaire_departage(liste_etiquettes[:-1])


vote_majoritaire_departage(['a', 'a', 'a', 'c', 'b', 'c', 'c', 'b'])


# J'ai dans un autre fichier une autre version avec l'option << feignasse >> 
# où on trie  toute la liste avec la méthode sort.
def liste_k_PPV(k, observation, X_app) :
    ''' Renvoie la liste des indices des k PPV de observation parmi
    ceux de X_app. << distances >> est la fonction qui calcule la distance '''

    distances = []
 
    for x in X_app :
        d = d_euclid(observation, x)
        distances.append(d)

    indices_kPPV = []    
    for u in range(k) :
        i_min, d_min = 0, distances[0]        
        for i, d in enumerate(distances) :
            if d < d_min :
                i_min, d_min = i, d
        distances[i_min] = float('inf') # pour l'éliminer des passages suivants
        indices_kPPV.append(i_min)
            
    return indices_kPPV


liste_k_PPV(8, [5, 3., 1.5, 1.], X_iris)


def kPPV_classification(k, observation, X_app, Y_app) :
    ''' X_app sont les données d'apprentissage et Y_app les classes
    correspondantes. La liste << classes >> contient toutes les classes
    possibles. Pour une observation donnée, on cherche dans X_app
    les k plus proches selon la distance choisie et on renvoie la
    classe la plus représentée. '''
       
    # On cherche les indices des k plus petites distances.   
    indices_k_PPV = liste_k_PPV(k, observation, X_app) 
    #print(indices_k_PPV)
    
    # Étiquettes des k plus proches voisins
    Y_kPPV = Y_app[indices_k_PPV] # on manipule en bloc un groupe d'indices dans une liste
    #print(Y_kPPV)
    
    # vote majoritaire
    return vote_majoritaire_departage(Y_kPPV)  



#######################################################################
###  Application ########################################
X_app, X_test, Y_app, Y_test = separe_donnees(X_iris, Y_iris, 0.2)

# On normalise les données
moy, ecarts = parametres_normalisation(X_app) # Seulement avec le jeu d'apprentissage
X_app = normalise(X_app, moy, ecarts)
X_test = normalise(X_test, moy, ecarts)

i_essai = 6
k_voisins = 8
obs_essai = X_test[i_essai]

reponse1 = kPPV_classification(k_voisins, obs_essai, X_app, Y_app)
print('On obtient ', reponse1, ' et on doit obtenir', Y_test[i_essai])


##############################################################################
#### Qualité de l'apprentissage
##############################################################################

Y_pred = applique_sur_jeu_complet(k_voisins, kPPV_classification, X_test, X_app, Y_app)
Y_pred == Y_test





def matrice_confusion(Y_connus, Y_pred, classes) :
    ''' Renvoie la matrice de confutions. M[i, j] est le nombre d'observations
    connues pour être dans la classe d'indice i et que l'on a prédites dans
    la classe j. La variable << classes >> est la liste des classes du problème
    étudié. '''    
    N_obs = len(Y_connus)
    assert N_obs == len(Y_pred), "Les nombres d'étiquettes ne sont pas égaux ! "

    Nc = len(classes) # nombre de classes possibles    
    Mat_conf = np.zeros([Nc, Nc], dtype = int)

    # Dictionnaire pour associer un indice à chaque classe
    dico_classes_indices = {classe : i for (i, classe) in enumerate(classes)}    
        
    for indice in range(N_obs) :
        classe_connue = Y_connus[indice]
        classe_predite = Y_pred[indice]
        i = dico_classes_indices[classe_connue]
        j = dico_classes_indices[classe_predite]
        Mat_conf[i, j] += 1
        
    return Mat_conf


### Essais sur tout le jeu test
Y_pred = applique_sur_jeu_complet(k_voisins, kPPV_classification, X_test, X_app, Y_app)
print(matrice_confusion(Y_pred, Y_test, classes_iris))    




########################################################################
####  Partie 5. Régression et  application aux stations de ski
########################################################################

def kPPV_regression(k, observation, X_app, Y_app) :
    ''' Les étiquettes Y_app sont des nombres. On renvoie la moyenne des
    étiquettes des k plus proches voisins.'''
    # On cherche les indices des k plus petites distances.   
    indices_k_PPV = liste_k_PPV(k, observation, X_app) #
    #print(indices_k_PPV)
    
    # Étiquettes des k plus proches voisins
    Y_kPPV = Y_app[indices_k_PPV] # on manipule en bloc un groupe d'indices dans une liste
    #print(Y_kPPV)
    
    return np.average(Y_kPPV)  



X_stations = np.array([[9.500e+02, 1.730e+03, 5.600e+02, 1.500e+01, 1.000e+02],
       [1.000e+03, 1.900e+03, 9.000e+02, 4.000e+01, 1.150e+02],
       [1.500e+03, 2.100e+03, 7.000e+02, 2.800e+01, 1.050e+02],
       [1.860e+03, 3.300e+03, 1.440e+03, 1.110e+02, 1.400e+02],
       [1.367e+03, 2.184e+03, 8.170e+02, 3.800e+01, 9.300e+01],
       [1.080e+03, 2.320e+03, 1.220e+03, 3.000e+01, 1.070e+02],
       [1.600e+03, 3.300e+03, 1.730e+03, 2.400e+01, 1.400e+02],
       [1.500e+03, 2.750e+03, 1.250e+03, 2.100e+01, 1.190e+02],
       [1.050e+03, 1.710e+03, 6.600e+02, 2.100e+01, 1.030e+02],
       [1.100e+03, 1.610e+03, 5.100e+02, 2.600e+01, 1.000e+02],
       [1.000e+03, 2.000e+03, 1.000e+03, 2.100e+01, 1.050e+02],
       [1.800e+03, 3.000e+03, 1.200e+03, 2.600e+01, 1.260e+02],
       [6.000e+02, 2.950e+03, 2.350e+03, 6.600e+01, 1.200e+02],
       [1.042e+03, 3.275e+03, 2.233e+03, 1.190e+02, 1.640e+02],
       [1.250e+03, 3.250e+03, 2.000e+03, 1.360e+02, 1.260e+02],
       [1.400e+03, 2.250e+03, 8.500e+02, 4.300e+01, 1.660e+02],
       [1.100e+03, 2.200e+03, 1.100e+03, 4.500e+01, 1.350e+02],
       [1.350e+03, 1.900e+03, 5.500e+02, 8.000e+00, 9.000e+01],
       [1.060e+03, 1.350e+03, 2.900e+02, 6.000e+00, 7.900e+01],
       [1.200e+03, 1.550e+03, 3.500e+02, 5.000e+00, 6.500e+01],
       [1.183e+03, 1.930e+03, 7.470e+02, 6.500e+01, 9.300e+01],
       [1.000e+03, 1.600e+03, 6.000e+02, 1.000e+01, 8.000e+01],
       [1.111e+03, 2.050e+03, 9.390e+02, 5.100e+01, 1.050e+02],
       [1.100e+03, 2.738e+03, 1.388e+03, 1.090e+02, 1.400e+02],
       [1.230e+03, 2.069e+03, 4.200e+02, 1.570e+02, 1.120e+02],
       [1.400e+03, 2.550e+03, 1.150e+03, 6.100e+01, 1.100e+02],
       [1.600e+03, 2.500e+03, 9.000e+02, 6.400e+01, 1.300e+02],
       [1.000e+03, 1.650e+03, 6.500e+02, 1.600e+02, 1.000e+02],
       [1.245e+03, 1.751e+03, 5.060e+02, 2.700e+01, 8.500e+01],
       [1.000e+03, 1.800e+03, 8.000e+02, 2.400e+01, 1.000e+02],
       [1.040e+03, 2.500e+03, 1.500e+03, 8.400e+01, 1.280e+02],
       [1.200e+03, 1.930e+03, 7.300e+02, 6.600e+01, 1.000e+02],
       [1.350e+03, 2.750e+03, 1.400e+03, 2.700e+01, 1.110e+02],
       [1.250e+03, 3.250e+03, 2.000e+03, 1.280e+02, 1.380e+02],
       [1.850e+03, 2.800e+03, 9.500e+02, 8.100e+01, 1.280e+02],
       [1.150e+03, 1.850e+03, 7.000e+02, 1.200e+01, 1.000e+02],
       [1.350e+03, 2.732e+03, 1.382e+03, 1.090e+02, 1.200e+02],
       [1.100e+03, 2.620e+03, 1.520e+03, 4.200e+01, 1.180e+02],
       [1.420e+03, 1.807e+03, 3.870e+02, 3.000e+01, 1.000e+02],
       [1.100e+03, 2.620e+03, 1.520e+03, 4.100e+01, 1.190e+02],
       [1.200e+03, 1.400e+03, 2.000e+02, 4.000e+00, 8.700e+01],
       [1.000e+03, 2.100e+03, 1.100e+03, 4.600e+01, 1.350e+02],
       [9.800e+02, 1.600e+03, 6.200e+02, 1.000e+01, 1.100e+02],
       [9.600e+02, 1.350e+03, 3.900e+02, 6.000e+00, 9.900e+01],
       [1.450e+03, 1.704e+03, 2.540e+02, 2.000e+01, 1.100e+02],
       [1.350e+03, 3.600e+03, 1.950e+03, 9.000e+01, 1.480e+02],
       [1.350e+03, 2.400e+03, 1.050e+03, 5.100e+01, 1.050e+02],
       [1.200e+03, 3.226e+03, 2.026e+03, 1.230e+02, 1.300e+02],
       [1.280e+03, 1.900e+03, 6.200e+02, 6.000e+00, 1.180e+02],
       [9.000e+02, 1.500e+03, 6.000e+02, 2.000e+01, 1.000e+02],
       [1.140e+03, 2.100e+03, 1.360e+03, 2.800e+01, 1.220e+02],
       [1.160e+03, 2.450e+03, 1.290e+03, 4.800e+01, 1.300e+02],
       [1.064e+03, 1.463e+03, 3.670e+02, 8.000e+00, 9.800e+01],
       [1.172e+03, 2.002e+03, 8.300e+02, 7.200e+01, 1.250e+02],
       [9.900e+02, 1.900e+03, 9.100e+02, 3.000e+01, 1.500e+02],
       [1.600e+03, 2.520e+03, 9.200e+02, 2.800e+01, 1.190e+02],
       [1.450e+03, 2.850e+03, 1.400e+03, 8.500e+01, 1.300e+02],
       [1.140e+03, 1.510e+03, 3.700e+02, 7.000e+00, 8.500e+01],
       [1.460e+03, 1.780e+03, 3.150e+02, 2.900e+01, 1.180e+02],
       [1.050e+03, 2.350e+03, 1.300e+03, 1.850e+02, 1.200e+02],
       [1.100e+03, 2.952e+03, 1.852e+03, 7.000e+01, 1.400e+02],
       [1.000e+03, 1.500e+03, 5.000e+02, 1.100e+01, 8.700e+01],
       [1.250e+03, 3.250e+03, 2.000e+03, 1.350e+02, 1.250e+02],
       [1.090e+03, 1.180e+03, 9.000e+01, 5.000e+00, 7.500e+01],
       [9.700e+02, 2.350e+03, 1.380e+03, 5.000e+01, 1.300e+02],
       [7.000e+02, 2.500e+03, 1.400e+03, 2.000e+01, 1.200e+02],
       [1.000e+03, 2.466e+03, 1.466e+03, 0.000e+00, 1.120e+02],
       [1.012e+03, 1.600e+03, 5.880e+02, 1.600e+01, 9.000e+01],
       [9.000e+02, 3.200e+03, 2.300e+03, 8.600e+01, 1.600e+02],
       [1.100e+03, 2.800e+03, 1.700e+03, 3.300e+01, 1.400e+02],
       [1.350e+03, 1.750e+03, 4.000e+02, 9.000e+00, 9.200e+01],
       [1.410e+03, 2.355e+03, 9.450e+02, 2.400e+01, 1.100e+02],
       [1.250e+03, 2.000e+03, 7.500e+02, 5.300e+01, 1.200e+02],
       [1.035e+03, 2.069e+03, 1.034e+03, 2.500e+01, 1.050e+02],
       [1.200e+03, 1.580e+03, 3.800e+02, 7.000e+00, 9.000e+01],
       [7.000e+02, 2.500e+03, 1.800e+03, 1.390e+02, 1.260e+02],
       [8.000e+02, 1.200e+03, 8.000e+02, 6.000e+00, 8.000e+01],
       [1.100e+03, 2.620e+03, 1.520e+03, 1.900e+01, 9.700e+01],
       [1.450e+03, 2.550e+03, 1.150e+03, 4.200e+01, 1.100e+02],
       [1.150e+03, 2.525e+03, 1.375e+03, 1.200e+02, 1.200e+02],
       [1.000e+03, 1.400e+03, 4.000e+02, 1.000e+01, 9.500e+01],
       [1.100e+03, 2.620e+03, 1.520e+03, 4.100e+01, 1.190e+02],
       [9.000e+02, 1.800e+03, 9.000e+02, 3.000e+01, 1.100e+02],
       [9.400e+02, 1.000e+03, 6.000e+01, 1.000e+00, 9.000e+01],
       [1.450e+03, 2.850e+03, 1.400e+03, 8.500e+01, 1.330e+02],
       [9.000e+02, 1.800e+03, 9.000e+02, 2.800e+01, 8.000e+01],
       [1.500e+03, 2.620e+03, 1.120e+03, 3.900e+01, 1.150e+02],
       [1.550e+03, 2.620e+03, 1.070e+03, 2.500e+01, 1.200e+02],
       [1.550e+03, 3.450e+03, 1.900e+03, 7.400e+01, 2.150e+02],
       [2.300e+03, 3.230e+03, 9.300e+02, 8.800e+01, 1.680e+02],
       [1.850e+03, 3.456e+03, 1.606e+03, 8.000e+01, 1.550e+02],
       [1.430e+03, 2.750e+03, 1.320e+03, 8.900e+01, 1.310e+02],
       [1.430e+03, 2.750e+03, 1.320e+03, 8.900e+01, 1.310e+02],
       [1.400e+03, 2.550e+03, 1.150e+03, 6.100e+01, 1.200e+02],
       [1.650e+03, 2.800e+03, 1.150e+03, 3.300e+01, 1.400e+02],
       [1.480e+03, 2.115e+03, 6.350e+02, 1.400e+01, 1.400e+02]])

Y_stations = np.array([25.5, 26. , 27. , 54.5, 28.5, 32. , 37. , 32. , 20.5, 21.5, 24.5,
       29.8, 53. , 55. , 54. , 36. , 46. , 12. , 13.4, 15. , 37.5, 19.8,
       38. , 56. , 35. , 44.3, 48.5, 34. , 20. , 29.5, 39.8, 38. , 32. ,
       55. , 47. , 19.5, 56. , 37.9, 17.6, 37.9, 11. , 38.5, 15. , 13. ,
       10.7, 53. , 36.5, 54. , 37.9, 26.9, 44. , 41.7,  9. , 43. , 45. ,
       29. , 52. , 12. , 22.4, 49.5, 53. , 18. , 55. ,  9.5, 47. , 44. ,
       43. , 20.5, 56. , 38. , 20. , 31.5, 29.5, 34. , 12. , 48.5, 12.5,
       27. , 41. , 50.5, 14. , 37.9, 25.9,  9.5, 52. , 26. , 38. , 32.7,
       62. , 57.5, 61. , 43. , 41.3, 44.3, 38. , 34.5])

noms_attributs_stations = ['Alt_min',  'Alt_max',  'Denivele',   'Nb_pistes',
                           'Nb_jours_ouverts',  'Prix_forfait']

plt.figure()
# Ici, il n'y a pas de classes donc ma fonction represente_donnees ne convient pas.
plt.scatter(X_stations[:,2], X_stations[:,3], c= Y_stations)
plt.xlabel(noms_attributs_stations[2])
plt.ylabel(noms_attributs_stations[3])
plt.colorbar()



X_app, X_test, Y_app, Y_test = separe_donnees(X_stations, Y_stations, 0.2)

moy, ecarts = parametres_normalisation(X_app) # Seulement avec le jeu d'apprentissage
X_app = normalise(X_app, moy, ecarts)
X_test = normalise(X_test, moy, ecarts)

X_essai = [1700., 2700, 1000, 50., 104]
X_essai = normalise(X_essai, moy, ecarts)
k = 5
prix_predit = kPPV_classification(k, X_essai, X_app, Y_app)
print('On prédit un prix de', prix_predit, '€')



def erreur_qm(Y_pred, Y_test) :
    E2 = np.average((Y_pred - Y_test)**2)
    return np.sqrt(E2)

# Essais sur tout le jeu de validation
Y_pred = applique_sur_jeu_complet(k, kPPV_regression, X_test, X_app, Y_app)



print('erreur quadratique moyenne', erreur_qm(Y_test, Y_pred), '€')

# Un graphique pour mieux voir si ça marche.
plt.figure()
plt.plot(Y_test, Y_pred, '+', color = 'red', markersize = 10)
plt.plot(Y_test, Y_test, linestyle = '-')
plt.xlabel('prix réel')
plt.ylabel('prix prédit')



##########################################################################""
#########  Partie 6 :  epsilon-voisins ###########################
##########################################################################""
    
def liste_eps_voisins(epsilon, observation, X_app) :
    ''' Renvoie les indices des éléments de X_app qui sont proches
    de observation, à une distance inférieure à epsilon. '''
    liste_indices_voisins = []
    for i, x in enumerate(X_app) :
        d = d_euclid(observation, x)
        if d <= epsilon :
            liste_indices_voisins.append(i) 
    if len(liste_indices_voisins) >= 1 : # si on a trouvé au moins un voisin
            return liste_indices_voisins # on renvoie la réponse
    else :  # sinon, on cherche à nouveau en augmentant epsilon
        return liste_eps_voisins(1.5 * epsilon, observation, X_app)


def epsPPV_classification(epsilon, observation, X_app, Y_app) :
    ''' X_app sont les données d'apprentissage et Y_app les classes
    correspondantes. Pour une observation donnée, on cherche dans X_app
    les observations dont la distance est inférieur à epsilon
    classe la plus représentée. '''    

    liste_indices_voisins = liste_eps_voisins(epsilon, observation, X_app)
    
    assert(liste_indices_voisins != []), 'pas de voisin trouvé'
    #print(liste_indices_voisins)        
    
    # Étiquettes des k plus proches voisins
    Y_PPV = Y_app[liste_indices_voisins]    
    
    # vote majoritaire
    return vote_majoritaire_departage(Y_PPV)   



def epsPPV_regression(epsilon, observation, X_app, Y_app) :
    ''' X_app sont les données d'apprentissage et Y_app les classes
    correspondantes. La liste << classes >> contient toutes les classes
    possibles. Pour une observation donnée, on cherche dans X_app
    les observations dont la distance est inférieur à epsilon
    classe la plus représentée. '''    

    liste_indices_voisins = liste_eps_voisins(epsilon, observation, X_app)
   
    assert(liste_indices_voisins != []), 'pas de voisin trouvé'      

    # Étiquettes des k plus proches voisins
    Y_PPV = Y_app[liste_indices_voisins]    
    
    return np.average(Y_PPV)    

### Application aux iris
X_app, X_test, Y_app, Y_test = separe_donnees(X_iris, Y_iris, 0.2)
moy, ecarts = parametres_normalisation(X_app) # Seulement avec le jeu d'apprentissage
X_app = normalise(X_app, moy, ecarts)
X_test = normalise(X_test, moy, ecarts)    

epsilon = 0.3
Y_pred = applique_sur_jeu_complet(epsilon, epsPPV_classification, X_test, X_app, Y_app)  
print(matrice_confusion(Y_pred, Y_test, classes_iris))
    


### Application aux stations de ski
X_app, X_test, Y_app, Y_test = separe_donnees(X_stations, Y_stations, 0.2)
moy, ecarts = parametres_normalisation(X_app) # Seulement avec le jeu d'apprentissage
X_app = normalise(X_app, moy, ecarts)
X_test = normalise(X_test, moy, ecarts)

epsilon = 0.5
Y_pred = applique_sur_jeu_complet(epsilon, epsPPV_regression, X_test, X_app, Y_app)


print('erreur quadratique moyenne', erreur_qm(Y_test, Y_pred), '€')





