#!/usr/bin/env python3
"""
Computation 41 -- L_comm homomorphism-gap test (explicit construction)
=======================================================================
Computations 38-40 measured the NORM of [D_alpha, bridge(chi_S)] vs
the substrate norm 2*sqrt(|S|).  This is NOT the genuine L_comm
criterion -- it tests whether the bridge produces Dirac commutators
of approximately the right size, but does not test whether the bridge
preserves the Dirac action.

The genuine L_comm criterion:

    HOMOMORPHISM GAP  =  || bridge([D_sub, chi_S]) - [D_alpha, bridge(chi_S)] ||_op

For an exact homomorphism this is zero.  L_comm closure asks the gap
shrink exponentially in D.

CONSTRUCTION OF THE BRIDGE ON [D_sub, chi_S].

The substrate Dirac commutator decomposes:

    [D_sub, chi_S^Walsh]  =  2 sum_{a in S} chi_a^{Cliff} * chi_S^Walsh

Natural multiplicative-bridge extension:

    bridge(chi_a^{Cliff})  :=  J_a tensor sigma_a   (frame-component of D_alpha)
    bridge(chi_S^Walsh)    :=  b_holo(chi_S) tensor I_C2
    bridge(prod)           :=  prod bridge(factor)

For substrate sites a in {0, 1, 2, 3} we need a Pauli assignment.  Use:

    site 0, 1, 2 -> sigma_1, sigma_2, sigma_3
    site 3 -> sigma_1 (the cyclic 4-th element; introduces structural
                       degeneracy that the test will surface)

With these choices, the homomorphism gap has explicit form:

    bridge_LHS - bridge_RHS  =  sum_a [ 2 [a in S] J_a b_holo
                                       - [J_a, b_holo] ] tensor sigma_a
                              =  sum_{a in S} {J_a, b_holo} tensor sigma_a
                                 - sum_{a not in S} [J_a, b_holo] tensor sigma_a

For exact homomorphism we need {J_a, b_holo} = 0 for a in S AND
[J_a, b_holo] = 0 for a not in S, simultaneously.  These are STRONG
conditions on b_holo unlikely to hold without special structure.

OUTPUT.
For D in {4, 6} and various alpha: compute the homomorphism gap and
its decomposition into "in-S anti-commutator" and "out-of-S
commutator" terms.  Identifies which terms dominate and what
structural fix would zero them out.
"""
import math
from itertools import combinations
import numpy as np
import numpy.linalg as la


sx = np.array([[0, 1], [1, 0]], dtype=complex)
sy = np.array([[0, -1j], [1j, 0]], dtype=complex)
sz = np.array([[1, 0], [0, -1]], dtype=complex)
I2 = np.eye(2, dtype=complex)


def kron_chain(ops):
    out = ops[0]
    for op in ops[1:]:
        out = np.kron(out, op)
    return out


def chi_walsh(D, S):
    return kron_chain([sz if a in S else I2 for a in range(D)])


def chi_clifford(D, a):
    return kron_chain([sz] * a + [sx] + [I2] * (D - 1 - a))


def op_norm(M):
    return float(la.norm(M, ord=2))


def basis_indices(N):
    return [(m1, m2) for m1 in range(N + 1) for m2 in range(N + 1 - m1)]


def log_norm_sq(m1, m2, alpha):
    return (math.lgamma(m1 + 1) + math.lgamma(m2 + 1)
            + math.lgamma(alpha + 3) - math.lgamma(m1 + m2 + alpha + 3))


