#!/usr/bin/env python3
"""
Computation 44 -- Phase 2 analytical L_comm closure at k = k_D
================================================================
Building on Computation 43's validation of the analytical sup-formula
framework, this script computes the L_comm closure gap

    gamma_D(k) = | sup ||M_{D, k}||_{op, 2x2} / sup |F_{D, k}|
                  / (2 sqrt(k)) - 1 |

at the Walsh-weight cutoff k = k_D = floor(sqrt(D)) for D = 4, 6, 8,
using ONLY the analytical sup-norm formulas (no Bergman truncation).

The matrix-valued pointwise op-norm uses the CORRECTED formula:
    ||M||_op^2 = (|a|^2 + |b|^2 + |c|^2) + sqrt(X^2 + |Y|^2)
where a = J_1 F, b = J_2 F, c = J_z F.

If the analytical gamma_D at k_D matches the exponential decay rate
c ~ 0.93 from Computation 42, the Phase 2 framework reproduces the
empirical rate from pure analysis on partial B^2.
"""
import math
import numpy as np
from itertools import combinations


def site_exponent(a):
    """site a -> (z_{a mod 2}, exponent floor(a/2) + 1)"""
    coord = a % 2
    exponent = a // 2 + 1
    return coord, exponent


def mosco_symbol_coeffs(D, k):
    """
    Return dict {(p, q): coefficient} for the Mosco-averaged bridge symbol
        F_{D, k}(z_0, z_1) = (1/sqrt(C(D, k))) sum_{|S|=k} prod_{a in S} z_{a mod 2}^{floor(a/2)+1}
    """
    cdk = math.comb(D, k)
    norm = 1.0 / math.sqrt(cdk)
    poly = {}
    for S in combinations(range(D), k):
        p, q = 0, 0
        for a in S:
            coord, exp = site_exponent(a)
            if coord == 0:
                p += exp
            else:
                q += exp
        key = (p, q)
        poly[key] = poly.get(key, 0.0) + norm
    return poly


def Jp_coeffs(F):
    """J_+ F = z_0 d/dz_1 F: shifts (p, q) -> (p+1, q-1) with multiplier q."""
    out = {}
    for (p, q), c in F.items():
        if q >= 1:
            key = (p + 1, q - 1)
            out[key] = out.get(key, 0.0) + c * q
    return out


def Jm_coeffs(F):
    """J_- F = z_1 d/dz_0 F: shifts (p, q) -> (p-1, q+1) with multiplier p."""
    out = {}
    for (p, q), c in F.items():
        if p >= 1:
            key = (p - 1, q + 1)
            out[key] = out.get(key, 0.0) + c * p
    return out


def Jz_coeffs(F):
    """J_z F = (1/2)(z_0 d/dz_0 - z_1 d/dz_1) F: multiplier (p - q)/2."""
    out = {}
    for (p, q), c in F.items():
        out[(p, q)] = c * 0.5 * (p - q)
    return out


def eval_poly(F, z0, z1):
    """Evaluate poly dict at (z_0, z_1)."""
    total = 0.0 + 0.0j
    for (p, q), c in F.items():
        total += c * (z0**p) * (z1**q)
    return total


def F_sup(F, n_theta=40, n_phi=40):
    best = 0.0
    for i in range(n_theta + 1):
        theta = (i / n_theta) * (np.pi / 2)
        r = np.cos(theta)
        s = np.sin(theta)
        for j in range(n_phi + 1):
            phi0 = (j / n_phi) * (2 * np.pi)
            for k in range(n_phi + 1):
                phi1 = (k / n_phi) * (2 * np.pi)
                z0 = r * np.exp(1j * phi0)
                z1 = s * np.exp(1j * phi1)
                val = abs(eval_poly(F, z0, z1))
                if val > best:
                    best = val
    return best


