#!/usr/bin/env python3
"""
walsh_bridge_higher_weights.py
================================
Reconnaissance for the rate gamma_D in the PST analog of Bhattacharyya-Singla
2022 Lemma 3.2.

This script measures DISCRETE-SIDE quantities only.  Full closure of the
lemma rate requires the continuum-side bridge image b_D([gamma_a, f]) on
spinor-valued functions on S^3, which in turn needs explicit spinor
derivatives of spherical harmonics on the round S^3.  That continuum
construction is the next deliverable.

What this script DOES measure (cheap, exact at any D up to D~10):

  1. The operator norm ||[chi_a, chi_S]||_op as a function of D and |S|.
     Used to confirm the algebraic claim that the discrete Clifford
     derivation gives 0 or 2 on pure Walsh modes (Pauli-string algebra).

  2. The Pauli-weight distribution of [chi_a, chi_S].  In the Walsh basis,
     chi_S sits at Walsh-weight |S|.  After commuting with chi_a, the
     result is a NON-Walsh Pauli string (it contains sigma_y at site a).
     The point: the bridge has to absorb this leakage when mapping the
     discrete commutator to spinor-valued spherical-harmonic operators.
     The PAULI WEIGHT (number of non-identity Pauli factors) of
     [chi_a, chi_S] is bounded by |S| + 1, and tracking how that compares
     to the Walsh weight |S| of the input tells us at which weight the
     bridge defect first appears.

  3. The Hilbert-Schmidt overlap of [chi_a, chi_S] with each Walsh mode
     chi_T.  By the analysis above (Pauli-string
     algebra), this overlap should be 0 for all T (the commutator is
     never a pure Walsh mode), confirming that the bridge defect lives
     entirely in the non-Walsh Pauli subspace.

The actual gamma_D rate (Walsh-cube to S^3 operator convergence) requires
the continuum-side construction and is left to a follow-up.

Run:
    python3 walsh_bridge_higher_weights.py
"""
import math
from itertools import combinations
import numpy as np
import numpy.linalg as la


# ---------- Sparse-friendly diagonal Walsh modes + Clifford generators ----------
def walsh_diag(D, S):
    """Walsh mode chi_S as a length-2^D vector of +-1 (it is diagonal)."""
    mask = sum(1 << a for a in S)
    out = np.empty(1 << D, dtype=np.int8)
    for x in range(1 << D):
        out[x] = 1 - 2 * (bin(x & mask).count("1") & 1)
    return out


def chi_a_apply(v, a, D):
    """Apply the JW Clifford generator chi_a = c_a + c_a^† to a column vector v.
    chi_a flips bit a and multiplies by the JW string sign over bits 0..a-1."""
    out = np.zeros_like(v)
    bit_a = 1 << a
    pre_mask = bit_a - 1
    for x in range(1 << D):
        sign = 1 - 2 * (bin(x & pre_mask).count("1") & 1)
        out[x ^ bit_a] += sign * v[x]
    return out


def commutator_chi_a_chi_S(a, S, D):
    """Return [chi_a, chi_S] as a 2^D x 2^D matrix.  Built column by column."""
    N = 1 << D
    out = np.zeros((N, N), dtype=complex)
    diag = walsh_diag(D, S)
    bit_a = 1 << a
    pre_mask = bit_a - 1
    mask_S = sum(1 << b for b in S)
    for x in range(N):
        # chi_S |x> = diag[x] |x>
        # chi_a |x> = (-1)^pre |x XOR a>
        # chi_a chi_S |x> = diag[x] * (-1)^pre * |x XOR a>
        # chi_S chi_a |x> = (-1)^pre * diag[x XOR a] * |x XOR a>
        sign_a = 1 - 2 * (bin(x & pre_mask).count("1") & 1)
        y = x ^ bit_a
        ca_cS = diag[x] * sign_a
        cS_ca = sign_a * diag[y]
        out[y, x] = ca_cS - cS_ca
    return out


# ---------- §1: confirm ||[chi_a, chi_S]||_op = 0 (a notin S) or 2 (a in S) ----------
print("=" * 78)
print("  walsh_bridge_higher_weights.py")
print("  Reconnaissance for gamma_D rate (PST analog of Bhattacharyya-Singla 3.2)")
print("=" * 78)
print()
print("  §1. Operator norm of [chi_a, chi_S], pure Walsh modes")
print()
print(f"  {'D':>3} {'|S|':>4} {'a in S':>8} {'||[chi_a, chi_S]||':>22}")
print(f"  {'-'*3} {'-'*4} {'-'*8} {'-'*22}")

