######################################################################################
#                    Régression multiple et polynomiale                              #
######################################################################################
import random as rd
#import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

## Importation des données

# chemin ="Donnees_mystere.csv"   # Importation des données du fichier ".csv" dans la DataFrame : "dataset"
# dataset = pd.read_csv(chemin)
#         # Création de la matrice des entrées "X" et du vecteur des sorties "Y"
# Entrees = dataset[["Entrées1 = Xa", "Entrées2 = Xb"]].to_numpy(dtype=float)
# Sorties = dataset[["Sorties = Y"]].to_numpy(dtype=float)

## Affichage des données dans un graphique 3D et Création de la matrice de la régression

def Trace_donnees(X,Y,Montrer):
    ax=plt.axes(projection='3d')
    ax.scatter(X[:,0],X[:,1],Y[:,0])
    ax.set_xlabel('Entrées1 = Xa')
    ax.set_ylabel('Entrées2 = Xb')
    ax.set_zlabel('Sorties = Y')
    if Montrer==True:
        plt.show()

def Mat_donnees(entrees_brutes:np.array,L_degres:list)->np.array:
    Ndonnees,Nparam=np.shape(entrees_brutes)
    donnees=np.ones((Ndonnees,1+sum(L_degres)))
    colonne=1
    for ip in range(Nparam):
        degre=1
        for id in range(L_degres[ip]):
            donnees[:,colonne]=(entrees_brutes[:,ip])**degre
            degre+=1
            colonne+=1
    return donnees

def Init_vecteur_parametres(X,Y):
    Nbr_para=np.shape(X)[1]
    Winit=[[np.mean(Y)]]
    Winit=Winit+[[rd.random()] for i in range(1,Nbr_para)]
    return np.array(Winit)

def normalisation_donnes(donnees):
    mu = np.ones(donnees.shape[1])          #  Moyennes des entrées
    sigma = np.zeros(donnees.shape[1])      # Ecart type des entrées
    donnees_normal=np.ones(donnees.shape)   # Matrice des entrées normalisées
    for k in range(1,donnees.shape[1]):     # Pour chaque colonne sauf celle des "1"
        mu[k] = np.mean(donnees[:,k])                   # Calcul de la moyenne
        sigma[k] = np.std(donnees[:,k])                 # Calcul de l'écart type
        donnees_normal[:,k]=(donnees[:,k]-mu[k])/sigma[k]                                             # Normalisation des données
    return donnees_normal,mu,sigma


## Algorithme de la régression

def Cout(X,W,Y):
    m=np.shape(X)[0]
    return sum((X.dot(W)-Y)**2)/(2*m)



def gradient(X,W,Y):
    m=len(X)
    return (X.T.dot(X.dot(W)-Y))/m




def Regression(Entrees,Sorties,L_degre,Nbr_iter,alpha,Epsilon):
        # Construction de la matrice de données
    Xd=Mat_donnees(Entrees,L_degre)
        # Normalisation des données et initialisation du vecteur paramètres W
    Xn,mu,sigma=normalisation_donnes(Xd)
    Wn=Init_vecteur_parametres(Xn,Sorties)
    print('Parametres initiaux de la régression :')
    print(Wn)
        # Initialisation du Delta_coût et du compteur
    cout=Cout(Xn,Wn,Sorties)
    HC=[cout]                   # Historique des coûts
    CP=np.inf                   # Coût précédent (avant la 1iere itération il est infini)
    k=0                         # Compteur d'itérations
    Delta_Cout=abs(CP-HC[-1])   # Variation du coût
    print(cout,'\t',k,'\t','\t',Delta_Cout)
        # Boucle d’apprentissage par la descente de gradiant
    while (k<Nbr_iter and Delta_Cout>Epsilon):
        Wn=Wn-alpha*gradient(Xn,Wn,Sorties)               # Actualisation du vecteur paramètres
            # Calcul du nouveau Delta_coût
        CP=HC[-1]
        cout=Cout(Xn,Wn,Sorties)
        HC.append(cout)
        Delta_Cout=abs(CP-HC[-1])
            # Affichage du Coût, du compteur et du Delta_coüt
            #  toutes les 1000 itérations
        k+=1    # Incrémentation du compteur d'itération
        if k%1000==0:
            print(cout,'\t',k,'\t',Delta_Cout)
                # Le vecteur paramètres est celui des données normalisées
                # Caclul des paramètres des données avant normalisation
    W=np.zeros((len(Wn),1))
    W[0,0]=Wn[0,0]
    for i in range(1,len(Wn)):
            W[0,0]-=Wn[i,0]*mu[i]/sigma[i]                           # Pour le calcul du paramètre Omega 0
            W[i,0]=Wn[i,0]/sigma[i]                        # pour le calcul du paramètre Omage i
            # Affiche graphique des données et de la régression
    Ecriture_regression(W,k,L_degre,HC)
    graphique(Entrees,Sorties,W,L_degre)
    # return W,HC,k

