import numpy as np
import matplotlib.pyplot as plt
import random as rd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
iris = load_iris()

##
couleurs = np.array(['indigo', 'darkcyan', 'gold'])
fig, ax = plt.subplots(1, 2, figsize=(16,7))
for i in range(2):
    ax[i].scatter(iris.data[:,2*i], iris.data[:,2*i+1], color = couleurs[iris.target])
    h = [plt.Line2D([0,0],[0,0],color=couleurs[j],marker='o',linestyle='', \
         label=l) for j, l in enumerate(iris.target_names)]
    ax[i].legend(handles=h, title='Variété :')
    ax[i].set_xlabel(iris.feature_names[2*i])
    ax[i].set_ylabel(iris.feature_names[2*i + 1])
plt.show()

## Question 1
print(type(iris.data))
print(type(iris.target))
print(len(iris.data))

variete = {0 : "setosa", 1: "versicolor", 2 : "virginica"}
compte = [0,0,0]
for k in range(len(iris.target)):
    print(variete[iris.target[k]])
    compte[iris.target[k]] += 1
print(compte)

## Question 2
for k in [0,50,-1]:
    print(variete[iris.target[k]])
    print("longueur des sépales : ", iris.data[k][0], "cm")
    print("largeur des sépales : ", iris.data[k][1], "cm")
    print("longueur des pétales : ", iris.data[k][2], "cm")
    print("largeur des pétales : ", iris.data[k][3], "cm")


## Entrainement/test
train_data, test_data, train_target, test_target = train_test_split(iris.data, iris.target, test_size=0.5, shuffle=True, random_state=12)

## Question 3
print(len(train_data))
v = [0,0,0]
for k in range(len(train_data)):
    v[train_target[k]] += 1
print(v)

for k in range(3):
    print(variete[train_target[k]])


## Question 4
def d(u,v):
    s = 0
    for k in range(4):
        s += (u[k]-v[k])**2
    return s**0.5

print(d(iris.data[54],iris.data[67]))

## Question 5
def variete_plus_proche(x):
    n = len(train_data)
    dist_min = np.inf
    for k in range(n):
        if 0 < d(train_data[k], x) < dist_min:
            k_min = k
            dist_min = d(train_data[k], x)
    return train_target[k_min]

## Question 6
s = 0
n=len(test_data)
for k in range(n):
    x = test_data[k]
    if variete_plus_proche(x) == test_target[k]:
        s+= 1
print("sur", n, "tests")
print("taux de réussite : ", s/n)


## Question 7
def partition(L,x):
    pivot = L[0]
    inf = []
    sup = []
    for e in L[1:]:
        if d(train_data[e],x) < d(train_data[pivot],x):
            inf.append(e)
        else:
            sup.append(e)
    return pivot, inf, sup

def tri_rapide(L, x):
    if len(L) <= 1:
        return L
    pivot, inf, sup = partition(L,x)
    inf = tri_rapide(inf, x)
    sup = tri_rapide(sup, x)
    return inf + [pivot] + sup

## Question 8
def k_plus_proches_voisins(x, k):
    L = list(range(len(train_data)))
    return tri_rapide(L, x)[:k]

## Question 9
def occurrences(L):
    nb_par_variete = [0,0,0]
    for x in L:
        nb_par_variete[x] += 1
    return nb_par_variete

print(occurrences([0,2,1,1,1,2,0,1,2,0]))

## Question 10
def variete_majoritaire(L):
    i_max = 0
    for i in range(len(L)):
        if L[i] > L[i_max]:
            i_max = i
    return i_max

## Question 11
def prediction_variete(x,k):
    indices = k_plus_proches_voisins(x, k)
    L = train_target[indices]
    return variete_majoritaire(occurrences(L))

print(prediction_variete(test_data[37], 3))
print(test_target[37])

##
def plot_k_plus_proches(x, k):
    indices_voisins = k_plus_proches_voisins(x, k)
    fig, ax = plt.subplots(1, 2, figsize=(16,7))
    for i in range(2):
        ax[i].scatter(train_data[:,2*i], train_data[:,2*i+1], color = couleurs[train_target])
        h = [plt.Line2D([0,0],[0,0],color=couleurs[j],marker='o',linestyle='', \
            label=l) for j, l in enumerate(iris.target_names)]
        ax[i].legend(handles=h, title='Variété :')
        ax[i].set_xlabel(iris.feature_names[2*i])
        ax[i].set_ylabel(iris.feature_names[2*i + 1])
        ax[i].scatter(x[2*i], x[2*i + 1], s=100, marker='X', \
              color=couleurs[prediction_variete(x, k)])
        for j in indices_voisins:
            v = train_data[j]
            ax[i].scatter(v[2*i], v[2*i + 1], s=100, marker='o', \
                color=couleurs[train_target[j]])
    plt.show()

plot_k_plus_proches(test_data[36], 4)

## Question 12
def precision(k):
    s = 0
    for i in range(len(test_data)):
        if prediction_variete(test_data[i],k) == test_target[i]:
            s+= 1
    return s/n

## Question 13
abs = list(range(1, 51))
precision_max = 0
ord = []
for k in abs:
    x = precision(k)
    ord.append(x)
    if x > precision_max:
        k_max = k
        precision_max = x
print("Précision maximale : ", precision_max)
print("pour k = ", k_max)

## Question 14
plt.plot(abs,ord)
plt.xlabel("k")
plt.ylabel("précision")
plt.show()


## Questions bonus
def moyenne(L):
    s = 0
    n = len(L)
    for k in range(n):
        s += L[k]
    return s/n


moyennes_var = []
for k in range(3):
    m = moyenne(iris.data[50*k:50*(k+1)])
    print(m)
    moyennes_var.append(m)

fig, ax = plt.subplots(1, 2, figsize=(16,7))
for i in range(2):
    ax[i].scatter(iris.data[:,2*i], iris.data[:,2*i+1], color = couleurs[iris.target])
    h = [plt.Line2D([0,0],[0,0],color=couleurs[j],marker='o',linestyle='', \
         label=l) for j, l in enumerate(iris.target_names)]
    ax[i].legend(handles=h, title='Variété :')
    ax[i].set_xlabel(iris.feature_names[2*i])
    ax[i].set_ylabel(iris.feature_names[2*i + 1])
    for k in range(3):
        ax[i].scatter(moyennes_var[k][2*i],moyennes_var[k][2*i+1], color=couleurs[k], marker="+",s=200)
plt.show()