def Tz_matrix(a, basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        Jp = list(J)
        Jp[a] += 1
        Jp = tuple(Jp)
        if Jp in idx:
            log_ratio = 0.5 * (log_norm_sq(Jp[0], Jp[1], alpha)
                               - log_norm_sq(J[0], J[1], alpha))
            M[idx[Jp], j] = math.exp(log_ratio)
    return M


def J_plus_matrix(basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m2 > 0:
            Jp = (m1 + 1, m2 - 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m2 * math.exp(log_ratio)
    return M


def J_minus_matrix(basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m1 > 0:
            Jp = (m1 - 1, m2 + 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m1 * math.exp(log_ratio)
    return M


def J_z_matrix(basis, alpha):
    n = len(basis)
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        M[j, j] = 0.5 * (m1 - m2)
    return M


def site_pauli_pair(site, basis, alpha):
    """Return (J_a, sigma_a) for the site->Pauli assignment.

    site 0 -> (J_1, sigma_1)
    site 1 -> (J_2, sigma_2)
    site 2 -> (J_z, sigma_3)
    site 3 -> (J_1, sigma_1)  (cyclic; the 4-vs-3 mismatch)
    """
    Jp = J_plus_matrix(basis, alpha)
    Jm = J_minus_matrix(basis, alpha)
    Jz = J_z_matrix(basis, alpha)
    J1 = (Jp + Jm) / 2
    J2 = (Jp - Jm) / (2j)
    mapping = {
        0: (J1, sx),
        1: (J2, sy),
        2: (Jz, sz),
        3: (J1, sx),
    }
    return mapping.get(site % 4)


def dirac_alpha(basis, alpha):
    """D_alpha = sum_a sigma_a tensor J_a  (a = 1, 2, 3)."""
    Jp = J_plus_matrix(basis, alpha)
    Jm = J_minus_matrix(basis, alpha)
    Jz = J_z_matrix(basis, alpha)
    J1 = (Jp + Jm) / 2
    J2 = (Jp - Jm) / (2j)
    return np.kron(J1, sx) + np.kron(J2, sy) + np.kron(Jz, sz)


def site_to_holo(D, basis, alpha, Tz):
    site_map = {}
    for site in range(D):
        coord = site % 2
        power = site // 2 + 1
        op = Tz[coord]
        for _ in range(power - 1):
            op = op @ Tz[coord]
        site_map[site] = op
    return site_map


def bridge_holo(D, S, basis, alpha, Tz):
    site_map = site_to_holo(D, basis, alpha, Tz)
    n = len(basis)
    out = np.eye(n, dtype=complex)
    for a in sorted(S):
        out = out @ site_map[a]
    return out


def main():
    print("=" * 90)
    print("  Computation 41  --  L_comm homomorphism-gap test (explicit construction)")
    print("=" * 90)
    print()
    print("  Test:  || bridge([D_sub, chi_S]) - [D_alpha, bridge(chi_S)] ||_op")
    print("  with bridge(chi_a^Cliff) := J_a tensor sigma_a (frame-component of D_alpha)")
    print("  and  bridge(chi_S^Walsh) := b_holo(chi_S) tensor I_C2")
    print()

    for D, N in [(4, 5), (4, 10), (6, 12)]:
        basis = basis_indices(N)
        alpha = 0.0  # focus on alpha = 0 for the homomorphism diagnostic
        Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
        D_a = dirac_alpha(basis, alpha)

        print(f"  D = {D}, N = {N}, Bergman dim = {len(basis)}, alpha = {alpha}")
        print()
        print(f"  {'case':>18}  {'||LHS||_op':>10}  {'||RHS||_op':>10}  {'||gap||_op':>12}  {'gap/LHS':>10}")
        # Test cases by weight
        cases = []
        for k in range(1, min(D + 1, 5)):
            # Pick first C(D, k) modes (up to 4 for speed)
            for S in list(combinations(range(D), k))[:3]:
                cases.append(set(S))

        for S in cases:
            k = len(S)
            b_S_holo = bridge_holo(D, S, basis, alpha, Tz)
            bridge_chi_S = np.kron(b_S_holo, I2)

            # LHS: bridge([D_sub, chi_S]) = 2 sum_{a in S} bridge(chi_a^Cliff) * bridge(chi_S)
            #    = 2 sum_{a in S} (J_a tensor sigma_a) (b_holo tensor I)
            #    = 2 sum_{a in S} (J_a b_holo) tensor sigma_a
            LHS = np.zeros_like(bridge_chi_S)
            for a in sorted(S):
                Ja, sigma_a = site_pauli_pair(a, basis, alpha)
                LHS = LHS + 2 * np.kron(Ja @ b_S_holo, sigma_a)

            # RHS: [D_alpha, bridge(chi_S)]
            RHS = D_a @ bridge_chi_S - bridge_chi_S @ D_a

            gap = LHS - RHS
            n_LHS = op_norm(LHS)
            n_RHS = op_norm(RHS)
            n_gap = op_norm(gap)
            ratio = n_gap / max(n_LHS, 1e-12)
            case_str = f"k={k}, S={str(S):<10}"
            print(f"  {case_str:>18}  {n_LHS:>10.4f}  {n_RHS:>10.4f}  {n_gap:>12.4f}  {ratio:>10.4f}")
        print()

    print()
    print("=" * 90)
    print("  Verdict")
    print("=" * 90)
    print()
    print("  The homomorphism gap || LHS - RHS ||_op directly measures L_comm failure.")
    print("  A gap close to zero means the bridge is a Dirac homomorphism on that")
    print("  weight class.  A gap close to ||LHS||_op means the bridge fails completely.")
    print()
    print("  The ratio gap / ||LHS|| at fixed weight indicates which fraction of the")
    print("  substrate Dirac action is preserved.  If this ratio shrinks with (D, N),")
    print("  the bridge is asymptotically a homomorphism (L_comm closure).")
    print()
    print("  Structural note on the 4-vs-3 Pauli mismatch: substrate has D Clifford")
    print("  sites, but the SU(2) Dirac uses only 3 Pauli matrices.  Sites 3, 4, 5, ...")
    print("  must be assigned cyclically or grouped, introducing degeneracy.  A clean")
    print("  L_comm closure may require a larger spinor fibre (C^{2^D} matching the")
    print("  substrate's Cl(0, 2D) representation) rather than C^2.")


if __name__ == "__main__":
    main()