def M_sup(Jpf, Jmf, Jzf, n_theta=40, n_phi=40):
    """Sup over partial B^2 of pointwise op-norm of M = sigma_a (J_a F)."""
    best = 0.0
    for i in range(n_theta + 1):
        theta = (i / n_theta) * (np.pi / 2)
        r = np.cos(theta)
        s = np.sin(theta)
        for j in range(n_phi + 1):
            phi0 = (j / n_phi) * (2 * np.pi)
            for k in range(n_phi + 1):
                phi1 = (k / n_phi) * (2 * np.pi)
                z0 = r * np.exp(1j * phi0)
                z1 = s * np.exp(1j * phi1)
                jp_val = eval_poly(Jpf, z0, z1)
                jm_val = eval_poly(Jmf, z0, z1)
                jz_val = eval_poly(Jzf, z0, z1)
                naive = 0.5 * (abs(jp_val)**2 + abs(jm_val)**2) + abs(jz_val)**2
                X = 0.5 * (abs(jp_val)**2 - abs(jm_val)**2)
                Y_imag = (jz_val.conjugate() * jm_val).imag
                Y_sq = 4.0 * Y_imag**2
                op2 = naive + math.sqrt(X**2 + Y_sq)
                val = math.sqrt(op2)
                if val > best:
                    best = val
    return best


def analytical_gamma(D, k, n_grid=30):
    F = mosco_symbol_coeffs(D, k)
    Jpf = Jp_coeffs(F)
    Jmf = Jm_coeffs(F)
    Jzf = Jz_coeffs(F)

    sF = F_sup(F, n_grid, n_grid)
    sM = M_sup(Jpf, Jmf, Jzf, n_grid, n_grid)
    if sF < 1e-12:
        return None
    ratio = sM / sF
    target = 2 * math.sqrt(k)
    return ratio, ratio / target, abs(1 - ratio / target)


def main():
    print("=" * 90)
    print("  Computation 44  --  Analytical L_comm closure at k = k_D = floor(sqrt(D))")
    print("=" * 90)
    print()
    print("  Grid search over 3-real-parameter sphere in C^2.")
    print("  All values are ANALYTICAL sup-norms (no Bergman truncation).")
    print()

    print(f"  {'D':>4}  {'k_D':>4}  {'sup |M|/sup |F|':>16}  {'/(2 sqrt(k_D))':>16}  {'gamma_D':>10}")
    results = []
    for D in [4, 6, 8, 10, 12]:
        k_D = int(math.floor(math.sqrt(D)))
        # Finer grid for larger D to improve sup accuracy
        n_grid = 30 if D >= 10 else 36
        r = analytical_gamma(D, k_D, n_grid=n_grid)
        if r is None:
            continue
        ratio, target_ratio, gamma = r
        results.append((D, gamma))
        print(f"  {D:>4}  {k_D:>4}  {ratio:>16.6f}  {target_ratio:>16.6f}  {gamma:>10.6f}")
    print()

    if len(results) >= 2:
        Ds = np.array([r[0] for r in results if r[1] > 1e-9], dtype=float)
        log_gs = np.array([math.log(r[1]) for r in results if r[1] > 1e-9])
        if len(Ds) >= 2:
            slope, intercept = np.polyfit(Ds, log_gs, 1)
            print(f"  Analytical exponential fit:")
            print(f"     log(gamma_D) = {slope:+.4f} * D + {intercept:+.4f}")
            print(f"     gamma_D ~ {math.exp(intercept):.4f} * exp({slope:+.4f} * D)")
            print(f"     Compare to Computation 42 numerical fit:")
            print(f"     gamma_D ~ 19.31 * exp(-0.9277 * D)  (D = 4, 6, 8)")
            print()
            if slope < -0.5:
                print(f"  => analytical exponential decay rate c ~ {-slope:.4f}")
                if abs(-slope - 0.93) < 0.2:
                    print(f"  => MATCHES the numerical rate 0.93 within tolerance!")
                else:
                    print(f"  => rate differs from numerical 0.93; investigate")
            else:
                print(f"  => slope is shallow; analytical does NOT reproduce exponential")
                print(f"     (possible: grid too coarse, OR Phase 2 closure is finite-N-dependent)")


if __name__ == "__main__":
    main()
