#!/usr/bin/env python3
"""
Computation 39 -- Mosco averaging + alpha rescaling for L_comm closure
=======================================================================
Computation 38 brought the L_comm framework to operational status:
the SU(2) Dirac D = sum_a sigma_a tensor J_a on H^2_alpha(B^2) tensor C^2
produces non-zero Dirac commutators with the holomorphic Walsh bridge,
and the normalised ratios approach the substrate target 2*sqrt(|S|).
Two within-framework gaps remained:

  (1) VARIANCE across same-weight S.  Different Walsh modes chi_S
      with the same weight |S| = k produced different commutator
      norms because they map to different SU(2) irrep states (z_0^k
      is the highest-weight state, z_0^{k-1} z_1 is one rung below,
      etc.).  Substrate side has uniform 2*sqrt(|S|) per weight class.

  (2) ALPHA-INDEPENDENCE of ratios.  The SU(2) generators J_a on the
      ONB have alpha-canceling norm factors, so the Dirac commutator
      itself is essentially alpha-independent (the alpha-dependence
      enters only via the bridge image norm).

This computation tests the two within-framework tunings:

  MOSCO AVERAGING.  Replace bridge(chi_S) for individual S with a
  WEIGHT-CLASS-AVERAGED bridge: at weight k, define

      b_Mosco_k := (1 / sqrt(C(D, k))) sum_{|T| = k} b_holo(chi_T)

  so every weight-k Walsh mode maps to the same bridge image.  This
  collapses C(D, k) modes to one symmetric combination and removes
  the variance across same-weight S.

  ALPHA RESCALING.  Multiply the Mosco-averaged bridge by an
  alpha-dependent factor c_k(alpha) chosen to make the bridge image
  have unit op-norm (matching ||chi_S||_op = 1 on the substrate):

      bridge_final_k(chi_S) := (1 / || b_Mosco_k ||_op) * b_Mosco_k

  The rescaling is per-weight-class (k = |S|) but uniform across S
  within a class.

L_COMM TEST.
For each (S, k = |S|), measure
    || [D_alpha, bridge_final_k ⊗ I_C2] ||_op   vs   2 sqrt(k)

If the ratio converges to 1 (or some specific limit independent of
S within each weight class), the L_comm closure is reachable in the
operational framework.

OUTPUT.
D = 4, N = 5.  Sweep alpha; report bridge image norms, Dirac
commutator norms, and ratios to 2*sqrt(k).
"""
import math
from itertools import combinations
import numpy as np
import numpy.linalg as la


sx = np.array([[0, 1], [1, 0]], dtype=complex)
sy = np.array([[0, -1j], [1j, 0]], dtype=complex)
sz = np.array([[1, 0], [0, -1]], dtype=complex)
I2 = np.eye(2, dtype=complex)


def kron_chain(ops):
    out = ops[0]
    for op in ops[1:]:
        out = np.kron(out, op)
    return out


def op_norm(M):
    return float(la.norm(M, ord=2))


def basis_indices(N):
    return [(m1, m2) for m1 in range(N + 1) for m2 in range(N + 1 - m1)]


def log_norm_sq(m1, m2, alpha):
    return (math.lgamma(m1 + 1) + math.lgamma(m2 + 1)
            + math.lgamma(alpha + 3) - math.lgamma(m1 + m2 + alpha + 3))


