#!/usr/bin/env python3
"""
Computation 43 -- Phase 2 analytical-vs-numerical bridge norm check
====================================================================
Phase 2 analytical work (research/lcomm_relaxed_lemma.md sections 9-11)
posits closed-form formulas:

  || T_F ||_op,infinite-N = sup_{B^2} |F|
  || [D_alpha, T_F tensor I] ||_op,infinite-N
      = sup_{B^2} sqrt(|J_+ F|^2/2 + |J_- F|^2/2 + |J_z F|^2)

via the Leibniz identity [J_a, T_F] = T_{J_a F} for holomorphic F.

This script tests these formulas at D = 4, k = 1, alpha = 0, by:
  (a) computing the matrix norms of T_F and [D_alpha, T_F ⊗ I]
      at growing Bergman truncation N,
  (b) computing the analytical sup over B^2 by numerical
      grid search,
  (c) checking that (a) converges to (b) as N grows.

If this works at D=4, k=1, the analytical framework is validated
and Phase 2 can proceed with confidence to higher-D extrapolation.
"""
import math
import numpy as np
import numpy.linalg as la


sx = np.array([[0, 1], [1, 0]], dtype=complex)
sy = np.array([[0, -1j], [1j, 0]], dtype=complex)
sz = np.array([[1, 0], [0, -1]], dtype=complex)
I2 = np.eye(2, dtype=complex)


def op_norm(M):
    return float(la.norm(M, ord=2))


def basis_indices(N):
    return [(m1, m2) for m1 in range(N + 1) for m2 in range(N + 1 - m1)]


def log_norm_sq(m1, m2, alpha):
    return (math.lgamma(m1 + 1) + math.lgamma(m2 + 1)
            + math.lgamma(alpha + 3) - math.lgamma(m1 + m2 + alpha + 3))


