import numpy as np
import matplotlib.pyplot as plt
import random

data_facile = np.genfromtxt('facile.csv', delimiter=',')
#data_moyen = np.genfromtxt('moyen.csv', delimiter=',')

def tracer(data, k):
    plt.figure()
    for k in range(k):
        x = [data[i][0] for i in range(len(data)) if data[i][2] == k]
        y = [data[i][1] for i in range(len(data)) if data[i][2] == k]
        plt.plot(x, y, '.', label = str(k))
    plt.legend(loc='right')

tracer(data_facile, 10)
#tracer(data_moyen, 20)

def barycentre(l):
    n = len(l)
    assert(n > 0)

    sx, sy = 0, 0
    for p in l:
        sx = sx + p[0]
        sy = sy + p[1]
    return (sx / n, sy / n)

def dist2(a, b):
    return (a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2

def dist(a, b):
    return dist2(a,b)**0.5

def plus_proche(a, l):
    n = len(l)
    assert(n > 0)

    imin = 0
    dmin = dist(a, l[0])
    for i in range(1, n):
        d = dist(a, l[i])
        if d < dmin:
            dmin = d
            imin = i
    return imin

def k_moyennes(k, l):
    assert(k >= 2)
    C = random.sample(list(l), k)
    for t in range(20):
        S = [ [] for i in range(k) ]
        for x in l:
            i = plus_proche(x, C)
            S[i].append(x)
        for i in range(k):
            if (S[i] != []):
                C[i] = barycentre(S[i])
    return (C, S)

def formater(C, S):
    l = []
    n = len(S)
    for i in range(n):
        for x in S[i]:
            l.append((x[0], x[1], i))
    return l

k = 8
(Cf, Sf) = k_moyennes(k, data_facile)
tracer(formater(Cf, Sf), k)

plt.show()

def score(C,S):
    Q = 0
    for i in range(len(S)):
        for x in S[i]:
            Q += dist2(x, C[i])
    return Q
    
def k_moyennes_v2(k, l):
    assert(k >= 2)
    C = random.sample(list(l), k)
    for t in range(20):
        S = [ [] for i in range(k) ]
        for x in l:
            i = plus_proche(x, C)
            S[i].append(x)
        for i in range(k):
            if (S[i] != []):
                C[i] = barycentre(S[i])
        print(score(C,S))
    return (C, S)

(Cf, Sf) = k_moyennes_v2(k, data_facile)

def trace_q(k, l):
    assert(k >= 2)
    C = random.sample(list(l), k)
    X = [i for i in range(20)]
    Q = []
    for t in X:
        S = [ [] for i in range(k) ]
        for x in l:
            i = plus_proche(x, C)
            S[i].append(x)
        for i in range(k):
            if (S[i] != []):
                C[i] = barycentre(S[i])
        Q.append(score(C,S))
    return X, Q
    
X, Q = trace_q(k, data_facile)
plt.plot(X, Q)
plt.show()

def k_moyennes_v3(k, l, epsilon):
    assert(k >= 2)
    C = random.sample(list(l), k)
    Q = float('inf')
    while True:
        S = [ [] for i in range(k) ]
        for x in l:
            i = plus_proche(x, C)
            S[i].append(x)
        for i in range(k):
            if (S[i] != []):
                C[i] = barycentre(S[i])
        prev_Q = Q
        Q = score(C,S)
        if abs(prev_Q - Q) <= epsilon:
            return (C, S)

(Cf, Sf) = k_moyennes_v3(k, data_facile, 0.01)
tracer(formater(Cf, Sf), k)
plt.show()