def Tz_matrix(a, basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        Jp = list(J)
        Jp[a] += 1
        Jp = tuple(Jp)
        if Jp in idx:
            log_ratio = 0.5 * (log_norm_sq(Jp[0], Jp[1], alpha)
                               - log_norm_sq(J[0], J[1], alpha))
            M[idx[Jp], j] = math.exp(log_ratio)
    return M


def J_plus_matrix(basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m2 > 0:
            Jp = (m1 + 1, m2 - 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m2 * math.exp(log_ratio)
    return M


def J_minus_matrix(basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m1 > 0:
            Jp = (m1 - 1, m2 + 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m1 * math.exp(log_ratio)
    return M


def J_z_matrix(basis, alpha):
    n = len(basis)
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        M[j, j] = 0.5 * (m1 - m2)
    return M


def dirac_alpha(basis, alpha):
    Jp = J_plus_matrix(basis, alpha)
    Jm = J_minus_matrix(basis, alpha)
    Jz = J_z_matrix(basis, alpha)
    J1 = (Jp + Jm) / 2
    J2 = (Jp - Jm) / (2j)
    return np.kron(J1, sx) + np.kron(J2, sy) + np.kron(Jz, sz)


def site_to_holo(D, basis, alpha, Tz):
    return {0: Tz[0], 1: Tz[1], 2: Tz[0] @ Tz[0], 3: Tz[1] @ Tz[1]}


def bridge_holo(D, S, basis, alpha, Tz):
    site_map = site_to_holo(D, basis, alpha, Tz)
    n = len(basis)
    out = np.eye(n, dtype=complex)
    for a in sorted(S):
        out = out @ site_map[a]
    return out


def mosco_bridge(D, k, basis, alpha, Tz):
    """Sum of b_holo(chi_S) over all |S| = k, divided by sqrt(C(D, k))."""
    n = len(basis)
    out = np.zeros((n, n), dtype=complex)
    count = 0
    for S in combinations(range(D), k):
        out = out + bridge_holo(D, set(S), basis, alpha, Tz)
        count += 1
    return out / math.sqrt(count) if count > 0 else out


def main():
    D = 4
    N = 5
    basis = basis_indices(N)
    alphas = [0.0, 1.0, 2.0, 5.0, 10.0]

    print("=" * 90)
    print("  Computation 39  --  Mosco averaging + alpha rescaling for L_comm closure")
    print("=" * 90)
    print()
    print(f"  D = {D}, N = {N}")
    print()

    # ============================================
    # Step 1: Mosco-averaged bridge norms vs alpha
    # ============================================
    print("  Step 1: Mosco-averaged bridge norms")
    print(f"  || b_Mosco_k ||_op  for k = 0, 1, 2, 3, 4, and rescaling factor c_k(alpha)")
    print()
    print(f"  {'k':>3}  C(D,k)  " + "  ".join(f"a={alpha:5.1f}" for alpha in alphas))
    norms_mosco = {}
    for k in range(D + 1):
        cdk = math.comb(D, k)
        row = f"  {k:>3}  {cdk:>6}  "
        for alpha in alphas:
            Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
            bm = mosco_bridge(D, k, basis, alpha, Tz)
            n = op_norm(bm)
            norms_mosco[(k, alpha)] = n
            row += f"  {n:>10.4f}"
        print(row)
    print()

    # ============================================
    # Step 2: rescaled-bridge Dirac commutator
    # ============================================
    print("  Step 2: alpha-rescaled bridge Dirac commutator")
    print()
    print("  bridge_final_k = (1 / || b_Mosco_k ||_op) * b_Mosco_k  (unit op-norm)")
    print("  Test: || [D_alpha, bridge_final_k tensor I_C2] ||_op  vs  2 sqrt(k)")
    print()
    print(f"  {'k':>3}  {'2 sqrt(k)':>10}  " + "  ".join(f"a={alpha:5.1f}" for alpha in alphas))
    for k in range(1, D + 1):
        target = 2 * math.sqrt(k)
        row = f"  {k:>3}  {target:>10.4f}  "
        for alpha in alphas:
            Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
            D_a = dirac_alpha(basis, alpha)
            bm = mosco_bridge(D, k, basis, alpha, Tz)
            scale = norms_mosco[(k, alpha)]
            if scale < 1e-9:
                row += f"  {'(zero)':>10}"
                continue
            bridge_final = bm / scale
            bridge_full = np.kron(bridge_final, I2)
            comm = D_a @ bridge_full - bridge_full @ D_a
            row += f"  {op_norm(comm):>10.4f}"
        print(row)
    print()

    # ============================================
    # Step 3: ratio to target
    # ============================================
    print("  Step 3: ratios || comm ||_op / (2 sqrt(k))  --  closure when ratio -> 1")
    print()
    print(f"  {'k':>3}  " + "  ".join(f"a={alpha:5.1f}" for alpha in alphas))
    for k in range(1, D + 1):
        target = 2 * math.sqrt(k)
        row = f"  {k:>3}  "
        for alpha in alphas:
            Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
            D_a = dirac_alpha(basis, alpha)
            bm = mosco_bridge(D, k, basis, alpha, Tz)
            scale = norms_mosco[(k, alpha)]
            if scale < 1e-9:
                row += f"  {'(zero)':>10}"
                continue
            bridge_final = bm / scale
            bridge_full = np.kron(bridge_final, I2)
            comm = D_a @ bridge_full - bridge_full @ D_a
            ratio = op_norm(comm) / target
            row += f"  {ratio:>10.4f}"
        print(row)
    print()

    # ============================================
    # Step 4: variance test (uniform across same-weight S?)
    # ============================================
    print("  Step 4: variance check -- compare per-S Dirac commutator norms after rescaling")
    print()
    print("  If Mosco averaging worked, all |S| = k modes should have the same")
    print("  commutator norm under the rescaled bridge.  Per-S after rescaling:")
    print()
    print(f"  {'k':>3}  {'|S|':>3}  {'2 sqrt(k)':>10}  " + "  ".join(f"a={alpha:5.1f}" for alpha in [0.0, 1.0, 5.0]))
    for k in range(1, D + 1):
        target = 2 * math.sqrt(k)
        # Show a few representative S
        S_list = list(combinations(range(D), k))[:3]  # first 3 of weight k
        for S in S_list:
            S_set = set(S)
            row = f"  {k:>3}  S={str(S_set):<10}  {target:>10.4f}  "
            for alpha in [0.0, 1.0, 5.0]:
                Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
                D_a = dirac_alpha(basis, alpha)
                # Per-S rescaled bridge: use individual b_holo, scaled by its norm
                b_S_holo = bridge_holo(D, S_set, basis, alpha, Tz)
                scale_S = op_norm(b_S_holo)
                if scale_S < 1e-9:
                    row += f"  {'(zero)':>10}"
                    continue
                bridge_S = b_S_holo / scale_S
                bridge_full = np.kron(bridge_S, I2)
                comm = D_a @ bridge_full - bridge_full @ D_a
                row += f"  {op_norm(comm):>10.4f}"
            print(row)
    print()

    print("=" * 90)
    print("  Verdict")
    print("=" * 90)
    print()
    print("  Read Step 3 ratios: a ratio close to 1 indicates the Mosco-averaged")
    print("  bridge with alpha-rescaling achieves the substrate L_comm target 2*sqrt(k)")
    print("  for that weight class.  Ratios that vary with alpha indicate the BS 2022")
    print("  alpha-tuning enters into the closure rate as predicted.")
    print()
    print("  Step 4 checks whether the per-S commutator norms are uniform across")
    print("  same-weight S after rescaling.  If they vary widely, the simple norm-")
    print("  rescaling per-S is not equivalent to a true Mosco average; an extra")
    print("  Wigner-D-component projection is needed.")
    print()
    print("  If both ratios cluster near 1 and the per-S variance is small, L_comm")
    print("  is closing in the operational framework at this D.")


if __name__ == "__main__":
    main()
