#!/usr/bin/env python3
"""
Computation 45 -- Refined symmetric-monomial bridge (Option C exploration)
==========================================================================
The current Mosco-averaged bridge of Computations 38-42 produces
mixed-degree polynomials F_{D, k} whose Dirac commutator op-norm
ratio jumps at the Walsh cutoff k_D transitions (Computation 44).

This script explores an alternative SYMMETRIC bridge construction:
    bridge_sym(chi_S) := T_{(z_0 z_1)^{m(|S|)}}
where the exponent m(k) depends only on the Walsh weight |S| = k,
not on the specific subset S.  After Mosco averaging:
    b_Mosco,k = (1/sqrt(C(D,k))) sum_{|S|=k} T_{(z_0 z_1)^{m(k)}}
              = sqrt(C(D,k)) * T_{(z_0 z_1)^{m(k)}}
and after rescaling:
    bridge_final,k = T_{(z_0 z_1)^{m(k)}} / sup |(z_0 z_1)^{m(k)}|

For F = (z_0 z_1)^m, analytical closed form:
    sup |F|^2 over r^2 + s^2 = 1 is (1/4)^m at r = s = 1/sqrt(2)
    sup |M|^2 / sup |F|^2 = (m+1)^2 * (1 - 1/m^2)^{m-1}
The closure ratio is therefore (m+1) sqrt((1 - 1/m^2)^{m-1}) approximately.

Goal: tune m(k) so that ratio matches 2 sqrt(k).  Test whether this
gives uniform closure independent of D, smoothing the cutoff transitions.
"""
import math
import numpy as np


def analytical_ratio_symmetric(m):
    """For F = (z_0 z_1)^m, analytical sup |M| / sup |F| over partial B^2."""
    if m < 1:
        return 0.0
    factor = (1.0 - 1.0 / m**2) ** (m - 1)
    return (m + 1) * math.sqrt(factor)


def best_m_for_k(k):
    """Find integer m that minimizes |ratio(m) - 2 sqrt(k)|."""
    target = 2 * math.sqrt(k)
    best_m = 1
    best_err = abs(analytical_ratio_symmetric(1) - target)
    for m in range(1, 30):
        err = abs(analytical_ratio_symmetric(m) - target)
        if err < best_err:
            best_err = err
            best_m = m
    return best_m, best_err


