#!/usr/bin/env python3
"""
Computation 40 -- L_comm convergence at larger (D, N): asymptotic closure test
================================================================================
Computation 39 demonstrated that the combined Mosco averaging + alpha
rescaling on the SU(2)-coupled Bergman framework gives L_comm ratios
that monotonically approach 1 with weight at D = 4, N = 5:

    k = 1:  0.60
    k = 2:  0.71
    k = 3:  0.87
    k = 4:  zero (truncation collapse)

This computation tests the convergence by scaling up:
  (i)  D = 4, N = 10  (larger Bergman truncation; resolves k = 4 truncation)
  (ii) D = 6, N = 12  (larger substrate; tests behavior at k = 4, 5)

The hypothesis: the symmetric-sum Mosco + alpha-rescaling bridge produces
ratios that converge to 1 as (D, N) increases, demonstrating asymptotic
closure of L_comm in the operational framework.

If confirmed, the L_comm gap is reduced to a quantitative convergence
problem rather than a structural one: the bridge framework correctly
captures the substrate <-> S^3 spectral-triple correspondence, and the
remaining work is the analytical rate proof.
"""
import math
from itertools import combinations
import numpy as np
import numpy.linalg as la


sx = np.array([[0, 1], [1, 0]], dtype=complex)
sy = np.array([[0, -1j], [1j, 0]], dtype=complex)
sz = np.array([[1, 0], [0, -1]], dtype=complex)
I2 = np.eye(2, dtype=complex)


def op_norm(M):
    return float(la.norm(M, ord=2))


def basis_indices(N):
    return [(m1, m2) for m1 in range(N + 1) for m2 in range(N + 1 - m1)]


def log_norm_sq(m1, m2, alpha):
    return (math.lgamma(m1 + 1) + math.lgamma(m2 + 1)
            + math.lgamma(alpha + 3) - math.lgamma(m1 + m2 + alpha + 3))