## Affichage des résultats et des données dans un graphique 3D

def Ecriture_regression(W,k,L_degre,Liste_Couts):
    print("La régression a été construite en",k,"itérations")
    print("Le modèle de prédiction est :")
    equation='y = w0 + w1.xa'
    for j in range(2,L_degre[0]+1):
        equation +=' + w'+str(j)+'.xa^'+str(j)
    equation+=' + w'+str(L_degre[0]+1)+'.xb'
    for j in range(2,L_degre[1]+1):
        equation +=' + w'+str(L_degre[0]+2)+'.xb^'+str(j)
    print(equation)
    print('Les paramètres wi de cette équation sont (de w0 à w'+str(sum(L_degre))+') :')
    print(W)
    print("L'erreur moyenne est :",np.sqrt(Liste_Couts[-1])*2)

def graphique(X,Y,W,L_degres):
    Trace_donnees(X,Y,False)
    ax=plt.axes(projection='3d')
    ax.scatter(X[:,0],X[:,1],Y[:,0])
    ax.set_xlabel('Entrées1 = Xa')
    ax.set_ylabel('Entrées2 = Xb')
    ax.set_zlabel('Sorties = Y')
    f = lambda sx, sy : W[0,0] + sum([W[i+1,0]*sx**(i+1) for i in range(L_degres[0])]) + sum([W[i+1+L_degres[0],0]*sy**(i+1) for i in range(L_degres[1])])
    GX = np.linspace(10,40,31)
    GY = np.linspace(10,40,31)
    GX, GY = np.meshgrid(GX, GY)
    Z = f(GX, GY)
    ax.plot_surface(GX, GY, Z, cmap='rainbow', alpha=0.8)
    plt.show()




