#!/usr/bin/env python3
"""
Computation 38 -- SU(2) Dirac operator on H^2_alpha(B^2) tensor C^2 for L_comm
================================================================================
Computation 37 established that TENSOR spinor extensions on
H^2_alpha(B^2) tensor C^N decouple in the operator norm and therefore
do not close L_comm.  The structural conclusion was that a NON-TENSOR
coupling between the Bergman and spinor factors is required, with the
natural candidate being the round-S^3 Dirac operator

    D = sum_a sigma_a tensor J_a,    a = 1, 2, 3

where J_a are SU(2) left-invariant vector fields on H^2_alpha(B^2)
acting as raising/lowering operators on the Wigner-D monomial basis,
and the sigma_a tensor structure produces the genuine cross-coupling.
This computation implements that construction and tests the L_comm gap.

SU(2) STRUCTURE ON H^2_alpha(B^2).
SU(2) acts on B^2 = C^2 by the fundamental 2-dimensional representation.
The induced action on the monomial basis z_1^{m_1} z_2^{m_2} is via the
su(2) generators as differential operators:

    J_+ = z_1 d/dz_2,    J_- = z_2 d/dz_1,    J_z = (1/2)(z_1 d/dz_1 - z_2 d/dz_2)

On the orthonormal basis e_J = z^J / ||z^J||_alpha:
    J_+ e_J = sqrt((m_2)(m_1+1)) * sqrt(<norms>) e_{J + e_1 - e_2}
    J_- e_J = sqrt((m_1)(m_2+1)) * sqrt(<norms>) e_{J - e_1 + e_2}
    J_z e_J = (1/2)(m_1 - m_2) e_J

Total bidegree |J| = m_1 + m_2 is preserved, giving the spin-|J|/2
irrep of SU(2) on each bidegree-|J| subspace.

DIRAC OPERATOR.
With J_1 = (J_+ + J_-) / 2 (Hermitian),
     J_2 = (J_+ - J_-) / (2i) (Hermitian),
     J_3 = J_z (Hermitian),

the Dirac analog on H^2_alpha tensor C^2 is

    D_alpha = sum_a sigma_a tensor J_a    (Hermitian by construction)

This is a NON-TENSOR operator: it is the SUM of three tensor terms,
producing the cross-coupling that Computation 37 identified as needed.
The op-norm of the commutator [D_alpha, b(chi_S)] is invariant under
multiplication of D by a phase, so the i factor that appears in some
formulations does not change the L_comm gap test results.

L_COMM TEST.
The Walsh bridge from Computation 35 maps chi_S to b_holo(chi_S),
extended trivially to H^2_alpha tensor C^2 as b_holo(chi_S) tensor I_C2.
The L_comm gap is

    gap(S, alpha) = || [D_alpha, b_holo(chi_S) tensor I_C2] ||_op

The substrate side: ||[D_substrate, chi_S]||_op = 2 sqrt(|S|)
(Computation 23, exact).  For closure we need these to match up to
the holomorphic-Toeplitz norm scaling on the bridge image.

OUTPUT.
For D = 4, N = 5, sweep alpha and report:
  - Spectrum of D_alpha (compare to Friedrich's S^3 Dirac scaling).
  - Bridge-image norms || b_holo(chi_S) ||_op for representative S.
  - Dirac commutator norms || [D_alpha, b_holo(chi_S) tensor I] ||_op.
  - Comparison to substrate 2*sqrt(|S|).

Reports whether the non-tensor SU(2) Dirac coupling closes the gap.
"""
import math
import numpy as np
import numpy.linalg as la


sx = np.array([[0, 1], [1, 0]], dtype=complex)
sy = np.array([[0, -1j], [1j, 0]], dtype=complex)
sz = np.array([[1, 0], [0, -1]], dtype=complex)
I2 = np.eye(2, dtype=complex)


def kron_chain(ops):
    out = ops[0]
    for op in ops[1:]:
        out = np.kron(out, op)
    return out


def chi_walsh(D, S):
    return kron_chain([sz if a in S else I2 for a in range(D)])


def op_norm(M):
    return float(la.norm(M, ord=2))


def basis_indices(N):
    return [(m1, m2) for m1 in range(N + 1) for m2 in range(N + 1 - m1)]


def log_norm_sq(m1, m2, alpha):
    return (math.lgamma(m1 + 1) + math.lgamma(m2 + 1)
            + math.lgamma(alpha + 3) - math.lgamma(m1 + m2 + alpha + 3))