def Tz_matrix(a, basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        Jp = list(J)
        Jp[a] += 1
        Jp = tuple(Jp)
        if Jp in idx:
            log_ratio = 0.5 * (log_norm_sq(Jp[0], Jp[1], alpha)
                               - log_norm_sq(J[0], J[1], alpha))
            M[idx[Jp], j] = math.exp(log_ratio)
    return M


def J_plus_matrix(basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m2 > 0:
            Jp = (m1 + 1, m2 - 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m2 * math.exp(log_ratio)
    return M


def J_minus_matrix(basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m1 > 0:
            Jp = (m1 - 1, m2 + 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m1 * math.exp(log_ratio)
    return M


def J_z_matrix(basis, alpha):
    n = len(basis)
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        M[j, j] = 0.5 * (m1 - m2)
    return M


def dirac_alpha(basis, alpha):
    Jp = J_plus_matrix(basis, alpha)
    Jm = J_minus_matrix(basis, alpha)
    Jz = J_z_matrix(basis, alpha)
    J1 = (Jp + Jm) / 2
    J2 = (Jp - Jm) / (2j)
    return np.kron(J1, sx) + np.kron(J2, sy) + np.kron(Jz, sz)


def site_to_holo(D, basis, alpha, Tz):
    """Map substrate site to holomorphic Toeplitz operator.

    For D up to 8:
      sites 0, 1 -> T_{z_0}, T_{z_1}
      sites 2, 3 -> T_{z_0^2}, T_{z_1^2}
      sites 4, 5 -> T_{z_0^3}, T_{z_1^3}
      sites 6, 7 -> T_{z_0^4}, T_{z_1^4}
    """
    site_map = {}
    for site in range(D):
        coord = site % 2
        power = site // 2 + 1
        op = Tz[coord]
        for _ in range(power - 1):
            op = op @ Tz[coord]
        site_map[site] = op
    return site_map


def bridge_holo(D, S, basis, alpha, Tz):
    site_map = site_to_holo(D, basis, alpha, Tz)
    n = len(basis)
    out = np.eye(n, dtype=complex)
    for a in sorted(S):
        out = out @ site_map[a]
    return out


def mosco_bridge(D, k, basis, alpha, Tz):
    n = len(basis)
    out = np.zeros((n, n), dtype=complex)
    count = 0
    for S in combinations(range(D), k):
        out = out + bridge_holo(D, set(S), basis, alpha, Tz)
        count += 1
    return out / math.sqrt(count) if count > 0 else out


def run_at(D, N, alphas, label):
    basis = basis_indices(N)
    print(f"  {label}: D = {D}, N = {N}, Bergman dim = {len(basis)}, full dim = {len(basis) * 2}")
    print()
    print(f"  Step 1: Mosco bridge norms || b_Mosco_k ||_op")
    print(f"  {'k':>3}  C(D,k)  " + "  ".join(f"a={alpha:5.1f}" for alpha in alphas))
    norms_mosco = {}
    for k in range(D + 1):
        cdk = math.comb(D, k)
        row = f"  {k:>3}  {cdk:>6}  "
        for alpha in alphas:
            Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
            bm = mosco_bridge(D, k, basis, alpha, Tz)
            n = op_norm(bm)
            norms_mosco[(k, alpha)] = n
            row += f"  {n:>10.4f}"
        print(row)
    print()

    print(f"  Step 2: ratios || [D, bridge_final_k tensor I] ||_op / (2 sqrt(k))")
    print(f"  Closure when ratio -> 1")
    print()
    print(f"  {'k':>3}  {'2 sqrt(k)':>10}  " + "  ".join(f"a={alpha:5.1f}" for alpha in alphas))
    ratios = {}
    for k in range(1, D + 1):
        target = 2 * math.sqrt(k)
        row = f"  {k:>3}  {target:>10.4f}  "
        for alpha in alphas:
            Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
            D_a = dirac_alpha(basis, alpha)
            bm = mosco_bridge(D, k, basis, alpha, Tz)
            scale = norms_mosco[(k, alpha)]
            if scale < 1e-9:
                row += f"  {'(zero)':>10}"
                continue
            bridge_final = bm / scale
            bridge_full = np.kron(bridge_final, I2)
            comm = D_a @ bridge_full - bridge_full @ D_a
            r = op_norm(comm) / target
            ratios[(k, alpha)] = r
            row += f"  {r:>10.4f}"
        print(row)
    print()
    return ratios


def main():
    print("=" * 90)
    print("  Computation 40  --  L_comm convergence at larger (D, N): asymptotic closure")
    print("=" * 90)
    print()

    alphas = [0.0, 1.0, 5.0]

    # Case (i): D = 4, N = 10 (larger Bergman truncation)
    r_D4_N10 = run_at(4, 10, alphas, "Case (i)")

    # Case (ii): D = 6, N = 12 (larger substrate)
    r_D6_N12 = run_at(6, 12, alphas, "Case (ii)")

    # Comparison: ratios across configurations at alpha = 0
    print("=" * 90)
    print("  Convergence summary at alpha = 0")
    print("=" * 90)
    print()
    print(f"  {'k':>3}  {'D=4, N=5 (Comp 39)':>22}  {'D=4, N=10':>14}  {'D=6, N=12':>14}")
    comp39 = {1: 0.6043, 2: 0.7148, 3: 0.8551, 4: 0.0}
    for k in range(1, 7):
        if k not in comp39 and (k, 0.0) not in r_D4_N10 and (k, 0.0) not in r_D6_N12:
            continue
        row = f"  {k:>3}  "
        prev = comp39.get(k, None)
        row += f"  {prev:>14.4f}  " if prev is not None else f"  {'-':>14}  "
        d4n10 = r_D4_N10.get((k, 0.0), None)
        row += f"  {d4n10:>12.4f}  " if d4n10 is not None else f"  {'-':>12}  "
        d6n12 = r_D6_N12.get((k, 0.0), None)
        row += f"  {d6n12:>12.4f}" if d6n12 is not None else f"  {'-':>12}"
        print(row)
    print()

    print("=" * 90)
    print("  Verdict")
    print("=" * 90)
    print()
    print("  Trend across (D, N) configurations indicates whether the L_comm")
    print("  ratios converge toward 1 at fixed weight k.  Specifically:")
    print()
    print("    - At fixed k, do ratios increase or decrease with (D, N)?")
    print("    - At fixed (D, N), do ratios increase with k toward 1?")
    print()
    print("  If both trends hold (ratios increasing in (D, N) toward 1 at fixed k,")
    print("  AND increasing in k toward 1 at fixed (D, N)), the L_comm gap is")
    print("  asymptotically closing in the operational framework.")


if __name__ == "__main__":
    main()
