import numpy as np
import random as rd
import copy 
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import time

pos_ex=[[0, 0 ,0, 0, 0, 0, 0], [0, 0, 2, 0, 0, 0, 0], [0, 0, 1 ,0, 0 ,0, 0], [0, 0 ,2, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0 ,1], [0, 2 ,1, 0, 1, 0 ,2]]

##pour afficher un joli jeu 

def affiche_grille(grille):
    #pions jaunes = joueur1
    ylim,xlim = 6,7
    xoffset = 0.5
    yoffset = 0.5
    decal = 0.5
    fig,ax = plt.subplots()
    ax.set(xlim = (0.5*xoffset,xlim+1.5*xoffset), ylim = (0.5*yoffset,ylim+1.5*yoffset), aspect = "equal")
    ax.add_artist(mpatches.Rectangle((xoffset,yoffset), xlim, ylim))
    
    for i in range(0,xlim) :
        for j in range(0,ylim):
            if grille[j][i] == 0 :
                ax.add_artist(mpatches.Circle(xy = (xoffset+i+decal,ylim-(yoffset+j)+decal), radius = 0.4, color = (1,1,1)))
            if grille[j][i] == 2 :
                ax.add_artist(mpatches.Circle(xy = (xoffset+i+decal,ylim-(yoffset+j)+decal), radius = 0.4, color = (1,0,0)))
            if grille[j][i]== 1 :
                ax.add_artist(mpatches.Circle(xy = (xoffset+i+decal,ylim-(yoffset+j)+decal), radius = 0.4, color = (1,1,0)))  
    plt.show()
    return None

affiche_grille(pos_ex)

def vide():
    return [[0 for j in range(7)] for i in range(6)]
#ATTENTION ne pas faire [[0]*7]*6 
# car un changement d'un coefficient se duplique sur toute la colonne!!!

def pleine(p,c):
    '''vérifie si la colonne c est pleine'''
    return p[0][c]!=0

def coups_possibles(p):
    liste_col=[]
    for c in range(7):
        if not pleine(p,c):
            liste_col.append(c)
    return liste_col

def jouer(p,c,i):
    '''renvoie la nouvelle position de jeu si le joueur i joue la colonne c'''
    newp=copy.deepcopy(p) #ATTENTION copy ne suffit pas
    for j in range(-1,-7,-1): #on cherche le coefficient nul le plus bas possible dans la colonne c
        if p[j][c]==0:
            newp[j][c]=i
            return newp
    return None #cas où la colonne i est déjà pleine


def gagne(p,i):
    #print('la position p de gagne: ',p)
    g=[i]*4 #la liste gagnante
    
    #détection des alignements horizontaux
    for i in range(len(p)):
        for j in range(4):
            if p[i][j:j+4]==g:
                #si la grille est un tableau numpy, faire list(p[i][j:j+4])
                #car sinon le test d'égalité ne fonctionnera pas
                return True
    
    #détection des alignement verticaux
    for j in range(len(p[0])):
        
        for i in range(3):
            liste=[]
            for k in range(i,i+4):
                liste.append(p[k][j])
            if liste==g:
                return True
            
    #détection des alignement diagonaux haut->bas
    for i0 in range(3):
        for j0 in range(4): #les points de départ
            liste=[]
            for k in range(4):
                liste.append(p[i0+k][j0+k])
            if liste==g:
                return True

                                 

    #détection des alignement diagonaux bas->haut
    for i0 in range(3,6):
        for j0 in range(4): #les points de départ
            liste=[]
            for k in range(4):
                liste.append(p[i0-k][j0+k])
            if liste==g:
                return True

    return False

#création du tableau des points pour calculer h(p)
points=[[3,4,5,7,5,4,3],[4,6,8,10,8,6,4],[5,8,11,13,11,8,5],[5,8,11,13,11,8,5],[4,6,8,10,8,6,4],[3,4,5,7,5,4,3]]

def h(p):
    #cas où la position contient un alignement gagnant pour J1 ou J2
    if gagne(p,1):
        return np.inf
    if gagne(p,2):
        return -np.inf
    
    #cas où ce n'est pas une position terminale gagnante pour l'un des joueurs
    somme=0
    for i in range(len(p)):
        for j in range(len(p[0])):
            if p[i][j]!=0:
                somme+=points[i][j]*(-2*p[i][j]+3) #(-2*p[i][j]+3=1 si joueur1 =-1 si joueur 2
    return somme



