# -*- coding: utf-8 -*-
"""
Created on Wed Sep  3 12:08:33 2025

@author: TSI2
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, RadioButtons
from matplotlib.gridspec import GridSpec

# =========================
# Paramètres et utilitaires
# =========================
FS_MAX = 2_000_000  # fréquence d'échantillonnage (2 MHz)
DURATION_MIN = 0.05  # durée par défaut (s) -> couvre bien 100 Hz (5 périodes = 0.05 s)
NFFT_POW2 = 131072   # taille FFT (puissance de 2, >= nb d'échantillons typiques)

# Conversions
def log_slider_to_freq(val_log10):
    """Mappe la valeur du slider (log10) -> fréquence en Hz."""
    return 10 ** val_log10

def freq_to_omega(f):
    return 2 * np.pi * f

# Génération des signaux d'entrée
def gen_input_signal(t, f0, kind="sine"):
    if kind == "sine":
        return np.sin(2*np.pi*f0 * t)
    elif kind == "square":
        # Créneau simple sans SciPy : sign(sin)
        return np.sign(np.sin(2*np.pi*f0 * t))
    else:
        return np.zeros_like(t)

# Spectre amplitude (module)
def amplitude_spectrum(x, fs):
    n = len(x)
    # Fenêtrage léger (Hann) pour réduire les fuites, surtout pour créneau
    w = np.hanning(n)
    xw = x * w
    X = np.fft.rfft(xw, n=n)
    f = np.fft.rfftfreq(n, d=1.0/fs)
    # Normalisation approximative (fenêtre + unilatéral)
    amp = (2.0 / np.sum(w)) * np.abs(X)
    return f, amp

# =========================
# Réponses fréquentielles H(jw)
# =========================
def H_lpf_2nd(omega, wc, Q=np.sqrt(2)/2):
    # H(s) = wc^2 / (s^2 + (wc/Q)s + wc^2), s = jω
    s = 1j * omega
    return (wc**2) / (s**2 + (wc/Q)*s + wc**2)

def H_hpf_2nd(omega, wc, Q=np.sqrt(2)/2):
    # H(s) = s^2 / (s^2 + (wc/Q)s + wc^2)
    s = 1j * omega
    return (s**2) / (s**2 + (wc/Q)*s + wc**2)

def H_bpf_3rd(omega, w0, Q=2.0):
    # 2e ordre band-pass : H2(s) = (s/(Q*w0)) / ((s/w0)^2 + (s/(Q*w0)) + 1)
    # 3e ordre obtenu en ajoutant un pôle 1er ordre à w0 : H3(s) = H2(s) * (1 / (1 + s/w0))
    s = 1j * omega
    H2 = (s/(Q*w0)) / ((s/w0)**2 + (s/(Q*w0)) + 1.0)
    H1 = 1.0 / (1.0 + s/w0)
    return H2 * H1

# Application du filtre en fréquence
def apply_filter_freq(x, fs, H_func, f_param, mode):
    """
    x : signal temporel
    fs : fréquence d'échantillonnage
    H_func : fonction H(jω, …)
    f_param : fréquence de coupure/centrale
    mode : 'lpf', 'hpf', 'bpf'
    """
    n = len(x)
    X = np.fft.rfft(x, n=n)
    freqs = np.fft.rfftfreq(n, d=1.0/fs)
    omega = 2*np.pi*freqs
    wc = 2*np.pi*f_param

    if mode == 'lpf':
        H = H_lpf_2nd(omega, wc)
    elif mode == 'hpf':
        H = H_hpf_2nd(omega, wc)
    else:  # 'bpf'
        H = H_bpf_3rd(omega, wc)

    Y = X * H
    y = np.fft.irfft(Y, n=n)
    return y, freqs, H

# =========================
# Figure et mise en page
# =========================
plt.close('all')
fig = plt.figure(figsize=(12, 9))
gs = GridSpec(3, 2, height_ratios=[1.0, 1.0, 1.3], hspace=0.35, wspace=0.3)

ax_e_t = fig.add_subplot(gs[0, 0])  # e(t)
ax_s_t = fig.add_subplot(gs[0, 1])  # s(t)
ax_E_f = fig.add_subplot(gs[1, 0])  # |E(f)|
ax_S_f = fig.add_subplot(gs[1, 1])  # |S(f)|
ax_bode = fig.add_subplot(gs[2, :]) # Bode (gain)

# =========================
# Paramètres initiaux
# =========================
f0_init = 1000.0      # Hz
fc_init = 2000.0      # Hz
input_kind = "sine"   # 'sine' ou 'square'
filter_mode = 'lpf'   # 'lpf', 'hpf', 'bpf'

# Durée d'affichage : 5 périodes du fondamental mais au moins DURATION_MIN
def compute_timebase(f0):
    T = 1.0 / f0
    duration = max(5*T, DURATION_MIN)
    # échantillonnage : garder fs fixe pour simplicité, mais clipper N
    fs = FS_MAX
    N = int(duration * fs)
    # limiter la taille pour garder l'interactivité fluide
    N = min(N, 200_000)  # 200k points max
    t = np.arange(N) / fs
    return t, fs

t, fs = compute_timebase(f0_init)
e = gen_input_signal(t, f0_init, input_kind)
s, freq_axis, H_init = apply_filter_freq(e, fs, H_lpf_2nd, fc_init, 'lpf')

# Courbes initiales
(line_e_t,) = ax_e_t.plot(t, e)
ax_e_t.set_title("Entrée e(t)")
ax_e_t.set_xlabel("Temps (s)")
ax_e_t.set_ylabel("Amplitude")
ax_e_t.grid(True)

(line_s_t,) = ax_s_t.plot(t, s)
ax_s_t.set_title("Sortie s(t)")
ax_s_t.set_xlabel("Temps (s)")
ax_s_t.set_ylabel("Amplitude")
ax_s_t.grid(True)

fE, ampE = amplitude_spectrum(e, fs)
(line_E_f,) = ax_E_f.plot(fE, ampE)
ax_E_f.set_title("Spectre |E(f)|")
ax_E_f.set_xlabel("Fréquence (Hz)")
ax_E_f.set_ylabel("Amplitude")
ax_E_f.set_xscale("log")
ax_E_f.grid(True, which='both')

fS, ampS = amplitude_spectrum(s, fs)
(line_S_f,) = ax_S_f.plot(fS, ampS)
ax_S_f.set_title("Spectre |S(f)|")
ax_S_f.set_xlabel("Fréquence (Hz)")
ax_S_f.set_ylabel("Amplitude")
ax_S_f.set_xscale("log")
ax_S_f.grid(True, which='both')

# Bode (gain uniquement)
f_bode = np.logspace(1, 6, 1000)  # 10 Hz → 1 MHz
w_bode = 2*np.pi*f_bode
H_bode = np.abs(H_lpf_2nd(w_bode, 2*np.pi*fc_init))
(line_bode,) = ax_bode.plot(f_bode, 20*np.log10(np.maximum(H_bode, 1e-12)))
ax_bode.set_title("Diagramme de Bode (gain)")
ax_bode.set_xlabel("Fréquence (Hz)")
ax_bode.set_ylabel("Gain (dB)")
ax_bode.set_xscale("log")
ax_bode.grid(True, which='both')

# =========================
# Widgets (sliders + radios)
# =========================
plt.subplots_adjust(left=0.1, right=0.95, bottom=0.19, top=0.95)

# Slider log f0 (fondamental d'entrée)
ax_f0 = plt.axes([0.10, 0.12, 0.55, 0.03])
slider_f0 = Slider(
    ax=ax_f0,
    label="f0 (Hz) [log]",
    valmin=np.log10(100.0),
    valmax=np.log10(100000.0),
    valinit=np.log10(f0_init),
)

# Slider log fc (coupure / centrale)
ax_fc = plt.axes([0.10, 0.08, 0.55, 0.03])
slider_fc = Slider(
    ax=ax_fc,
    label="fc (Hz) [log]",
    valmin=np.log10(50.0),
    valmax=np.log10(200000.0),
    valinit=np.log10(fc_init),
)

# Boutons radio filtre (gauche)
ax_filter = plt.axes([0.70, 0.065, 0.12, 0.11])
radio_filter = RadioButtons(
    ax_filter,
    labels=("Passe-bas 2ᵉ", "Passe-haut 2ᵉ", "Passe-bande 3ᵉ"),
    active=0
)

# Boutons radio type d'entrée (droite)
ax_input = plt.axes([0.85, 0.08, 0.10, 0.08])
radio_input = RadioButtons(
    ax_input,
    labels=("Sinus", "Créneau"),
    active=0
)

# =========================
# Callback de mise à jour
# =========================
def update(_=None):
    # Lire widgets
    f0 = log_slider_to_freq(slider_f0.val)
    fc = log_slider_to_freq(slider_fc.val)

    # Type d'entrée
    sel_in = radio_input.value_selected
    in_kind = "sine" if "Sinus" in sel_in else "square"

    # Mode filtre
    sel_f = radio_filter.value_selected
    if "Passe-bas" in sel_f:
        mode = 'lpf'
        Hfunc = H_lpf_2nd
    elif "Passe-haut" in sel_f:
        mode = 'hpf'
        Hfunc = H_hpf_2nd
    else:
        mode = 'bpf'
        Hfunc = H_bpf_3rd

    # Recalculer base de temps si nécessaire (f0 change)
    t_new, fs_new = compute_timebase(f0)
    # Regénérer entrée
    e_new = gen_input_signal(t_new, f0, in_kind)
    # Appliquer filtre
    s_new, f_axis, H_used = apply_filter_freq(e_new, fs_new, Hfunc, fc, mode)

    # Mettre à jour e(t), s(t)
    line_e_t.set_data(t_new, e_new)
    ax_e_t.set_xlim(t_new[0], t_new[-1])
    ax_e_t.set_ylim(1.1*np.min(e_new), 1.1*np.max(e_new))

    line_s_t.set_data(t_new, s_new)
    ax_s_t.set_xlim(t_new[0], t_new[-1])
    ymin, ymax = np.min(s_new), np.max(s_new)
    if np.isfinite(ymin) and np.isfinite(ymax) and (ymax - ymin) > 1e-12:
        pad = 0.1*(ymax - ymin)
        ax_s_t.set_ylim(ymin - pad, ymax + pad)

    # Spectres
    fE, ampE = amplitude_spectrum(e_new, fs_new)
    line_E_f.set_data(fE, ampE)
    ax_E_f.set_xlim(max(1.0, fE[1]), fE[-1])
    ax_E_f.set_ylim(0, np.max(ampE)*1.1 if np.max(ampE) > 0 else 1)

    fS, ampS = amplitude_spectrum(s_new, fs_new)
    line_S_f.set_data(fS, ampS)
    ax_S_f.set_xlim(max(1.0, fS[1]), fS[-1])
    ax_S_f.set_ylim(0, np.max(ampS)*1.1 if np.max(ampS) > 0 else 1)

    # Bode (gain)
    H_b = None
    if mode == 'lpf':
        H_b = np.abs(H_lpf_2nd(w_bode, 2*np.pi*fc))
    elif mode == 'hpf':
        H_b = np.abs(H_hpf_2nd(w_bode, 2*np.pi*fc))
    else:
        H_b = np.abs(H_bpf_3rd(w_bode, 2*np.pi*fc))
    line_bode.set_ydata(20*np.log10(np.maximum(H_b, 1e-12)))

    fig.canvas.draw_idle()

# Connexions
slider_f0.on_changed(update)
slider_fc.on_changed(update)
radio_filter.on_clicked(update)
radio_input.on_clicked(update)

# Premier rendu cohérent
update()

plt.show()
