import matplotlib.pyplot as plt
import numpy as np

def inittab():  # Creation d'un tableau vierge avec NxN éléments et conditions aux limites
    tab = []
    for i in range(N):  # Création  d'un tableau avec N éléments nuls
        tab.append(0.)
    # Ajout des conditions initiales
    tab[a]  = 100.
    tab[b]  = -100.
    return tab

def iteration(tab): #Effectue une itéraion et renvoie le nouveau tableau
    Vnew = inittab()
    for i in range(1,N-1):
        if (i==a or i==b):
            pass
        else:
            Vnew[i] = (tab[i-1]+tab[i+1])/2+(tab[i+1]-tab[i-1])/(4*r[i])
    return Vnew


def converge():
    tab = inittab()
    ecart = 200
    Npas = 0        # nombre d'itérations
    while ecart > eps:      # Tant que l'écart entre le nouveau potentiel et l'ancien est plus grand la tolérance, itération
        tab2 = iteration(tab)
        diff = [abs(tab2[i]-tab[i]) for i in range(len(tab))]
        ecart = max(diff)
        tab = tab2
        Npas += 1       # On incrémente aussi le nombre de pas
    print('Npas =', Npas)
    return tab

def matrice(V):
    M = np.zeros([2*C+1, 2*C+1])
    for i in range(0,2*C+1):  #(-C,C):
        for j in range(0,2*C+1):
            M[i,j] = V[int(np.sqrt((i-C)**2 + (j-C)**2))]
    return M

# Initialisation
N = 101     # dimension du tableau
a = 15
b = 35
eps = 0.01  # ecart toléré
Npas = 0    # nombre d'itérations
dr = 1
r = [i*dr for i in range(N)]
C = 70     # coordonnées en x et y du centre de l'image
V = inittab()   # premiere distribution

V2 = converge()
M2= matrice(V2)

# Affichage
fig = plt.figure(1)
plt.plot(V2)
ax = fig.add_subplot(111)
ax.set_xlabel('rayon')
ax.set_ylabel('potentiel')

plt.figure(2)
plt.plot(M2[0])

fig = plt.figure(3)
ax = fig.add_subplot(1,1,1)
ax.set_aspect('equal')

theta = np.linspace(0, 6.3, 70)
x1, y1 = C+a*np.cos(theta), C+a*np.sin(theta)
x2, y2 = C+b*np.cos(theta), C+b*np.sin(theta)
plt.plot(x1, y1)
plt.plot(x2, y2)

plt.imshow(M2, interpolation='nearest', cmap=plt.cm.seismic) #RdBu)    #hot)  #ocean)
plt.colorbar()
plt.show()