#!/usr/bin/env python3
"""
Computation 42 -- L_comm relaxed-criterion convergence at D = 8, N = 20
========================================================================
With Path B committed (paper v17.94, equation eq:lcomm-relaxed), the
L_comm closure target is the relaxed norm-equivalence criterion with
gamma_D -> 0 at exponential rate.  Computation 40 established the
norm-match convergence empirically at (D, N) = (4, 5), (4, 10), (6, 12).
This script extends to (D, N) = (8, 20) to test the rate.
"""
import math
from itertools import combinations
import numpy as np
import numpy.linalg as la


sx = np.array([[0, 1], [1, 0]], dtype=complex)
sy = np.array([[0, -1j], [1j, 0]], dtype=complex)
sz = np.array([[1, 0], [0, -1]], dtype=complex)
I2 = np.eye(2, dtype=complex)


def kron_chain(ops):
    out = ops[0]
    for op in ops[1:]:
        out = np.kron(out, op)
    return out


def chi_walsh(D, S):
    return kron_chain([sz if a in S else I2 for a in range(D)])


def chi_clifford(D, a):
    return kron_chain([sz] * a + [sx] + [I2] * (D - 1 - a))


def op_norm(M):
    return float(la.norm(M, ord=2))


def basis_indices(N):
    return [(m1, m2) for m1 in range(N + 1) for m2 in range(N + 1 - m1)]


def log_norm_sq(m1, m2, alpha):
    return (math.lgamma(m1 + 1) + math.lgamma(m2 + 1)
            + math.lgamma(alpha + 3) - math.lgamma(m1 + m2 + alpha + 3))


