import random
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# -------------------
# Fonction générique de visualisation
# -------------------
def visualize(data, steps, title="Tri"):
    fig, ax = plt.subplots()
    ax.set_title(title)
    bar_rects = ax.bar(range(len(data)), data, align="edge", color="blue")

    ax.set_xlim(0, len(data))
    ax.set_ylim(0, int(1.1 * max(data)))

    text_step = ax.text(0.02, 0.95, "", transform=ax.transAxes)
    text_comp = ax.text(0.7, 0.95, "", transform=ax.transAxes)

    def update_fig(step):
        A, indices, comparisons = step
        for rect, val in zip(bar_rects, A):
            rect.set_height(val)
            rect.set_color("blue")
        for idx in indices:
            bar_rects[idx].set_color("red")
        text_comp.set_text(f"Comparaisons : {comparisons}")

    anim = animation.FuncAnimation(fig, update_fig, frames=steps, interval=50, repeat=False)
    plt.show()


# ============================================================
# TRIS SIMPLES — version étudiante
# ============================================================

def tri_bulle(A):
    n = len(A)
    for i in range(n):
        for j in range(0, n - i - 1):
            if A[j] > A[j + 1]:
                A[j], A[j + 1] = A[j + 1], A[j]

def tri_insertion(A):
    for i in range(1, len(A)):
        key = A[i]
        j = i - 1
        while j >= 0 and A[j] > key:
            A[j + 1] = A[j]
            j -= 1
        A[j + 1] = key

def tri_selection(A):
    n = len(A)
    for i in range(n):
        min_idx = i
        for j in range(i + 1, n):
            if A[j] < A[min_idx]:
                min_idx = j
        A[i], A[min_idx] = A[min_idx], A[i]


# ============================================================
# TRIS SIMPLES — version espion (avec enregistrement des étapes)
# ============================================================

def record_bubble_sort(A):
    steps = []
    comparisons = 0
    n = len(A)
    for i in range(n):
        for j in range(0, n - i - 1):
            comparisons += 1
            if A[j] > A[j + 1]:
                A[j], A[j + 1] = A[j + 1], A[j]
            steps.append((A.copy(), [j, j+1], comparisons))
    return steps

def record_insertion_sort(A):
    steps = []
    comparisons = 0
    for i in range(1, len(A)):
        key = A[i]
        j = i - 1
        while j >= 0:
            comparisons += 1
            if A[j] > key:
                A[j + 1] = A[j]
                steps.append((A.copy(), [j, j+1], comparisons))
                j -= 1
            else:
                break
        A[j + 1] = key
        steps.append((A.copy(), [j+1], comparisons))
    return steps

def record_selection_sort(A):
    steps = []
    comparisons = 0
    n = len(A)
    for i in range(n):
        min_idx = i
        for j in range(i + 1, n):
            comparisons += 1
            if A[j] < A[min_idx]:
                min_idx = j
            steps.append((A.copy(), [i, j], comparisons))
        A[i], A[min_idx] = A[min_idx], A[i]
        steps.append((A.copy(), [i, min_idx], comparisons))
    return steps


# ============================================================
# TRIS RÉCURSIFS — version espion
# ============================================================

def record_merge_sort(A):
    steps = []
    def merge_sort(A, start=0, end=None, comparisons=0):
        if end is None:
            end = len(A)
        if end - start > 1:
            mid = (start + end) // 2
            comparisons = merge_sort(A, start, mid, comparisons)
            comparisons = merge_sort(A, mid, end, comparisons)
            comparisons = merge(A, start, mid, end, comparisons)
        return comparisons

    def merge(A, start, mid, end, comparisons):
        left = A[start:mid]
        right = A[mid:end]
        i = j = 0
        k = start
        while i < len(left) and j < len(right):
            comparisons += 1
            if left[i] <= right[j]:
                A[k] = left[i]
                i += 1
            else:
                A[k] = right[j]
                j += 1
            steps.append((A.copy(), [k], comparisons))
            k += 1
        while i < len(left):
            A[k] = left[i]
            i += 1
            k += 1
            steps.append((A.copy(), [k-1], comparisons))
        while j < len(right):
            A[k] = right[j]
            j += 1
            k += 1
            steps.append((A.copy(), [k-1], comparisons))
        return comparisons

    merge_sort(A)
    return steps


def record_quick_sort(A):
    steps = []
    def quick_sort(A, start=0, end=None, comparisons=0):
        if end is None:
            end = len(A)
        if start < end - 1:
            pivot_index, comparisons = partition(A, start, end, comparisons)
            comparisons = quick_sort(A, start, pivot_index, comparisons)
            comparisons = quick_sort(A, pivot_index + 1, end, comparisons)
        return comparisons

    def partition(A, start, end, comparisons):
        pivot = A[end - 1]
        i = start
        for j in range(start, end - 1):
            comparisons += 1
            if A[j] <= pivot:
                A[i], A[j] = A[j], A[i]
                steps.append((A.copy(), [i, j], comparisons))
                i += 1
        A[i], A[end - 1] = A[end - 1], A[i]
        steps.append((A.copy(), [i, end-1], comparisons))
        return i, comparisons

    quick_sort(A)
    return steps


# ============================================================
# TEST
# ============================================================

if __name__ == "__main__":
    N = 40
    data = [random.randint(1, 100) for _ in range(N)]

    # Tri par sélection
    visualize(data.copy(), record_selection_sort(data.copy()), title="Tri par sélection")

    # Tri à bulles
    visualize(data.copy(), record_bubble_sort(data.copy()), title="Tri à bulles")

    # Tri par insertion
    visualize(data.copy(), record_insertion_sort(data.copy()), title="Tri par insertion")

    # Tri fusion
    visualize(data.copy(), record_merge_sort(data.copy()), title="Tri fusion")

    # Tri rapide
    visualize(data.copy(), record_quick_sort(data.copy()), title="Tri rapide")