Entrees = np.array([[10., 34.],
       [13., 16.],
       [39., 33.],
       [24., 28.],
       [14., 40.],
       [17., 36.],
       [31., 22.],
       [15., 23.],
       [33., 33.],
       [38., 37.],
       [28., 37.],
       [36., 30.],
       [20., 25.],
       [13., 25.],
       [22., 17.],
       [20., 21.],
       [29., 29.],
       [25., 12.],
       [17., 16.],
       [22., 25.],
       [26., 30.],
       [16., 15.],
       [24., 19.],
       [20., 38.],
       [27., 28.],
       [35., 13.],
       [21., 23.],
       [33., 17.],
       [17., 28.],
       [35., 11.],
       [37., 37.],
       [11., 25.],
       [10., 33.],
       [19., 24.],
       [38., 19.],
       [24., 27.],
       [36., 27.],
       [20., 24.],
       [14., 36.],
       [27., 20.],
       [34., 29.],
       [17., 22.],
       [27., 12.],
       [34., 23.],
       [25., 37.],
       [31., 23.],
       [32., 40.],
       [39., 28.],
       [15., 21.],
       [37., 28.],
       [40., 14.],
       [24., 18.],
       [25., 13.],
       [30., 37.],
       [18., 12.],
       [39., 39.],
       [16., 36.],
       [21., 12.],
       [26., 21.],
       [28., 17.],
       [18., 38.],
       [21., 30.],
       [30., 19.],
       [20., 25.],
       [19., 12.],
       [29., 13.],
       [20., 22.],
       [32., 15.],
       [20., 28.],
       [38., 22.],
       [25., 37.],
       [18., 29.],
       [36., 24.],
       [18., 32.],
       [13., 20.],
       [24., 10.],
       [22., 10.],
       [20., 11.],
       [15., 21.],
       [34., 35.],
       [39., 36.],
       [32., 38.],
       [31., 19.],
       [37., 21.],
       [14., 10.],
       [25., 17.],
       [10., 35.],
       [37., 19.],
       [10., 15.],
       [38., 29.],
       [19., 18.],
       [27., 34.],
       [23., 34.],
       [12., 32.],
       [35., 12.],
       [34., 22.],
       [19., 37.],
       [20., 37.],
       [40., 25.],
       [28., 38.],
       [20., 37.],
       [23., 13.],
       [21., 36.],
       [29., 11.],
       [27., 14.],
       [32., 21.],
       [30., 11.],
       [21., 33.],
       [31., 17.],
       [32., 40.],
       [15., 16.],
       [15., 29.],
       [32., 34.],
       [28., 22.],
       [35., 31.],
       [18., 12.],
       [33., 34.],
       [12., 33.],
       [12., 11.],
       [16., 36.],
       [34., 15.],
       [34., 18.],
       [19., 29.],
       [33., 38.],
       [13., 14.],
       [11., 17.],
       [19., 11.],
       [33., 20.],
       [32., 37.],
       [19., 20.],
       [38., 26.],
       [18., 39.],
       [12., 36.],
       [33., 16.],
       [13., 16.],
       [24., 37.],
       [34., 27.],
       [15., 13.],
       [13., 27.],
       [32., 23.],
       [11., 32.],
       [27., 30.],
       [13., 28.],
       [31., 34.],
       [27., 25.],
       [12., 29.],
       [17., 32.],
       [16., 38.],
       [32., 19.],
       [20., 37.],
       [37., 19.],
       [28., 16.],
       [32., 12.],
       [15., 23.],
       [30., 25.],
       [29., 26.],
       [18., 30.],
       [13., 15.],
       [30., 14.],
       [19., 35.],
       [40., 27.],
       [35., 19.],
       [27., 23.],
       [24., 21.],
       [22., 18.],
       [37., 36.],
       [16., 13.],
       [14., 32.],
       [32., 30.],
       [34., 13.],
       [14., 21.],
       [35., 22.],
       [32., 38.],
       [23., 37.],
       [33., 22.],
       [14., 10.],
       [11., 24.],
       [22., 37.],
       [21., 33.],
       [26., 39.],
       [36., 37.],
       [39., 18.],
       [27., 17.],
       [22., 34.],
       [34., 36.],
       [10., 19.],
       [34., 35.],
       [16., 10.],
       [28., 15.],
       [39., 26.],
       [14., 32.],
       [19., 32.],
       [21., 34.],
       [33., 12.],
       [23., 24.],
       [15., 13.],
       [18., 39.],
       [29., 31.],
       [39., 38.],
       [31., 38.],
       [15., 24.],
       [13., 39.],
       [11., 25.],
       [12., 36.],
       [40., 22.],
       [36., 22.],
       [29., 30.],
       [16., 26.],
       [31., 38.],
       [15., 39.],
       [12., 25.],
       [23., 37.],
       [37., 10.],
       [28., 30.],
       [25., 22.],
       [12., 37.],
       [20., 26.],
       [30., 34.],
       [38., 12.],
       [32., 39.],
       [38., 31.],
       [38., 17.],
       [20., 35.],
       [19., 18.],
       [20., 13.],
       [30., 39.],
       [20., 20.],
       [10., 40.],
       [19., 29.],
       [12., 16.],
       [21., 17.],
       [39., 24.],
       [19., 15.],
       [35., 37.],
       [25., 15.],
       [16., 37.],
       [34., 39.],
       [40., 24.],
       [17., 29.],
       [40., 34.],
       [30., 18.],
       [30., 33.],
       [30., 33.],
       [39., 20.],
       [34., 13.],
       [17., 19.],
       [26., 26.],
       [36., 21.],
       [37., 16.],
       [25., 39.],
       [28., 29.],
       [10., 16.],
       [36., 26.],
       [26., 26.],
       [13., 31.],
       [28., 35.],
       [37., 14.],
       [11., 23.],
       [20., 31.],
       [22., 28.],
       [20., 10.],
       [18., 40.],
       [29., 40.],
       [20., 38.],
       [16., 25.],
       [15., 29.],
       [15., 23.],
       [30., 39.],
       [12., 37.],
       [28., 37.],
       [21., 21.],
       [22., 11.],
       [29., 20.],
       [13., 14.],
       [32., 22.],
       [30., 14.],
       [26., 40.],
       [24., 21.],
       [20., 38.],
       [40., 29.],
       [24., 23.],
       [32., 23.],
       [31., 38.],
       [18., 19.],
       [17., 21.],
       [14., 18.],
       [12., 18.],
       [26., 37.],
       [17., 22.],
       [40., 11.],
       [38., 15.],
       [18., 12.],
       [26., 11.],
       [19., 39.],
       [33., 32.],
       [33., 26.],
       [12., 17.],
       [21., 17.],
       [39., 19.],
       [23., 11.],
       [22., 34.],
       [40., 25.],
       [38., 12.],
       [26., 25.],
       [16., 14.],
       [19., 30.],
       [11., 21.],
       [12., 14.],
       [39., 11.],
       [30., 39.],
       [36., 21.],
       [22., 18.],
       [32., 30.],
       [36., 40.],
       [12., 21.],
       [16., 27.],
       [34., 10.],
       [27., 22.],
       [29., 37.],
       [13., 11.],
       [28., 21.],
       [11., 39.],
       [25., 18.],
       [12., 38.],
       [40., 24.],
       [21., 15.],
       [24., 31.],
       [11., 32.],
       [29., 38.],
       [11., 33.],
       [32., 28.],
       [29., 19.],
       [23., 30.],
       [14., 31.],
       [38., 40.],
       [10., 30.],
       [10., 22.],
       [27., 20.],
       [10., 11.],
       [33., 12.],
       [14., 32.],
       [18., 36.],
       [16., 27.],
       [31., 22.],
       [28., 14.],
       [29., 32.],
       [21., 24.],
       [28., 40.],
       [31., 26.],
       [15., 40.],
       [36., 20.],
       [32., 34.],
       [32., 27.],
       [37., 33.],
       [31., 25.],
       [29., 15.],
       [35., 40.],
       [34., 36.],
       [21., 25.],
       [23., 24.],
       [21., 33.],
       [28., 17.],
       [11., 39.],
       [32., 30.],
       [32., 33.],
       [36., 25.],
       [10., 34.],
       [30., 29.],
       [15., 22.],
       [28., 34.],
       [38., 24.],
       [28., 39.],
       [16., 39.],
       [14., 39.],
       [35., 26.],
       [10., 29.],
       [23., 39.],
       [14., 39.],
       [31., 30.],
       [15., 10.],
       [24., 33.],
       [26., 18.],
       [33., 39.],
       [13., 35.],
       [34., 11.],
       [39., 21.],
       [11., 26.],
       [12., 11.],
       [29., 29.],
       [39., 21.],
       [12., 21.],
       [25., 38.],
       [37., 20.],
       [27., 33.],
       [17., 10.],
       [16., 17.],
       [39., 33.],
       [26., 12.],
       [17., 18.],
       [29., 25.],
       [27., 39.],
       [31., 36.],
       [23., 35.],
       [37., 20.],
       [16., 32.],
       [18., 32.],
       [23., 30.],
       [29., 40.],
       [35., 25.],
       [27., 19.],
       [17., 39.],
       [11., 25.],
       [32., 26.],
       [27., 14.],
       [29., 28.],
       [30., 13.],
       [27., 18.],
       [31., 13.],
       [33., 11.],
       [40., 40.],
       [34., 27.],
       [22., 26.],
       [15., 32.],
       [12., 23.],
       [17., 28.],
       [30., 21.],
       [40., 31.],
       [33., 32.],
       [13., 38.],
       [16., 39.],
       [19., 13.],
       [36., 33.],
       [35., 34.],
       [22., 19.],
       [34., 20.],
       [17., 19.],
       [11., 37.],
       [22., 40.],
       [24., 28.],
       [28., 37.],
       [38., 23.],
       [19., 35.],
       [12., 21.],
       [14., 26.],
       [34., 34.],
       [35., 24.],
       [32., 29.],
       [32., 27.],
       [15., 14.],
       [16., 17.],
       [24., 17.],
       [15., 25.],
       [20., 23.],
       [37., 38.],
       [36., 28.],
       [24., 40.],
       [26., 26.],
       [21., 34.],
       [16., 31.],
       [12., 36.],
       [17., 25.],
       [40., 14.],
       [22., 33.],
       [29., 39.],
       [35., 17.],
       [26., 36.],
       [23., 25.],
       [24., 34.],
       [24., 37.],
       [33., 14.],
       [16., 12.],
       [39., 13.],
       [18., 14.],
       [14., 28.],
       [20., 27.],
       [10., 22.],
       [11., 25.],
       [40., 27.],
       [29., 29.],
       [19., 29.],
       [20., 34.],
       [36., 32.],
       [25., 28.],
       [15., 31.],
       [18., 15.],
       [30., 31.],
       [13., 18.],
       [36., 12.],
       [26., 32.],
       [30., 22.],
       [26., 39.],
       [35., 34.],
       [18., 40.],
       [20., 31.],
       [31., 21.],
       [15., 33.],
       [39., 23.],
       [22., 23.],
       [34., 32.],
       [37., 19.]])
