import numpy as np
import matplotlib.pyplot as plt

def S(n,x) :
    s = 1 # le terme k=0
    for k in range(1,n) :
        s += 2 * x**(k**2)
    return s

n = 0
while (1/2)**(n-1)>1e-5 :
    n += 1
print(S(n,1/2))

def g(x) :
    return np.sqrt(np.pi/(1-x))
    
X = np.linspace(-1,1,1000)
Y,Z = [],[]
for x in X :
    n = 20
    Y.append(g(x))
    Z.append(S(n,x))
    
plt.plot(X,Y)
plt.plot(X,Z)
plt.show()

def r2(n) :
    if n == 0 :
        return 1
    else :
        compt = 0
        for k in range(n) :
            for h in range(n) :
                if k**2 + h**2 == n :
                    if k == 0 or h == 0 :
                        compt += 2
                    else :
                        compt += 4 # avec (h,k) (-h,k) (-k,h) et (-h,-k)
        return compt

for n in range(101) :
    print(r2(n))