def main():
    print("=" * 90)
    print("  Computation 45  --  Refined symmetric-monomial bridge")
    print("=" * 90)
    print()
    print("  Symmetric bridge: chi_S -> T_{(z_0 z_1)^{m(k)}}")
    print("  Analytical closed-form ratio at integer m:")
    print(f"    ratio(m) = sup|M|/sup|F| = (m+1) * sqrt((1 - 1/m^2)^{{m-1}})")
    print()
    print(f"  {'m':>4}  {'analytical ratio':>20}")
    for m in range(1, 13):
        r = analytical_ratio_symmetric(m)
        print(f"  {m:>4}  {r:>20.6f}")
    print()

    print("  Best integer m(k) for substrate target 2 sqrt(k):")
    print()
    print(f"  {'k':>3}  {'2 sqrt(k)':>11}  {'best m(k)':>10}  {'ratio at best m':>16}  {'gap':>10}")
    for k in range(1, 12):
        target = 2 * math.sqrt(k)
        m, _ = best_m_for_k(k)
        ratio = analytical_ratio_symmetric(m)
        gap = abs(ratio - target)
        rel_gap = gap / target
        print(f"  {k:>3}  {target:>11.4f}  {m:>10}  {ratio:>16.4f}  {rel_gap:>10.4f}")
    print()

    print("=" * 90)
    print("  Interpolation: superposition T_{(z_0 z_1)^m} + alpha T_{(z_0 z_1)^{m+1}}")
    print("=" * 90)
    print()
    print("  For each k, find m and alpha so that the superposition ratio = 2 sqrt(k).")
    print("  This gives EXACT closure at each k via two integer-degree symmetric monomials.")
    print()

    # For F = (z_0 z_1)^m + alpha (z_0 z_1)^{m+1}:
    #   At diagonal r = s = 1/sqrt(2): F = (1/2)^m * (1 + alpha/2)
    #   sup |F| includes contribution from non-diagonal points; superposition
    #   makes optimization non-trivial.
    #
    # Compute numerically via 3-param grid search.

    def F_sup_super(m, alpha, n_theta=40, n_phi=30):
        best = 0.0
        for i in range(n_theta + 1):
            theta = (i / n_theta) * (np.pi / 2)
            r = np.cos(theta)
            s = np.sin(theta)
            for j in range(n_phi + 1):
                phi0 = (j / n_phi) * (2 * np.pi)
                for k in range(n_phi + 1):
                    phi1 = (k / n_phi) * (2 * np.pi)
                    z0 = r * np.exp(1j * phi0)
                    z1 = s * np.exp(1j * phi1)
                    val = abs((z0 * z1)**m + alpha * (z0 * z1)**(m + 1))
                    if val > best:
                        best = val
        return best

    def M_sup_super(m, alpha, n_theta=30, n_phi=24):
        best = 0.0
        for i in range(n_theta + 1):
            theta = (i / n_theta) * (np.pi / 2)
            r = np.cos(theta)
            s = np.sin(theta)
            for j in range(n_phi + 1):
                phi0 = (j / n_phi) * (2 * np.pi)
                for k in range(n_phi + 1):
                    phi1 = (k / n_phi) * (2 * np.pi)
                    z0 = r * np.exp(1j * phi0)
                    z1 = s * np.exp(1j * phi1)
                    # F = z_0^m z_1^m + alpha z_0^{m+1} z_1^{m+1}
                    # J_+ F = m z_0^{m+1} z_1^{m-1} + alpha (m+1) z_0^{m+2} z_1^m
                    # J_- F = m z_0^{m-1} z_1^{m+1} + alpha (m+1) z_0^m z_1^{m+2}
                    # J_z F = 0 (symmetric in (p, q))
                    if m >= 1:
                        jp = m * z0**(m+1) * z1**(m-1) + alpha * (m+1) * z0**(m+2) * z1**m
                        jm = m * z0**(m-1) * z1**(m+1) + alpha * (m+1) * z0**m * z1**(m+2)
                    else:
                        jp = alpha * (m+1) * z0**(m+2) * z1**m
                        jm = alpha * (m+1) * z0**m * z1**(m+2)
                    jz_val = 0.0 + 0.0j
                    naive = 0.5 * (abs(jp)**2 + abs(jm)**2) + abs(jz_val)**2
                    X = 0.5 * (abs(jp)**2 - abs(jm)**2)
                    Y_imag = (jz_val.conjugate() * jm).imag
                    Y_sq = 4.0 * Y_imag**2
                    op2 = naive + math.sqrt(X**2 + Y_sq)
                    val = math.sqrt(op2)
                    if val > best:
                        best = val
        return best

    print(f"  {'k':>3}  {'2 sqrt(k)':>11}  {'m':>3}  {'alpha':>8}  {'ratio':>10}  {'gap_rel':>10}")
    for k in range(1, 7):
        target = 2 * math.sqrt(k)
        # Search over m and alpha to minimize gap
        best_gap = float('inf')
        best_m, best_alpha = 1, 0.0
        best_ratio = 0.0
        for m in range(1, 8):
            for alpha_int in range(-20, 21):
                alpha = alpha_int / 10.0
                sF = F_sup_super(m, alpha, n_theta=20, n_phi=12)
                if sF < 1e-9:
                    continue
                sM = M_sup_super(m, alpha, n_theta=15, n_phi=10)
                r = sM / sF
                gap = abs(r - target)
                if gap < best_gap:
                    best_gap = gap
                    best_m = m
                    best_alpha = alpha
                    best_ratio = r
        rel_gap = best_gap / target
        print(f"  {k:>3}  {target:>11.4f}  {best_m:>3}  {best_alpha:>8.2f}  {best_ratio:>10.4f}  {rel_gap:>10.4f}")


if __name__ == "__main__":
    main()
