#!/usr/bin/env python3
"""
Computation 47 -- Option C: 3-monomial bridge for odd-k closure
================================================================
Computation 46 showed that the 2-monomial bridge
    f(z) = (z_0 z_1)^m + alpha (z_0 z_1)^{m+1}
gives essentially exact L_comm closure at k = 1, 2, 4, 6, 8 (even k
and 1), but a residual gap at odd k = 3, 5, 7.

This script tests a 3-monomial bridge with two free parameters:
    f(z) = (z_0 z_1)^{m-1} + beta (z_0 z_1)^m + alpha (z_0 z_1)^{m+1}
to see if exact closure is achievable at odd k.
"""
import math
import numpy as np


def F_sup_3mon(m, alpha, beta, n_theta=30, n_phi=18):
    best = 0.0
    for i in range(n_theta + 1):
        theta = (i / n_theta) * (np.pi / 2)
        r = np.cos(theta)
        s = np.sin(theta)
        for j in range(n_phi + 1):
            phi0 = (j / n_phi) * (2 * np.pi)
            for k in range(n_phi + 1):
                phi1 = (k / n_phi) * (2 * np.pi)
                z0 = r * np.exp(1j * phi0)
                z1 = s * np.exp(1j * phi1)
                u = z0 * z1
                val = abs(u**(m-1) + beta * u**m + alpha * u**(m+1))
                if val > best:
                    best = val
    return best


def M_sup_3mon(m, alpha, beta, n_theta=25, n_phi=14):
    """
    F = u^{m-1} + beta u^m + alpha u^{m+1}, u = z_0 z_1.
    F = z_0^{m-1} z_1^{m-1} + beta z_0^m z_1^m + alpha z_0^{m+1} z_1^{m+1}
    J_+ F = z_0 d/dz_1 F
          = (m-1) z_0^m z_1^{m-2} + beta * m * z_0^{m+1} z_1^{m-1}
            + alpha * (m+1) * z_0^{m+2} z_1^m
    J_- F = z_1 d/dz_0 F  (symmetric)
    J_z F = (p - q)/2 * each term, but all terms have p = q, so J_z F = 0.
    """
    best = 0.0
    for i in range(n_theta + 1):
        theta = (i / n_theta) * (np.pi / 2)
        r = np.cos(theta)
        s = np.sin(theta)
        for j in range(n_phi + 1):
            phi0 = (j / n_phi) * (2 * np.pi)
            for k in range(n_phi + 1):
                phi1 = (k / n_phi) * (2 * np.pi)
                z0 = r * np.exp(1j * phi0)
                z1 = s * np.exp(1j * phi1)
                if m >= 2:
                    t1 = (m-1) * z0**m * z1**(m-2)
                else:
                    t1 = 0
                t2 = beta * m * z0**(m+1) * z1**(m-1) if m >= 1 else 0
                t3 = alpha * (m+1) * z0**(m+2) * z1**m
                jp = t1 + t2 + t3
                # J_- by symmetry (swap z_0 <-> z_1 in coefficient roles)
                if m >= 2:
                    s1 = (m-1) * z0**(m-2) * z1**m
                else:
                    s1 = 0
                s2 = beta * m * z0**(m-1) * z1**(m+1) if m >= 1 else 0
                s3 = alpha * (m+1) * z0**m * z1**(m+2)
                jm = s1 + s2 + s3
                jz_val = 0.0 + 0.0j
                naive = 0.5 * (abs(jp)**2 + abs(jm)**2) + abs(jz_val)**2
                X = 0.5 * (abs(jp)**2 - abs(jm)**2)
                Y_imag = (jz_val.conjugate() * jm).imag
                Y_sq = 4.0 * Y_imag**2
                op2 = naive + math.sqrt(X**2 + Y_sq)
                val = math.sqrt(op2)
                if val > best:
                    best = val
    return best


def find_best_3mon(k, m, alpha_grid, beta_grid):
    target = 2 * math.sqrt(k)
    best_gap = float('inf')
    best = None
    for alpha in alpha_grid:
        for beta in beta_grid:
            sF = F_sup_3mon(m, alpha, beta)
            if sF < 1e-9:
                continue
            sM = M_sup_3mon(m, alpha, beta)
            r = sM / sF
            gap = abs(r - target)
            if gap < best_gap:
                best_gap = gap
                best = (gap, m, alpha, beta, r, sF)
    return best


def main():
    print("=" * 90)
    print("  Computation 47  --  3-monomial bridge for odd-k closure")
    print("=" * 90)
    print()
    print("  Symbol: f(z) = (z_0 z_1)^{m-1} + beta (z_0 z_1)^m + alpha (z_0 z_1)^{m+1}")
    print()
    print(f"  {'k':>3}  {'2 sqrt(k)':>11}  {'m':>3}  {'beta':>10}  {'alpha':>10}"
          f"  {'ratio':>10}  {'gap':>10}")
    for k in [3, 5, 7]:
        target = 2 * math.sqrt(k)
        # Search across multiple m values
        best_overall = None
        for m in range(2, 7):
            alpha_grid = [-3.0 + 0.5 * i for i in range(13)]
            beta_grid = [-3.0 + 0.5 * i for i in range(13)]
            result = find_best_3mon(k, m, alpha_grid, beta_grid)
            if result and (best_overall is None or result[0] < best_overall[0]):
                best_overall = result
        gap, m, alpha, beta, ratio, sF = best_overall
        print(f"  {k:>3}  {target:>11.4f}  {m:>3}  {beta:>10.3f}  {alpha:>10.3f}"
              f"  {ratio:>10.5f}  {gap:>10.5f}")
        # Refine around best
        am, bm = alpha, beta
        for scale in [0.2, 0.05, 0.01]:
            alpha_grid = [am + scale * d for d in range(-4, 5)]
            beta_grid = [bm + scale * d for d in range(-4, 5)]
            result = find_best_3mon(k, m, alpha_grid, beta_grid)
            if result and result[0] < gap:
                gap, _, alpha, beta, ratio, sF = result
                am, bm = alpha, beta
        print(f"  {'(refined)':>15}     {m:>3}  {beta:>10.5f}  {alpha:>10.5f}"
              f"  {ratio:>10.5f}  {gap:>10.5f}")


if __name__ == "__main__":
    main()