def Tz_matrix(a, basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        Jp = list(J)
        Jp[a] += 1
        Jp = tuple(Jp)
        if Jp in idx:
            log_ratio = 0.5 * (log_norm_sq(Jp[0], Jp[1], alpha)
                               - log_norm_sq(J[0], J[1], alpha))
            M[idx[Jp], j] = math.exp(log_ratio)
    return M


def J_plus_matrix(basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m2 > 0:
            Jp = (m1 + 1, m2 - 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m2 * math.exp(log_ratio)
    return M


def J_minus_matrix(basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m1 > 0:
            Jp = (m1 - 1, m2 + 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m1 * math.exp(log_ratio)
    return M


def J_z_matrix(basis, alpha):
    n = len(basis)
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        M[j, j] = 0.5 * (m1 - m2)
    return M


def dirac_alpha(basis, alpha):
    Jp = J_plus_matrix(basis, alpha)
    Jm = J_minus_matrix(basis, alpha)
    Jz = J_z_matrix(basis, alpha)
    J1 = (Jp + Jm) / 2
    J2 = (Jp - Jm) / (2j)
    return np.kron(J1, sx) + np.kron(J2, sy) + np.kron(Jz, sz)


def F_symbol_D4_k1(z0, z1):
    return 0.5 * (z0 + z1 + z0**2 + z1**2)


def Jp_F_D4_k1(z0, z1):
    return 0.5 * (z0 * 1 + z0 * 2 * z1)


def Jm_F_D4_k1(z0, z1):
    return 0.5 * (z1 * 1 + z1 * 2 * z0)


def Jz_F_D4_k1(z0, z1):
    return 0.5 * (0.5 * z0 - 0.5 * z1 + z0**2 - z1**2)


def M_norm_squared(z0, z1):
    """
    Pointwise op-norm squared of M(z) = sigma_1 J_1 F + sigma_2 J_2 F + sigma_3 J_z F,
    where M(z) is the 2x2 matrix-valued symbol.

    For complex a, b, c: ||sigma_1 a + sigma_2 b + sigma_3 c||_op^2
       = (|a|^2 + |b|^2 + |c|^2) + sqrt(X^2 + |Y|^2)
    where (in terms of J_+/-/z F):
       (|a|^2 + |b|^2) = (|J_+ F|^2 + |J_- F|^2) / 2
       X = (|J_+ F|^2 - |J_- F|^2) / 2
       Y = 2i Im(conj(J_z F) * J_- F)
    """
    Jpf = Jp_F_D4_k1(z0, z1)
    Jmf = Jm_F_D4_k1(z0, z1)
    Jzf = Jz_F_D4_k1(z0, z1)
    naive = 0.5 * (abs(Jpf)**2 + abs(Jmf)**2) + abs(Jzf)**2
    X = 0.5 * (abs(Jpf)**2 - abs(Jmf)**2)
    Y_imag = (Jzf.conjugate() * Jmf).imag
    Y_sq = 4.0 * Y_imag**2
    return naive + math.sqrt(X**2 + Y_sq)


def analytical_sup_F():
    """Sup over the 3-sphere |z_0|^2 + |z_1|^2 = 1, three real parameters."""
    best = 0.0
    n_theta = 60
    n_phi = 60
    for i in range(n_theta + 1):
        theta = (i / n_theta) * (np.pi / 2)
        r = np.cos(theta)
        s = np.sin(theta)
        # Only RELATIVE phase between z_0 and z_1 matters for |F|, since F is a polynomial
        # in z_0, z_1 with no anti-holomorphic dependence; multiplying both by global
        # phase doesn't change |F| if and only if F is degree-homogeneous.  Here F is
        # MIXED degree, so global phase DOES matter.
        for j in range(n_phi + 1):
            phi0 = (j / n_phi) * (2 * np.pi)
            for k in range(n_phi + 1):
                phi1 = (k / n_phi) * (2 * np.pi)
                z0 = r * np.exp(1j * phi0)
                z1 = s * np.exp(1j * phi1)
                val = abs(F_symbol_D4_k1(z0, z1))
                if val > best:
                    best = val
    return best


def analytical_sup_M():
    """Sup over the 3-sphere of the matrix-valued symbol op-norm."""
    best = 0.0
    n_theta = 40
    n_phi = 40
    for i in range(n_theta + 1):
        theta = (i / n_theta) * (np.pi / 2)
        r = np.cos(theta)
        s = np.sin(theta)
        for j in range(n_phi + 1):
            phi0 = (j / n_phi) * (2 * np.pi)
            for k in range(n_phi + 1):
                phi1 = (k / n_phi) * (2 * np.pi)
                z0 = r * np.exp(1j * phi0)
                z1 = s * np.exp(1j * phi1)
                val = math.sqrt(M_norm_squared(z0, z1))
                if val > best:
                    best = val
    return best


def main():
    print("=" * 90)
    print("  Computation 43  --  Analytical vs numerical bridge / commutator norms")
    print("                       D = 4, k = 1, alpha = 0")
    print("=" * 90)
    print()

    alpha = 0.0

    print("  Analytical sup-norm calculation (grid search over partial B^2):")
    sup_F = analytical_sup_F()
    sup_M = analytical_sup_M()
    print(f"    sup |F(z_0, z_1)|       = {sup_F:.6f}")
    print(f"    sup |M(z_0, z_1)|       = {sup_M:.6f}")
    print(f"    sup |M| / sup |F|       = {sup_M / sup_F:.6f}")
    print(f"    ratio to 2*sqrt(1) = 2  = {sup_M / sup_F / 2:.6f}")
    print()

    print("  Closed-form diagonal evaluation (z_0 = z_1 = 1/sqrt(2)):")
    z = 1 / math.sqrt(2)
    F_diag = abs(F_symbol_D4_k1(z, z))
    M_diag = math.sqrt(M_norm_squared(z, z))
    print(f"    |F(diag)|               = {F_diag:.6f}    (analytic: (sqrt(2)+1)/2 = {(math.sqrt(2)+1)/2:.6f})")
    print(f"    |M(diag)|               = {M_diag:.6f}")
    print(f"    |M|/|F| at diag         = {M_diag / F_diag:.6f}")
    print()

    print("  Numerical matrix-norm computation at growing Bergman truncation N:")
    print()
    print(f"  {'N':>3}  {'dim':>5}  {'||T_F||_op':>14}  {'||[D, T_F⊗I]||_op':>20}  {'ratio':>10}  {'rescaled/2':>12}")

    for N in [5, 10, 15, 20, 25, 30, 35, 40]:
        basis = basis_indices(N)
        Tz0 = Tz_matrix(0, basis, alpha)
        Tz1 = Tz_matrix(1, basis, alpha)

        T_F = 0.5 * (Tz0 + Tz1 + Tz0 @ Tz0 + Tz1 @ Tz1)
        TF_norm = op_norm(T_F)

        D_a = dirac_alpha(basis, alpha)
        bridge_full = np.kron(T_F, I2)
        comm = D_a @ bridge_full - bridge_full @ D_a
        comm_norm = op_norm(comm)

        ratio = comm_norm / TF_norm
        rescaled = ratio / 2.0
        print(f"  {N:>3}  {len(basis):>5}  {TF_norm:>14.6f}  {comm_norm:>20.6f}  {ratio:>10.6f}  {rescaled:>12.6f}")

    print()
    print("=" * 90)
    print("  Verdict")
    print("=" * 90)
    print()
    print("  Expectation if analytical framework is correct:")
    print(f"    ||T_F||_op,N        -> sup |F| = {sup_F:.6f} as N -> infinity")
    print(f"    ||[D, T_F⊗I]||_op,N -> sup |M| = {sup_M:.6f} as N -> infinity")
    print(f"    ratio_N             -> sup |M| / sup |F| = {sup_M / sup_F:.6f}")
    print()
    print("  If matrix norms grow with N indefinitely, the framework needs revision.")


if __name__ == "__main__":
    main()
