# -*- coding: utf-8 -*-
"""
Created on Wed Oct  9 19:39:28 2024

@author: arnau
"""

# -*- coding: utf-8 -*-
"""
Created on Fri Oct  4 13:43:18 2024

@author: arnau
"""

import numpy as np
import scipy.integrate as integr
import matplotlib.pyplot as plt


#%% Fonctions communes

a = 0.036
kappa = 760
y0 = 16
T= 200

def solution(x):
    '''solution exacte de l'équation de Gompertz'''
    c = np.log(y0/kappa)
    return kappa*np.exp(c*np.exp(-a*x))



#%% Résolution avec scipy. integrate

def F(y,t):
    return a*y*np.log(kappa/y)

Abs = np.linspace(0,T,1000)
Y = integr.odeint(F,y0,Abs)
plt.plot(Abs,Y,label='solution scipy')
plt.plot(Abs,solution(Abs),'-.', label='solution exacte')
plt.legend()
plt.plot()

#%% Méthode d'Euler

def Euler(y0,T,n):
    '''
    Entrées : [0,T] les extremités de l'intervalle de définition d'une solution
              y0 une condition initiale
              n le nombre de subdivision
    Sortie : une approximation de la solution de y'=F(y)
    sur [a,b] telle que y(a)=y0 avec Euler
    '''
    h=T/n  # pas 
    Y=[y0]
    for _ in range(1,n+1):
        yold=Y[-1]
        ynew = yold + h*a*yold*np.log(kappa/yold) 
        Y.append(ynew)# liste des approximation de y(a+k(b-a)/n)
    return Y

for n in [5,10,50,100]:
    Temps = [k*T/n for k in range(n+1)]
    Ord = Euler(y0,T,n)
    plt.plot(Temps,Ord,label='Euler avec n='+str(n))
#plt.plot(Abs,Y,label='solution scipy')

plt.plot(Abs,solution(Abs),'--', label='solution exacte')
plt.legend()
plt.plot()    
plt.show()







#%% Erreur

def Erreur(y0,T,n):
    Y1 = Euler(y0,T,n) # approximation par Euler
    E1 = [np.abs(solution(k*T/n)-Y1[k]) for k in range(n+1)] # erreur pour Euler
    return max(E1)

#%% erreur à T constant et pas qui tend vers $0$

Nb = range(1,100)  # nombre de subdivision
Erreur1= [Erreur(y0,T,k) for k in Nb]
plt.plot(Nb,Erreur1, color='red', label='Erreur de la méthode Euler')  # erreur de la méthode d'Euler
plt.title('Erreur lorsque le pas tend vers 0')
plt.xlabel('Nb de subdivision n')
plt.ylabel('Erreur')
plt.legend()
plt.show()

