#!/usr/bin/env python3
"""
walsh_frechet_round_trip.py
============================
Directly measures the section round-trip ratio under the Frechet smooth
Lip-norm

    L_F(T) := max_{j=0..K_MAX} ||ad_D^j(T)||_op / j!

(the K_MAX-truncated supremum), and compares it to the round-trip
residual ||T_tail||_op for T's of various Walsh-weight structure.

Goal: verify that L_F gives gamma_D -> 0 at rate exp(-O(sqrt D)) on
the constant-tail stress test.  This is the natural spectral-triple
Lip-norm; if it works, then L_round closes under the standard smooth
subalgebra hypothesis without any artificial weight choice.
"""
import math
import numpy as np
import numpy.linalg as la

sz = np.array([[1, 0], [0, -1]], dtype=complex)
sx = np.array([[0, 1], [1, 0]], dtype=complex)
I2 = np.eye(2, dtype=complex)


def kron_chain(ops):
    out = ops[0]
    for op in ops[1:]:
        out = np.kron(out, op)
    return out


def chi_clifford(D, a):
    return kron_chain([sz] * a + [sx] + [I2] * (D - 1 - a))


def chi_walsh(D, S):
    return kron_chain([sz if a in S else I2 for a in range(D)])


def dirac(D):
    N = 1 << D
    out = np.zeros((N, N), dtype=complex)
    for a in range(D):
        out = out + chi_clifford(D, a)
    return out


def op_norm(M):
    return float(la.norm(M, ord=2))


def L_frechet(D_op, T, K_max):
    """L_F(T) = max_{j=0..K_max} ||ad_D^j T|| / j!.

    Uses the formula ad_D(X) = D X - X D, applied iteratively.
    """
    best = op_norm(T)
    fact = 1.0
    cur = T
    for j in range(1, K_max + 1):
        cur = D_op @ cur - cur @ D_op
        fact *= j
        v = op_norm(cur) / fact
        if v > best:
            best = v
    return best


def constant_tail_T(D, k_D):
    """Build T = sum_{|S|>k_D} chi_S as a 2^D x 2^D matrix."""
    N = 1 << D
    out = np.zeros((N, N), dtype=complex)
    from itertools import combinations
    for k in range(k_D + 1, D + 1):
        for S in combinations(range(D), k):
            out = out + chi_walsh(D, S)
    return out


def tail_op_norm_from_coefs_diag(D, k_D):
    """Operator norm of the constant tail: max |sum of +/-1 across tail subsets|.

    Since each chi_S is diagonal +/- 1, the sum is diagonal with entries equal
    to (count of tail subsets +1 at x) - (count of tail subsets -1 at x).
    At x = 0 (all-zero bits) every chi_S evaluates to +1, so this entry is
    sum_{k>k_D} C(D, k) -- the maximum.
    """
    return sum(math.comb(D, k) for k in range(k_D + 1, D + 1))


# ---------- Sweep ----------
K_MAX = 12   # Frechet truncation; takes 0!..12! into account
print("=" * 100)
print("  walsh_frechet_round_trip.py  --  Frechet smooth Lip-norm vs L_round residual")
print("=" * 100)
print()
print(f"  Frechet Lip-norm:  L_F(T) = max_{{j=0..{K_MAX}}} ||ad_D^j(T)||_op / j!")
print()
print(f"  {'D':>3} {'k_D':>4} {'||T_tail||':>14} {'L_F(T_tail)':>14} {'ratio':>12} {'2sqrt(D)':>10} {'exp(-2sqrt(D))':>14}")
print(f"  {'-'*3} {'-'*4} {'-'*14} {'-'*14} {'-'*12} {'-'*10} {'-'*14}")

for D in (4, 6, 8, 10, 12):
    k_D = int(math.floor(math.sqrt(D)))
    D_op = dirac(D)
    T = constant_tail_T(D, k_D)
    norm_op = op_norm(T)
    expected_op = tail_op_norm_from_coefs_diag(D, k_D)
    if abs(norm_op - expected_op) > 1e-9 * expected_op:
        print(f"  WARN: op norm {norm_op} != expected {expected_op}")
    lf = L_frechet(D_op, T, K_MAX)
    ratio = norm_op / lf if lf > 0 else float('nan')
    sd2 = 2 * math.sqrt(D)
    expsd2 = math.exp(-sd2)
    print(f"  {D:>3} {k_D:>4} {norm_op:>14.4e} {lf:>14.4e} {ratio:>12.4e} {sd2:>10.4f} {expsd2:>14.4e}")

print()
print("=" * 100)
print("  Notes")
print("=" * 100)
print()
print("  If the Frechet smooth Lip-norm closes L_round, the ratio in column 'ratio' should")
print("  decay roughly as exp(-2 sqrt D) (last column, for reference).")
print()
print("  If the ratio decays SLOWER, then the Frechet Lip-norm with K_MAX = {K_MAX} truncation")
print("  is too weak; either increase K_MAX or use a stronger Lip-norm (such as the exponential")
print("  L_sup_e_c with c > log 2 found in walsh_round_trip2.py).")
print()
print("  The largest D here is 12 (4096 x 4096 matrices); higher D requires diagonal-only")
print("  arithmetic, which is harder for L_F because ad_D mixes the diagonal Walsh basis with")
print("  the non-diagonal Clifford generators.")
