import sys
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from PyQt5.QtWidgets import (
    QApplication,
    QMainWindow,
    QVBoxLayout,
    QWidget,
    QSlider,
    QLabel,
    QComboBox,
    QHBoxLayout,
)
from PyQt5.QtCore import Qt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.patches import FancyArrowPatch, Arc

plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "EB Garamond",
        "axes.labelsize": "x-large",
        "xtick.labelsize": "x-large",
        "ytick.labelsize": "x-large",
    }
)
plt.rcParams["figure.facecolor"] = "w"

mpl.rc("text.latex", preamble=r"\usepackage{xcolor}\usepackage{amsmath}")


def compute_Q(R, L, C):
    if R == 0 or C == 0:
        return np.inf
    return 1 / R * np.sqrt(L / C)


class RCFilterWidget(QWidget):
    def __init__(self):
        super().__init__()
        # Initial values
        self.R = 1.0
        self.L = 1.0
        self.C = 1.0
        self.init_w = 1 / (self.R * self.C)
        self.w = np.linspace(0, 10 * self.init_w, 500)
        self.init_ui()

    def get_transfer_functions(self):
        R, L, C = self.R, self.L, self.C
        return {
            "RC (C voltage)": lambda w: 1 / (1 + 1j * w * R * C),
            "RC (R voltage)": lambda w: 1j * w * R * C / (1 + 1j * w * R * C),
            "RL (R voltage)": lambda w: R / (R + 1j * w * L),
            "RL (L voltage)": lambda w: 1j * w * L / (R + 1j * w * L),
            "RLC (C voltage)": lambda w: 1 / (1 - w**2 * L * C + 1j * w * R * C),
            "RLC (R voltage)": lambda w: (1j * w * R * C)
            / (1 - w**2 * L * C + 1j * w * R * C),
            "RLC (L voltage)": lambda w: (w**2 * L * C)
            / (1 - w**2 * L * C + 1j * w * R * C),
        }

    def init_ui(self):
        layout = QVBoxLayout(self)

        # Matplotlib Figure
        self.fig, self.axs = plt.subplots(1, 3, figsize=(15, 5))
        self.fig.subplots_adjust(bottom=0.18)
        self.canvas = FigureCanvas(self.fig)
        layout.addWidget(self.canvas)

        # Complex plane
        ax0 = self.axs[0]
        unit_circle = plt.Circle((0, 0), 1, color="gray", fill=False, linestyle="--")
        ax0.add_artist(unit_circle)
        ax0.axhline(0, color="black", lw=1)
        ax0.axvline(0, color="black", lw=1)
        ax0.set_xlim(-1.2, 1.2)
        ax0.set_ylim(-1.2, 1.2)
        ax0.set_aspect("equal")
        ax0.set_title("Amplitude complexe")
        ax0.set_xlabel("Re")
        ax0.set_ylabel("Im")

        # Initial transfer function
        self.current_tf = list(self.get_transfer_functions().keys())[0]
        H_func = self.get_transfer_functions()[self.current_tf]
        H_vals = H_func(self.w / self.init_w)
        modulus = np.abs(H_vals)
        argument = np.angle(H_vals)

        # Arrow for complex amplitude
        H_init = H_func(self.init_w / self.init_w)
        self.vector = FancyArrowPatch(
            (0, 0),
            (np.real(H_init), np.imag(H_init)),
            color="#d62626",
            arrowstyle="->",
            mutation_scale=20,
            lw=3,
        )
        ax0.add_patch(self.vector)

        # Arc for phase
        angle_init = np.angle(H_init)
        angle_init_deg = np.degrees(angle_init)
        if angle_init_deg >= 0:
            theta1_init = 0
            theta2_init = angle_init_deg
        else:
            theta1_init = angle_init_deg
            theta2_init = 0
        self.arc = Arc(
            (0, 0),
            0.6,
            0.6,
            angle=0,
            theta1=theta1_init,
            theta2=theta2_init,
            color="#2e8b73",
            lw=2,
        )
        ax0.add_patch(self.arc)

        # Initial position for arrow tip
        arrow_tip_x = np.real(H_init)
        arrow_tip_y = np.imag(H_init)
        self.arrow_text = ax0.text(
            arrow_tip_x,
            arrow_tip_y,
            r"$\underline{H}(x)$",
            color="#733bb3",
            fontsize=14,
            va="center",
            ha="left",
        )

        # Initial position for arrow middle
        self.middle_text = ax0.text(
            arrow_tip_x / 2,
            arrow_tip_y / 2,
            r"$G(x)$",
            color="#d62626",
            fontsize=14,
            va="bottom",
            ha="right",
        )

        # Initial position for arc label (place at middle of arc)
        arc_angle = np.angle(H_init)
        arc_radius = 0.4
        arc_label_x = arc_radius * np.cos(arc_angle / 2)
        arc_label_y = arc_radius * np.sin(arc_angle / 2)
        self.arc_text = ax0.text(
            arc_label_x,
            arc_label_y,
            r"$\Delta\varphi_{s/e}(x)$",
            color="#2e8b73",
            fontsize=14,
            va="center",
            ha="left",
        )

        # Modulus plot
        ax1 = self.axs[1]
        (self.mod_line,) = ax1.plot(
            self.w / self.init_w, modulus, ls="-", color="#d62626"
        )
        ax1.set_title("Gain selon pulsation")
        ax1.set_xlabel(r"$x = \displaystyle\frac{\omega}{\omega_0}$")
        ax1.set_ylabel(r"$G(x)$")
        ax1.grid(True)
        (self.mod_point,) = ax1.plot(
            [self.init_w / self.init_w], [np.abs(H_init)], "ko"
        )

        # Argument plot
        ax2 = self.axs[2]
        (self.arg_line,) = ax2.plot(
            self.w / self.init_w, argument, ls="-", color="#2e8b73"
        )
        ax2.set_title("Déphasage selon pulsation")
        ax2.set_xlabel(r"$x = \displaystyle\frac{\omega}{\omega_0}$")
        ax2.set_ylabel(r"$\Delta\varphi_{s/e}(x)$ (rad)")
        ax2.grid(True)
        (self.arg_point,) = ax2.plot(
            [self.init_w / self.init_w], [np.angle(H_init)], "ko"
        )

        # Slider
        self.slider_label = QLabel(
            "Pulsation ($\\omega/\\omega_0$): {:.2f}".format(self.init_w / self.init_w)
        )
        layout.addWidget(self.slider_label)
        self.slider = QSlider(Qt.Horizontal)
        self.slider.setMinimum(0)
        self.slider.setMaximum(1000)
        self.slider.setValue(100)
        self.slider.setTickInterval(1)
        layout.addWidget(self.slider)
        self.slider.valueChanged.connect(self.update_plot)

        # Dropdown for transfer function selection
        self.dropdown = QComboBox()
        self.dropdown.addItems(list(self.get_transfer_functions().keys()))
        layout.addWidget(self.dropdown)
        self.dropdown.currentIndexChanged.connect(self.change_transfer_function)

        # R, L, C sliders
        param_layout = QHBoxLayout()
        self.r_label = QLabel(f"R = {self.R:.2f}")
        self.r_slider = QSlider(Qt.Horizontal)
        self.r_slider.setMinimum(1)
        self.r_slider.setMaximum(1000)
        self.r_slider.setValue(int(self.R * 100))
        self.r_slider.valueChanged.connect(self.change_params)
        param_layout.addWidget(self.r_label)
        param_layout.addWidget(self.r_slider)

        self.l_label = QLabel(f"L = {self.L:.2f}")
        self.l_slider = QSlider(Qt.Horizontal)
        self.l_slider.setMinimum(1)
        self.l_slider.setMaximum(1000)
        self.l_slider.setValue(int(self.L * 100))
        self.l_slider.valueChanged.connect(self.change_params)
        param_layout.addWidget(self.l_label)
        param_layout.addWidget(self.l_slider)

        self.c_label = QLabel(f"C = {self.C:.2f}")
        self.c_slider = QSlider(Qt.Horizontal)
        self.c_slider.setMinimum(1)
        self.c_slider.setMaximum(1000)
        self.c_slider.setValue(int(self.C * 100))
        self.c_slider.valueChanged.connect(self.change_params)
        param_layout.addWidget(self.c_label)
        param_layout.addWidget(self.c_slider)

        layout.addLayout(param_layout)

        # Q display
        self.q_label = QLabel(f"Q = {compute_Q(self.R, self.L, self.C):.2f}")
        layout.addWidget(self.q_label)

    def update_plot(self, value):
        puls_ratio = value / 100.0
        puls = puls_ratio * self.init_w
        H_func = self.get_transfer_functions()[self.dropdown.currentText()]
        H_val = H_func(puls / self.init_w)
        self.slider_label.setText(
            "Pulsation ($\\omega/\\omega_0$): {:.2f}".format(puls_ratio)
        )

        # Arrow update
        self.vector.set_positions((0, 0), (np.real(H_val), np.imag(H_val)))

        # Points update
        self.mod_point.set_data([puls_ratio], [np.abs(H_val)])
        self.arg_point.set_data([puls_ratio], [np.angle(H_val)])

        # Arc update
        angle = np.angle(H_val)
        angle_deg = np.degrees(angle)
        if angle_deg >= 0:
            self.arc.theta1 = 0
            self.arc.theta2 = angle_deg
        else:
            self.arc.theta1 = angle_deg
            self.arc.theta2 = 0

        # Arrow tip position
        arrow_tip_x = np.real(H_val)
        arrow_tip_y = np.imag(H_val)
        self.arrow_text.set_position((arrow_tip_x, arrow_tip_y))

        # Middle tip position
        self.middle_text.set_position((arrow_tip_x / 2, arrow_tip_y / 2))

        # Arc label position (middle of arc)
        arc_angle = np.angle(H_val)
        arc_radius = 0.4
        arc_label_x = arc_radius * np.cos(arc_angle / 2)
        arc_label_y = arc_radius * np.sin(arc_angle / 2)
        self.arc_text.set_position((arc_label_x, arc_label_y))
        self.canvas.draw_idle()

    def change_transfer_function(self, index):
        H_func = self.get_transfer_functions()[self.dropdown.currentText()]
        H_vals = H_func(self.w / self.init_w)
        modulus = np.abs(H_vals)
        argument = np.angle(H_vals)
        self.mod_line.set_ydata(modulus)
        self.arg_line.set_ydata(argument)
        # Autoscale
        self.axs[1].relim()
        self.axs[1].autoscale_view()
        self.axs[2].relim()
        self.axs[2].autoscale_view()
        self.update_plot(self.slider.value())
        self.canvas.draw_idle()

    def change_params(self, value):
        self.R = self.r_slider.value() / 100.0
        self.L = self.l_slider.value() / 100.0
        self.C = self.c_slider.value() / 100.0
        self.r_label.setText(f"R = {self.R:.2f}")
        self.l_label.setText(f"L = {self.L:.2f}")
        self.c_label.setText(f"C = {self.C:.2f}")
        self.q_label.setText(f"Q = {compute_Q(self.R, self.L, self.C):.2f}")
        # Update frequency axis and transfer function
        self.init_w = 1 / (self.R * self.C)
        self.w = np.linspace(0, 10 * self.init_w, 500)
        H_func = self.get_transfer_functions()[self.dropdown.currentText()]
        H_vals = H_func(self.w / self.init_w)
        modulus = np.abs(H_vals)
        argument = np.angle(H_vals)
        self.mod_line.set_xdata(self.w / self.init_w)
        self.mod_line.set_ydata(modulus)
        self.arg_line.set_xdata(self.w / self.init_w)
        self.arg_line.set_ydata(argument)
        # Autoscale
        self.axs[1].relim()
        self.axs[1].autoscale_view()
        self.axs[2].relim()
        self.axs[2].autoscale_view()
        self.update_plot(self.slider.value())
        self.canvas.draw_idle()


class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("RC/RLC Filter Transfer Function Animation")
        self.widget = RCFilterWidget()
        self.setCentralWidget(self.widget)


if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = MainWindow()
    window.show()
    sys.exit(app.exec_())
