from math import *
import matplotlib.pyplot as plt

from scipy.integrate import odeint

## Mouvements à force centrale

## Le champ centrale (force/masse système) : n'est conservatif que si Fr
## ne dépend que de r
# -1/r^n
def FrPuiss(r,n):
    return -1/r**n

# gravité
def FrGrav(r):
    return FrPuiss(r,2)

# rappel élastique longueur nulle (ou modèle de Thomson de l'atome - k vec OM)
def FrHarm(r):
    return -r

# rappel élastique : ressort de longueur à vide 1, de raideur 1
def FrRessort(r):
    return 1-r

# 1/r cube
def FrCube(r):
    return FrPuiss(r,3)


# Equadiff sur r : puisque r²θpt = C, cte, on trouve
# rptpt -C²/r^3 = Fr

# Fonction de l'équadiff vectorisée : passage de Y=(r,rpt,th) à Ypt=(rpt, rptpt,thpt)
# Les deux premiers arguments doivent être Y et t; extra args ensuite.
def Feq(Y,t,champ,C):
    r,rpt,th=Y
    return(rpt,champ(r)+C**2/r**3,C/r**2)

# Construit la liste de tout, avec
# Fr : le champ considéré
# C : constante des aires
# Y0 : CI = (rpt0,r0,th0)
# N : nombre de points
# dt : le pas temporel
def EulerVect(Fr,C,Y0,N,dt):
    r,rpt,th=Y0
    t=0
    rPtpt=0
    tList=[0]*N
    rList=[0]*N
    rPtList=[0]*N
    thList=[0]*N
    for i in range (N):
        tList[i]=t
        rList[i]=r
        thList[i]=th
        rPt,rPtpt,thPt=Feq(Y0,0,Fr,C)
        rPtList[i]=rPt
        t+=dt
        r+=rPt*dt
        rPt+=rPtpt*dt
        th+=thPt*dt
        Y0=(r,rPt,th)
    return tList,rList,rPtList,thList

def odeVect(Fr,C,Y0,N,dt):
    tList=[i*dt for i in range(N)]
    res=odeint(Feq,Y0,tList,args=(Fr,C))
    rList=[]
    rPtList=[]
    thList=[]
    for Y in res :
        r,rPt,th=Y
        rList.append(r)
        rPtList.append(rPt)
        thList.append(th)
    return tList,rList,rPtList,thList


def trace(v0,alphDeg,r0=1,champ=FrGrav,N=2000,dt=0.01,meth=odeVect):
    alpha=alphDeg*pi/180
    C=r0*v0*sin(alpha)
    rpt0=v0*cos(alpha)
    Y0=(r0,rpt0,0)

    tList,rList,rPtList,thList=meth(champ,C,Y0,N,dt)
    plt.polar(thList,rList)
    plt.title('Trajectoire')
    return C,tList,rList,rPtList,thList

plt.subplot(221,polar=True)
C,tList,rList,rPtList,thList=trace(.9,70,dt=0.004,champ=FrGrav)

plt.subplot(222)
plt.plot(tList,rList,label='r(t)')
plt.legend()
plt.subplot(224)
plt.plot(tList,thList,label='θ(t)')
plt.legend()
plt.subplot(223)

def vitesseOrtho(C,r):
    return C/r
def vitesse(C,r,rPt):
    return sqrt(rPt**2+(C/r)**2)

vList=[]
for i in range(len(tList)):
    vList.append(vitesse(C,rList[i],rPtList[i]))
plt.plot(tList,vList,label='v(t)')
vOrthList=[]
for i in range(len(tList)):
    vOrthList.append(vitesseOrtho(C,rList[i]))
plt.plot(tList,vOrthList,'g--',label='vθ(t)')
plt.plot(tList,rPtList,'r--',label='vr(t)')

plt.legend()
plt.show()