def Tz_matrix(a, basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        Jp = list(J)
        Jp[a] += 1
        Jp = tuple(Jp)
        if Jp in idx:
            log_ratio = 0.5 * (log_norm_sq(Jp[0], Jp[1], alpha)
                               - log_norm_sq(J[0], J[1], alpha))
            M[idx[Jp], j] = math.exp(log_ratio)
    return M


def J_plus_matrix(basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m2 > 0:
            Jp = (m1 + 1, m2 - 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m2 * math.exp(log_ratio)
    return M


def J_minus_matrix(basis, alpha):
    n = len(basis)
    idx = {J: i for i, J in enumerate(basis)}
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        if m1 > 0:
            Jp = (m1 - 1, m2 + 1)
            if Jp in idx:
                log_ratio = 0.5 * (log_norm_sq(*Jp, alpha) - log_norm_sq(*J, alpha))
                M[idx[Jp], j] = m1 * math.exp(log_ratio)
    return M


def J_z_matrix(basis, alpha):
    n = len(basis)
    M = np.zeros((n, n), dtype=complex)
    for j, J in enumerate(basis):
        m1, m2 = J
        M[j, j] = 0.5 * (m1 - m2)
    return M


def dirac_alpha(basis, alpha):
    Jp = J_plus_matrix(basis, alpha)
    Jm = J_minus_matrix(basis, alpha)
    Jz = J_z_matrix(basis, alpha)
    J1 = (Jp + Jm) / 2
    J2 = (Jp - Jm) / (2j)
    return np.kron(J1, sx) + np.kron(J2, sy) + np.kron(Jz, sz)


def site_to_holo(D, basis, alpha, Tz):
    site_map = {}
    for site in range(D):
        coord = site % 2
        power = site // 2 + 1
        op = Tz[coord]
        for _ in range(power - 1):
            op = op @ Tz[coord]
        site_map[site] = op
    return site_map


def bridge_holo(D, S, basis, alpha, Tz):
    site_map = site_to_holo(D, basis, alpha, Tz)
    n = len(basis)
    out = np.eye(n, dtype=complex)
    for a in sorted(S):
        out = out @ site_map[a]
    return out


def mosco_bridge(D, k, basis, alpha, Tz):
    n = len(basis)
    out = np.zeros((n, n), dtype=complex)
    count = 0
    for S in combinations(range(D), k):
        out = out + bridge_holo(D, set(S), basis, alpha, Tz)
        count += 1
    return out / math.sqrt(count) if count > 0 else out


def main():
    D = 8
    N = 20
    basis = basis_indices(N)
    alphas = [0.0, 1.0, 2.0, 5.0]

    print("=" * 90)
    print(f"  Computation 42  --  L_comm relaxed-criterion convergence at D = {D}, N = {N}")
    print("=" * 90)
    print()
    print(f"  Substrate dim: 2^{D} = {1 << D}")
    print(f"  Bergman dim:   {len(basis)}")
    print(f"  Full Hilbert:  {len(basis) * 2}")
    print(f"  Walsh-weight cutoff:  k_D = floor(sqrt(D)) = {int(math.floor(math.sqrt(D)))}")
    print()

    print("  Mosco-averaged bridge norms || b_Mosco_k ||_op")
    print(f"  {'k':>3}  {'C(D,k)':>7}  " + "  ".join(f"a={alpha:5.1f}" for alpha in alphas))
    norms_mosco = {}
    for k in range(D + 1):
        cdk = math.comb(D, k)
        row = f"  {k:>3}  {cdk:>7}  "
        for alpha in alphas:
            Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
            bm = mosco_bridge(D, k, basis, alpha, Tz)
            n_val = op_norm(bm)
            norms_mosco[(k, alpha)] = n_val
            row += f"  {n_val:>10.4f}"
        print(row)
    print()

    print("  Bergman-side ratios || [D_alpha, bridge_final_k tensor I] ||_op / (2*sqrt(k))")
    print()
    print(f"  {'k':>3}  {'2 sqrt(k)':>11}  " + "  ".join(f"a={alpha:5.1f}" for alpha in alphas))
    ratios = {}
    for k in range(1, D + 1):
        target = 2 * math.sqrt(k)
        row = f"  {k:>3}  {target:>11.4f}  "
        for alpha in alphas:
            Tz = [Tz_matrix(0, basis, alpha), Tz_matrix(1, basis, alpha)]
            D_a = dirac_alpha(basis, alpha)
            bm = mosco_bridge(D, k, basis, alpha, Tz)
            scale = norms_mosco[(k, alpha)]
            if scale < 1e-9:
                row += f"  {'(zero)':>10}"
                continue
            bridge_final = bm / scale
            bridge_full = np.kron(bridge_final, I2)
            comm = D_a @ bridge_full - bridge_full @ D_a
            r = op_norm(comm) / target
            ratios[(k, alpha)] = r
            row += f"  {r:>10.4f}"
        print(row)
    print()

    print("  gamma_D(k) = |1 - ratio|  (the L_comm relaxed-criterion gap)")
    print()
    print(f"  {'k':>3}  " + "  ".join(f"a={alpha:5.1f}" for alpha in alphas))
    for k in range(1, D + 1):
        row = f"  {k:>3}  "
        for alpha in alphas:
            r = ratios.get((k, alpha), None)
            if r is None:
                row += f"  {'-':>10}"
            else:
                row += f"  {abs(1 - r):>10.4f}"
        print(row)
    print()

    print("=" * 90)
    print("  Convergence summary across (D, N) configurations at alpha = 0")
    print("=" * 90)
    print()
    print(f"  {'k':>3}  {'D=4, N=5':>12}  {'D=4, N=10':>12}  {'D=6, N=12':>12}  {'D=8, N=20':>12}")
    ref = {
        1: (0.6043, 0.6211, 0.7219),
        2: (0.7148, 0.6671, 0.7975),
        3: (0.8551, 0.8073, 0.9625),
        4: (None, 0.8874, 1.1187),
        5: (None, None, 1.2273),
        6: (None, None, 1.3229),
    }
    for k in range(1, D + 1):
        d4n5, d4n10, d6n12 = ref.get(k, (None, None, None))
        d8n20 = ratios.get((k, 0.0), None)
        row = f"  {k:>3}  "
        row += f"  {d4n5:>10.4f}  " if d4n5 is not None else f"  {'-':>10}  "
        row += f"  {d4n10:>10.4f}  " if d4n10 is not None else f"  {'-':>10}  "
        row += f"  {d6n12:>10.4f}  " if d6n12 is not None else f"  {'-':>10}  "
        row += f"  {d8n20:>10.4f}" if d8n20 is not None else f"  {'-':>10}"
        print(row)
    print()

    print("  Exponential-rate sketch at k = k_D = floor(sqrt(D))")
    print(f"  {'D':>3}  {'k_D':>3}  {'ratio @alpha=0':>16}  {'gamma_D':>10}  {'log(gamma_D)':>14}")
    k_D_data = [
        (4, 2, ref[2][0]),
        (6, 2, ref[2][2]),
        (8, 2, ratios.get((2, 0.0))),
    ]
    valid = []
    for D_val, kD, r in k_D_data:
        if r is None:
            continue
        gamma = abs(1 - r)
        log_g = math.log(gamma) if gamma > 1e-12 else float('-inf')
        valid.append((D_val, gamma))
        print(f"  {D_val:>3}  {kD:>3}  {r:>16.4f}  {gamma:>10.4f}  {log_g:>14.4f}")
    print()

    if len(valid) >= 2:
        Ds = np.array([v[0] for v in valid], dtype=float)
        log_gammas = np.array([math.log(v[1]) for v in valid])
        slope, intercept = np.polyfit(Ds, log_gammas, 1)
        print(f"  Linear fit log(gamma_D) = {slope:+.4f} * D + {intercept:+.4f}")
        print(f"  => gamma_D ~ {math.exp(intercept):.4f} * exp({slope:+.4f} * D)")
        if slope < -0.05:
            print(f"  => empirical exponential decay rate c ~ {-slope:.4f}")
        else:
            print(f"  => slope is shallow; more data needed to confirm exponential")
    print()

    print("=" * 90)
    print("  Verdict")
    print("=" * 90)
    print()
    print("  Read the convergence-summary table: ratios at fixed k should")
    print("  approach 1 as (D, N) grows.  At k = k_D = floor(sqrt(D)) the trend")
    print("  determines whether the relaxed L_comm criterion closes exponentially.")


if __name__ == "__main__":
    main()