##Algorithme Min-Max

def maxiJ1(p,n): #on cherche à maximiser l'heuristique
                    #concerne le joueur 1
    if n==0:
        return (h(p),None) 
    if gagne(p,1):
        return (np.inf,None) #le jeu s'arrête dans ce cas

    coups=coups_possibles(p)
    if coups==[]: #la grille est pleine, la récursivité s'arrête
        return (h(p),None)
    
    #si la partie peut se poursuivre:
    maxi=-np.inf
    c_maxi=coups[0]
    for c in coups:
        pos=jouer(p,c,1) #la position suivante si J1 joue le coup c
        valeur=miniJ2(pos,n-1)[0] #J2 va chercher à minimiser la valeur
        if valeur>=maxi:
            maxi=valeur
            c_maxi=c
    return (maxi,c_maxi)

def miniJ2(p,n): #on cherche à minimiser l'heuristique
                    #concerne le joueur 2
    if n==0:
        return (h(p),None)
    if gagne(p,2):
        return (-np.inf,None) #le jeu s'arrête dans ce cas

    coups=coups_possibles(p)
    if coups==[]:  #la grille est pleine, la récursivité s'arrête
        return (h(p),None)
    
    #si la partie peut se poursuivre:
    mini=np.inf
    c_maxi=coups[0]
    for c in coups:
        pos=jouer(p,c,2) #la position suivante si J2 joue le coup c
        valeur=maxiJ1(pos,n-1)[0] #J1 va chercher à maximiser la valeur
        if valeur<=mini:
            mini=valeur
            c_mini=c
    return (mini,c_mini)

def simulation_random():
    '''simule une partie où chaque joueur joue au hasard'''
    p=vide()
    i=1 #le joueur qui commence à jouer
    Fin=False #pour vérifier que la grille n'est pas pleine
    while  gagne(p,1)==False and gagne(p,2)==False and not Fin : 
             
        colonnes_dispos=coups_possibles(p)
        if colonnes_dispos!=[]:
            c=rd.choice(colonnes_dispos)
            p=jouer(p,c,i)
            affiche_grille(p)
            time.sleep(0.5)
            i=-i+3  #on change de joueur
        else:
            Fin=True #le jeu est fini car la grille est pleine
        
    if gagne(p,1):
        print('le joueur 1 a gagné')
        #return 1
    elif gagne(p,2):
        print('le joueur 2 a gagné')
        #return 2
    else:
        print('match nul')
        #return 0
    
def stat_random(N):#on lance N simulations de jeu
    res=[0,0,0]
    for i in range(N):
        res[simulation_random()]+=1
    return res       
        
#quand on joue au hasard, très très peu de matchs nuls, le joueur1 a un peu plus de chance de gagner que joueur 2

#-------------------------------------------------------------------------
#on fait jouer les joueurs avec l'algo minmax

def simulation_joueur1et2(n): #le joueur1 prévoit n coups d'avance
    #le joueur2 prévoit AUSSI n coups d'avance 
    p=vide()
    i=1 #le joueur qui commence à jouer
    Fin=False #pour 
    while gagne(p,1)==False and gagne(p,2)==False and not Fin:
        
        
        colonnes_dispos=coups_possibles(p)
        if colonnes_dispos!=[]:
            if i==1:
                c=maxiJ1(p,n)[1]
                
                p=jouer(p,c,i)
                
                i=2  #on change de joueur
            elif i==2: #ATTENTION à ne pas mettre seulement un if
                 c=miniJ2(p,n)[1]
                 p=jouer(p,c,i)
                 i=1
            affiche_grille(p)
            time.sleep(0.5)     
        else:
            Fin=True #le jeu est fini car tout est plein
    
    #print(np.array(pos))
    if gagne(p,1):
        print('le joueur 1 a gagné')
        #return 1
    elif gagne(p,2):
        print('le joueur 2 a gagné')
        #return 2
    else:
        print('match nul')
        #return 0

