# -*- coding: utf-8 -*-
"""
Created on Sun Jan 29 20:36:54 2023

@author: benja
"""



import matplotlib.pyplot as plt

import numpy as np




""" Tracer la propagation d'onde dans une chaine de ressorts"""


# Donnees du probleme
k = 3      # raideur
a = 1      # distance interatomique au repos
m = 1      # masse atome

# Discretisation spatiale
n_x = 1001                # nb d'atomes
# L = a*(n_x-1)             # longueur totale au repos
L_x = [j*a for j in range(n_x)]    # liste des positions au repos des atomes

# Discretisation temporelle
dt = 0.01              # pas de temps << 2pi * sqrt(m/k)
n_t = 20000            # nb de pas de temps



# Initialisations de liste de deplacement au repos
L_xi = [0]*n_x             # liste des deplacements xi
L_xipoint = [0]*n_x        # liste des vitesse d(xi)/dt
# Perturbation initiale
ampl = 0.1*a               # perturbation max
sigma = 4                  # ecart-type (30:non dispersion, 10:amplitude varie, 4:forte dispersion, 3:artefacts)
x0 = 500                   # centre de la perturbation
for i in range(n_x):
    L_xi[i] = ampl*np.exp(-(i-x0)**2/(2*sigma**2))
   


# pour graphe entre x1 et x2
x1 = 0
x2 = n_x

# Resolution numerique
plt.figure()
plt.plot(L_x[x1:x2],L_xi[x1:x2])      # tracer xi(x) initial
for i in range(n_t):                  # pour chaque instant :
    L_xi_temp = [0]                      # xi(x) temporaire pour ce t, initialisee a 0
    L_xipoint_temp = [0]                 # dxi/dt(x) temporaire pour ce t, initialisee a 0
    for j in range(1,n_x-1):          # pour chaque position suivante :
        L_xi_temp.append(L_xi[j]+(L_xipoint[j])*dt)                                # Euler sur xi
        L_xipoint_temp.append(L_xipoint[j]+(L_xi[j+1]+L_xi[j-1]-2*L_xi[j])*dt*k/m) # Euler sur d(xi)/dt
    L_xi_temp.append(0)               # extremite a 0
    L_xipoint_temp.append(0)          # extremite a 0
    L_xi[:] = L_xi_temp[:]            # copie complete dans la liste L_xi
    L_xipoint[:] = L_xipoint_temp[:]  # copie complete dans la liste L_xipoint
    if ((i+1)%5000 == 0):             # trace tous les xxxx dt
        plt.plot(L_x[x1:x2],L_xi[x1:x2])   
plt.xlabel('Position x (avec a = 1)'); plt.ylabel('perturbation xi (avec a = 1)')
plt.title('Perturbation fine (sigma = 4a) : propagation dispersive')
plt.show()


# figures pour sigma = 40 et 4