import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

### Fermer toutes les figures avant de commencer
plt.close("all")

### Couleurs des courbes
colorlist = ['seagreen','firebrick','steelblue','orangered','darkorchid']


"""
RÃ‰SOLUTION PHYSIQUE DE L'Ã‰QUATION DE LA CHALEUR
"""

### ParamÃ¨tres Ã  modifier en classe
Dth = .01    # diffusivitÃ©
DeltaT = 10 # Ã©cart de tempÃ©rature


### Parametres physiques et numeriques
L_time=10       # durÃ©e totale
L_space=1

N_time = 20000  # 20000
N_space = 100   # 100

# Increment spatial et temporel
dx=L_space/(N_space-1)
x = np.array([dx*i for i in range(N_space)])

dt=L_time/(N_time-1)


def laplacien(T,dx,Tinit,Tlast):
    n = len(T)
    lapl = np.zeros_like(T)
    for i in range(n):
        if i!=0 and i!=n-1:
            lapl[i] = (T[i+1] - 2*T[i] + T[i-1])/dx**2
        lapl[0] = (T[1] - 2*T[0] + Tinit)/dx**2
        lapl[n-1] = (Tlast - 2*T[n-1] + T[n-2])/dx**2
    return lapl

### Fonction qui fait la rÃ©solution
def resolution(Dth, DeltaT):
    print("Le temps caracteristique de diffusion est tdif = "+str(round(L_space**2/Dth))+" min")    
    
    print('StabilitÃ© de la rÃ©solution :')
    print('\t D dt =', Dth*L_time/N_time)
    print('\t .5 dx**2=', .5*(L_space/N_space)**2)
    
    if Dth*L_time/N_time > .5*(L_space/N_space)**2:
        print('RÃ©solution instable : il faut diminuer le pas de temps')
        return
    else:
        print('RÃ©solution stable, je fais les calculs')
        
    # Ã‰tat initial
    Ti = np.array([20 for i in range(N_space)])
    Tg = Ti[0] + DeltaT
    Td = Ti[-1]
    T  = [Ti]
    
    # Initialisation de la couleur des courbes    
    img = 0
    
    plt.figure(figsize=(14,8))

    for i in range(N_time):
        delta_T = laplacien(Ti,dx,Tg,Td)
        Ti = Ti + dt * Dth * delta_T
#        for j in range(N_space):
#            Ti[j] = Ti[j] + dt*Dth*delta_T[j]
        T.append(Ti)
    
    # On ne trace que quelques instants
        if i==10 or i== 200 or i ==2000 or i==5000 or i==N_time-1 :
            plt.plot(x,T[i],label=r'$t$='+str(round(i*dt,3)),color=colorlist[img])
            img += 1
    
    plt.grid(True)
    plt.xlabel(r'$x$ (m)')
    plt.ylabel(r'Temp\'{e}rature (\si{\celsius})')
    plt.legend(fontsize='small')
    plt.ylim(Td-2,Tg+2)    

    return T


T = resolution(Dth, DeltaT)