for D in (4, 6, 8):
    for k in (1, 2, 3, 4):
        if k > D:
            continue
        # Sample one S of weight k and try a = S[0] (in) and a not in S
        S_in = list(range(k))                          # e.g. {0, 1, ..., k-1}
        a_in = S_in[0]
        # find an a not in S
        a_out = next((b for b in range(D) if b not in S_in), None)
        norm_in = la.norm(commutator_chi_a_chi_S(a_in, S_in, D), ord=2)
        if a_out is not None:
            norm_out = la.norm(commutator_chi_a_chi_S(a_out, S_in, D), ord=2)
        else:
            norm_out = float("nan")
        print(f"  {D:>3} {k:>4} {'in':>8} {norm_in:>22.10f}")
        if a_out is not None:
            print(f"  {D:>3} {k:>4} {'out':>8} {norm_out:>22.10f}")

print()
print("  Expected: 0 when a not in S, exactly 2 when a in S.")
print("  This is exact at any D and any |S|; the rate gamma_D is NOT in this")
print("  scalar — it is in the bridge image on the continuum side, which")
print("  this script does NOT yet compute.")

# ---------- §2: Pauli-weight grading of [chi_a, chi_S] ----------
print()
print("  §2. Walsh-mode overlap of [chi_a, chi_S]  (Hilbert-Schmidt inner product)")
print()
print("  HS-overlap <chi_T | [chi_a, chi_S]> / 2^D is 0 for every Walsh mode chi_T,")
print("  because the commutator contains sigma_y at site a and Walsh modes are")
print("  pure sigma_z products.  We verify numerically at D = 6.")

D = 6
S_in = [0, 2, 4]
a_in = 0
comm = commutator_chi_a_chi_S(a_in, S_in, D)
norm_factor = 1 << D
max_overlap = 0.0
for k in range(D + 1):
    for T in combinations(range(D), k):
        diag = walsh_diag(D, T)
        # <chi_T | comm>_HS = sum_x diag[x] * comm[x, x] -- but comm is purely
        # off-diagonal (it flips bit a), so the diagonal is identically zero.
        # The full HS inner product is tr(chi_T^dagger * comm).  Since chi_T
        # is diagonal, tr(chi_T * comm) = sum_x diag[x] * comm[x, x] = 0.
        overlap = float(np.abs(np.sum(diag * np.diag(comm))).item())
        max_overlap = max(max_overlap, overlap)
print(f"    max |<chi_T | comm>_HS| / 2^D over all T  =  {max_overlap / norm_factor:.3e}")
print("    -> 0 to machine precision.  The commutator is perpendicular to the")
print("       entire Walsh subspace; all the bridge content sits in the spinor")
print("       (sigma_y / sigma_x) directions, not in the abelian Walsh sector.")

# ---------- §3: What the script does NOT measure (sets up the next deliverable) ----------
print()
print("  §3. What this script cannot yet measure")
print()
print("  The lemma rate gamma_D in PST Lemma 3.2 is:")
print()
print("      || T - sigma_D(pi_D(T)) ||_op  <=  gamma_D * L_D(T)")
print()
print("  where pi_D is the bridge to C(S^3) and sigma_D is the section lift.  The")
print("  RIGHT-hand side requires:")
print()
print("    (a) An explicit Lip-norm L_D on the discrete CAR algebra (candidate")
print("        Walsh-weight-graded norm given in the companion analysis.")
print("    (b) The continuum-side commutator [gamma_a, f] for f in C^infty(S^3),")
print("        which is the spinor derivative of f.  This requires explicit")
print("        spinor calculus on the round S^3 (Friedrich 1980, Baer 1996).")
print("    (c) The bridge b_D applied to a SPINOR-valued spherical harmonic, not")
print("        just a scalar one.  Computation 9 §7 builds b_D for scalars; the")
print("        spinor extension is a separate construction.")
print()
print("  Once (a)-(c) are in hand, the gamma_D measurement reduces to:")
print()
print("      gamma_D  =  sup_S  || [chi_a, chi_S] - b_D([gamma_a, b_D^{-1}(chi_S)]) ||")
print("                       / L_D(chi_S)")
print()
print("  at sample weights |S| = 1, 2, 3, 4 and D = 4, 6, 8, 10.  That is the next")
print("  numerical deliverable for the closure attempt.")
print()
print("=" * 78)