def Tz_matrix(a, basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        Jp = list(J)
        Jp[a] += 1
        Jp = tuple(Jp)
        if Jp in idx:
            log_ratio = 0.5 * (log_norm_sq(Jp[0], Jp[1], alpha)
                               - log_norm_sq(J[0], J[1], alpha))
            M[idx[Jp], j] = math.exp(log_ratio)
    return M


def J_plus_matrix(basis, alpha):
    """J_+ z^J = m_2 z^{J + e_1 - e_2}.  On ONB, include norm ratios."""
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m2 > 0:
            Jp = (m1 + 1, m2 - 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m2 * math.exp(log_ratio)
    return M


def J_minus_matrix(basis, alpha):
    """J_- z^J = m_1 z^{J - e_1 + e_2}."""
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m1 > 0:
            Jp = (m1 - 1, m2 + 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m1 * math.exp(log_ratio)
    return M


def J_z_matrix(basis, alpha):
    n = len(basis)
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        M[j, j] = 0.5 * (m1 - m2)
    return M


def site_to_holo(D, basis, alpha, Tz):
    return {0: Tz[0], 1: Tz[1], 2: Tz[0] @ Tz[0], 3: Tz[1] @ Tz[1]}


def bridge_holo(D, S, basis, alpha, Tz):
    site_map = site_to_holo(D, basis, alpha, Tz)
    n = len(basis)
    out = np.eye(n, dtype=complex)
    for a in sorted(S):
        out = out @ site_map[a]
    return out


def main():
    D = 4
    N = 5
    basis = basis_indices(N)
    dim_Toep = len(basis)
    alphas = [0.0, 1.0, 2.0, 5.0]

    print("=" * 90)
    print("  Computation 38  --  SU(2) Dirac operator on H^2_alpha(B^2) tensor C^2 for L_comm")
    print("=" * 90)
    print()
    print(f"  D = {D}, N = {N}, dim H^2_alpha = {dim_Toep}, dim full = {dim_Toep * 2}")
    print()

    # ============================================
    # Step 1: build SU(2) generators and Dirac
    # ============================================
    print("  Step 1: SU(2) generators on H^2_alpha(B^2) and Dirac operator")
    print()

    for alpha in [0.0, 1.0]:
        Jp = J_plus_matrix(basis, alpha)
        Jm = J_minus_matrix(basis, alpha)
        Jz = J_z_matrix(basis, alpha)

        # Check su(2) commutation [J_+, J_-] = 2 J_z
        comm_pm = Jp @ Jm - Jm @ Jp
        check_pm = comm_pm - 2 * Jz
        # Check [J_z, J_+] = J_+
        comm_zp = Jz @ Jp - Jp @ Jz
        check_zp = comm_zp - Jp

        print(f"  alpha = {alpha}: su(2) algebra checks")
        print(f"    || [J_+, J_-] - 2 J_z ||  =  {op_norm(check_pm):.4e}")
        print(f"    || [J_z, J_+] - J_+   ||  =  {op_norm(check_zp):.4e}")
    print()

    # Dirac at sample alpha
    alpha = 0.0
    Jp = J_plus_matrix(basis, alpha)
    Jm = J_minus_matrix(basis, alpha)
    Jz = J_z_matrix(basis, alpha)
    J1 = (Jp + Jm) / 2
    J2 = (Jp - Jm) / (2j)
    D_a = np.kron(J1, sx) + np.kron(J2, sy) + np.kron(Jz, sz)

    # Check Hermiticity
    print(f"  Dirac operator at alpha = {alpha}:")
    print(f"    || D - D^* || = {op_norm(D_a - D_a.conj().T):.4e}  (should be 0; Hermitian)")
    print(f"    || D ||_op   = {op_norm(D_a):.4f}")

    # Dirac eigenvalues
    eigs = sorted(np.real(la.eigvalsh(D_a)))
    print(f"    Eigenvalue range: [{eigs[0]:.4f}, {eigs[-1]:.4f}]")
    print(f"    First few: {[f'{e:.3f}' for e in eigs[:6]]}")
    print(f"    Friedrich S^3 Dirac eigenvalues: +/-(n + 3/2),  n = 0, 1, 2, ...")
    print()

    # ============================================
    # Step 2: L_comm test
    # ============================================
    print("  Step 2: L_comm gap test")
    print()
    print("  Substrate side: ||[D_substrate, chi_S]||_op = 2*sqrt(|S|)  (Computation 23)")
    print()
    print("  Toeplitz side: ||[D_alpha, b_holo(chi_S) tensor I_C2]||_op")
    print()
    print(f"  {'case':>14}  {'|S|':>3}  {'2*sqrt(|S|)':>12}  " + "  ".join(f"a={alpha:5.1f}" for alpha in alphas))

    cases = [
        ({0}, 1),
        ({1}, 1),
        ({0, 1}, 2),
        ({0, 2}, 2),
        ({1, 2}, 2),
        ({0, 1, 2}, 3),
        ({0, 2, 3}, 3),
    ]
    for S, k in cases:
        sub_norm = 2 * math.sqrt(k)
        row = f"  S={str(set(S)):<10}  {k:>3}  {sub_norm:>12.4f}  "
        for alpha in alphas:
            Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
            Jp = J_plus_matrix(basis, alpha)
            Jm = J_minus_matrix(basis, alpha)
            Jz = J_z_matrix(basis, alpha)
            J1 = (Jp + Jm) / 2
            J2 = (Jp - Jm) / (2j)
            D_a = np.kron(J1, sx) + np.kron(J2, sy) + np.kron(Jz, sz)
            b_S_holo = bridge_holo(D, S, basis, alpha, Tz)
            b_S = np.kron(b_S_holo, I2)
            comm = D_a @ b_S - b_S @ D_a
            row += f"  {op_norm(comm):>10.4f}"
        print(row)
    print()

    # ============================================
    # Step 3: normalised comparison
    # ============================================
    print("  Step 3: comparison normalised by bridge fidelity")
    print()
    print("  Ratio: ||[D_alpha, b(chi_S) tensor I]||_op / ||b(chi_S)||_op  vs  2 sqrt(|S|)")
    print()
    print(f"  {'case':>14}  {'|S|':>3}  {'target':>10}  " + "  ".join(f"a={alpha:5.1f}" for alpha in alphas))
    for S, k in cases:
        sub_target = 2 * math.sqrt(k)
        row = f"  S={str(set(S)):<10}  {k:>3}  {sub_target:>10.4f}  "
        for alpha in alphas:
            Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
            Jp = J_plus_matrix(basis, alpha)
            Jm = J_minus_matrix(basis, alpha)
            Jz = J_z_matrix(basis, alpha)
            J1 = (Jp + Jm) / 2
            J2 = (Jp - Jm) / (2j)
            D_a = np.kron(J1, sx) + np.kron(J2, sy) + np.kron(Jz, sz)
            b_S_holo = bridge_holo(D, S, basis, alpha, Tz)
            b_norm = op_norm(b_S_holo)
            b_S = np.kron(b_S_holo, I2)
            comm = D_a @ b_S - b_S @ D_a
            ratio = op_norm(comm) / max(b_norm, 1e-12)
            row += f"  {ratio:>10.4f}"
        print(row)
    print()

    # ============================================
    # Verdict
    # ============================================
    print("=" * 90)
    print("  Verdict")
    print("=" * 90)
    print()
    print("  The SU(2) Dirac operator on H^2_alpha(B^2) tensor C^2 has the expected")
    print("  su(2) algebra structure ([J_+, J_-] = 2 J_z to machine precision when the")
    print("  truncation does not interfere) and is Hermitian.")
    print()
    print("  The L_comm commutator [D_alpha, b_holo(chi_S) tensor I] is NON-ZERO --")
    print("  the non-tensor coupling between Bergman and spinor factors finally produces")
    print("  a meaningful Dirac commutator.")
    print()
    print("  However, the numerical values do not yet match the substrate target")
    print("  2*sqrt(|S|).  Several factors contribute:")
    print()
    print("    1. The bridge image norm ||b_holo(chi_S)|| < 1 (Computation 35).")
    print("    2. The su(2) raising/lowering operators take VALUES proportional to")
    print("       sqrt(m * (k - m + 1)) on monomial e_{m, k-m}, peaking at the centre")
    print("       of each bidegree-k irrep -- different scaling from substrate 2*sqrt(|S|).")
    print("    3. The truncation at N = 5 limits the highest bidegree captured.")
    print()
    print("  Next concrete step: tune alpha = alpha(D) so that the Dirac commutator")
    print("  matches 2*sqrt(|S|) on average (or on the highest-weight vectors),")
    print("  and verify the exponential closure rate gamma_D -> 0 as D -> infty.")
    print()
    print("  This represents the FIRST quantitative test of the full SU(2)-coupled")
    print("  Dirac bridge for L_comm.  The framework is now operational; the closure")
    print("  rate analysis is what remains.")


if __name__ == "__main__":
    main()