Sorties=np.array([[188.],
       [198.],
       [268.],
       [226.],
       [247.],
       [242.],
       [206.],
       [214.],
       [230.],
       [269.],
       [253.],
       [240.],
       [224.],
       [211.],
       [216.],
       [230.],
       [238.],
       [202.],
       [201.],
       [223.],
       [242.],
       [196.],
       [204.],
       [249.],
       [237.],
       [199.],
       [217.],
       [204.],
       [233.],
       [212.],
       [260.],
       [185.],
       [203.],
       [237.],
       [241.],
       [226.],
       [237.],
       [219.],
       [233.],
       [219.],
       [237.],
       [222.],
       [190.],
       [225.],
       [245.],
       [223.],
       [251.],
       [259.],
       [204.],
       [251.],
       [257.],
       [214.],
       [191.],
       [252.],
       [196.],
       [284.],
       [247.],
       [196.],
       [207.],
       [206.],
       [249.],
       [241.],
       [198.],
       [230.],
       [198.],
       [197.],
       [215.],
       [204.],
       [244.],
       [249.],
       [259.],
       [239.],
       [230.],
       [237.],
       [200.],
       [195.],
       [203.],
       [207.],
       [204.],
       [249.],
       [277.],
       [253.],
       [212.],
       [230.],
       [174.],
       [209.],
       [193.],
       [223.],
       [158.],
       [254.],
       [210.],
       [250.],
       [252.],
       [207.],
       [209.],
       [221.],
       [261.],
       [260.],
       [273.],
       [252.],
       [249.],
       [193.],
       [246.],
       [192.],
       [210.],
       [211.],
       [199.],
       [239.],
       [198.],
       [259.],
       [202.],
       [231.],
       [235.],
       [220.],
       [252.],
       [209.],
       [250.],
       [217.],
       [179.],
       [239.],
       [211.],
       [210.],
       [234.],
       [247.],
       [180.],
       [172.],
       [197.],
       [217.],
       [236.],
       [229.],
       [250.],
       [251.],
       [231.],
       [197.],
       [195.],
       [259.],
       [222.],
       [190.],
       [204.],
       [227.],
       [199.],
       [224.],
       [205.],
       [239.],
       [226.],
       [214.],
       [242.],
       [259.],
       [202.],
       [248.],
       [228.],
       [193.],
       [197.],
       [222.],
       [218.],
       [222.],
       [248.],
       [181.],
       [192.],
       [253.],
       [265.],
       [214.],
       [208.],
       [223.],
       [222.],
       [265.],
       [194.],
       [224.],
       [235.],
       [205.],
       [200.],
       [233.],
       [239.],
       [248.],
       [219.],
       [178.],
       [180.],
       [246.],
       [244.],
       [244.],
       [261.],
       [254.],
       [206.],
       [249.],
       [245.],
       [166.],
       [239.],
       [199.],
       [207.],
       [256.],
       [237.],
       [244.],
       [243.],
       [193.],
       [224.],
       [203.],
       [259.],
       [228.],
       [290.],
       [241.],
       [211.],
       [230.],
       [196.],
       [217.],
       [265.],
       [233.],
       [240.],
       [234.],
       [236.],
       [243.],
       [204.],
       [243.],
       [208.],
       [221.],
       [210.],
       [221.],
       [237.],
       [241.],
       [218.],
       [256.],
       [268.],
       [224.],
       [240.],
       [210.],
       [198.],
       [254.],
       [219.],
       [207.],
       [241.],
       [190.],
       [217.],
       [266.],
       [202.],
       [246.],
       [215.],
       [251.],
       [248.],
       [271.],
       [227.],
       [289.],
       [205.],
       [226.],
       [226.],
       [253.],
       [206.],
       [223.],
       [232.],
       [217.],
       [233.],
       [259.],
       [226.],
       [161.],
       [245.],
       [216.],
       [217.],
       [240.],
       [226.],
       [193.],
       [247.],
       [244.],
       [200.],
       [259.],
       [257.],
       [248.],
       [220.],
       [238.],
       [224.],
       [246.],
       [231.],
       [245.],
       [218.],
       [191.],
       [210.],
       [192.],
       [214.],
       [188.],
       [250.],
       [224.],
       [251.],
       [281.],
       [220.],
       [223.],
       [256.],
       [220.],
       [223.],
       [204.],
       [196.],
       [252.],
       [226.],
       [241.],
       [220.],
       [192.],
       [189.],
       [261.],
       [227.],
       [232.],
       [187.],
       [219.],
       [248.],
       [195.],
       [243.],
       [280.],
       [232.],
       [214.],
       [202.],
       [238.],
       [193.],
       [172.],
       [240.],
       [239.],
       [233.],
       [213.],
       [239.],
       [272.],
       [184.],
       [221.],
       [192.],
       [212.],
       [249.],
       [188.],
       [208.],
       [215.],
       [218.],
       [222.],
       [263.],
       [205.],
       [231.],
       [202.],
       [243.],
       [210.],
       [226.],
       [201.],
       [240.],
       [233.],
       [287.],
       [185.],
       [176.],
       [204.],
       [155.],
       [194.],
       [236.],
       [244.],
       [234.],
       [204.],
       [192.],
       [226.],
       [237.],
       [256.],
       [227.],
       [253.],
       [235.],
       [242.],
       [230.],
       [249.],
       [214.],
       [209.],
       [261.],
       [244.],
       [220.],
       [231.],
       [243.],
       [207.],
       [217.],
       [241.],
       [229.],
       [239.],
       [206.],
       [232.],
       [206.],
       [249.],
       [257.],
       [255.],
       [248.],
       [236.],
       [232.],
       [180.],
       [248.],
       [235.],
       [231.],
       [199.],
       [250.],
       [202.],
       [257.],
       [232.],
       [189.],
       [248.],
       [196.],
       [175.],
       [227.],
       [250.],
       [195.],
       [249.],
       [235.],
       [245.],
       [187.],
       [209.],
       [272.],
       [194.],
       [220.],
       [215.],
       [240.],
       [252.],
       [250.],
       [237.],
       [235.],
       [232.],
       [234.],
       [256.],
       [222.],
       [207.],
       [248.],
       [185.],
       [218.],
       [206.],
       [219.],
       [200.],
       [202.],
       [202.],
       [189.],
       [290.],
       [224.],
       [239.],
       [238.],
       [188.],
       [229.],
       [211.],
       [283.],
       [230.],
       [236.],
       [254.],
       [202.],
       [244.],
       [253.],
       [221.],
       [207.],
       [207.],
       [214.],
       [249.],
       [237.],
       [236.],
       [242.],
       [246.],
       [202.],
       [211.],
       [241.],
       [225.],
       [229.],
       [217.],
       [191.],
       [201.],
       [219.],
       [210.],
       [233.],
       [274.],
       [237.],
       [255.],
       [232.],
       [247.],
       [230.],
       [218.],
       [231.],
       [256.],
       [242.],
       [238.],
       [211.],
       [239.],
       [218.],
       [243.],
       [256.],
       [196.],
       [200.],
       [224.],
       [212.],
       [213.],
       [226.],
       [165.],
       [183.],
       [275.],
       [219.],
       [232.],
       [251.],
       [242.],
       [237.],
       [225.],
       [212.],
       [228.],
       [190.],
       [213.],
       [230.],
       [205.],
       [242.],
       [248.],
       [254.],
       [242.],
       [218.],
       [237.],
       [252.],
       [224.],
       [239.],
       [233.]])