import numpy as np
import matplotlib.pyplot as plt

from cycler import cycler

style_cycler = (cycler(linestyle=['-', '--', '-.', ':',(0, (3, 5, 1, 5, 1, 5))])+cycler(color=['teal',"firebrick","forestgreen","darkviolet","darkorange"]))
plt.rc('axes', prop_cycle=style_cycler)



T=np.arange(-0.5,2.5,0.01)

#%%

def sol(K,x0,t0,t):
    tv=t0+2*np.sqrt(x0)/K
    if t<t0:
        return x0
    elif t <tv:
        return (np.sqrt(x0)-K*(t-t0)/2)**2
    else:
        return 0
    
sol_v=np.vectorize(sol)   
#% variation de x0
plt.clf()
plt.ylim([0,3.5])
for x0 in [1,2,3]:
    y=np.copy(T)
    for i in range(len(T)):
        y[i]=sol(2,x0,0,T[i])
        
    plt.plot(T,y,label="x0="+str(x0))
plt.legend()
plt.show()
#% variation de t0
plt.clf()
plt.ylim([0,3.5])
for t0 in [0,0.5,1]:
    y=np.copy(T)
    for i in range(len(T)):
        y[i]=sol(2,3,t0,T[i])
        
    plt.plot(T,y,label="t0="+str(t0))
plt.legend()
plt.show()