#!/usr/bin/env python3
"""
Computation 63 -- Lemma 2 reduction: off-diagonal block of T_tail in D_sub eigenbasis
======================================================================================
Per research/dirac_proof.md section 2.2, Lemma 2 of the L_round closure
requires a lower bound on L_F(T_tail) where T_tail = sum_{|S|>k_D} chi_S.

Computation 62 established the rigorous closure for SINGLE Walsh modes:
    ||ad_D^j(chi_S)|| <= (2 sqrt(D))^j  via  D_sub^2 = D * I.

A sharper observation underlies the proof strategy for the tail sum:

  Because D_sub^2 = D * I, the operator D_sub has spectrum {+sqrt(D), -sqrt(D)}
  each with multiplicity 2^(D-1).  Let P_+, P_- be the spectral projectors.
  Every operator T decomposes as
      T = T_++ + T_+- + T_-+ + T_--
  with T_{eps,delta} := P_eps T P_delta.

  Direct calculation:
      ad_{D_sub}(T_{eps,delta}) = (eps - delta) * sqrt(D) * T_{eps,delta}.
  So T_++ and T_-- are KILLED by ad_{D_sub}; T_+- has eigenvalue +2 sqrt(D);
  T_-+ has eigenvalue -2 sqrt(D).

  Hence ad_{D_sub} acts as a diagonal operator with eigenvalues {0, +2sqrt(D), -2sqrt(D)}
  on the four-block decomposition.

  For any operator T:
      ||ad_D^j(T)||_op = (2 sqrt(D))^j * max(||T_+-||_op, ||T_-+||_op),  (j >= 1).
  (The max equality holds because [[0, T_-+]; [T_+-, 0]] is block-off-diagonal,
  whose op-norm equals max of block op-norms.)

  Therefore L_F(T) = max(||T_+-||, ||T_-+||) * M(D)
  where M(D) := sup_{j>=1} (2 sqrt(D))^j / j!  ~  e^(2 sqrt(D)) / sqrt(2 pi * 2 sqrt(D)).

This script:
  (1) Diagonalizes D_sub for D = 4, 5, 6, 7, 8 to obtain P_+, P_-
  (2) Computes the off-diagonal blocks (T_tail)_+- and (T_tail)_-+ explicitly
  (3) Verifies the identity ||ad_D^j(T_tail)|| = (2 sqrt(D))^j * max(||T_+-||, ||T_-+||)
  (4) Reduces Lemma 2 to: bound max(||T_+-||, ||T_-+||) for T_tail explicitly

The KEY new analytical question (Lemma 2-prime):
  PROVE  max(||(T_tail)_+-||_op, ||(T_tail)_-+||_op)  <=  C * 2^D * exp(- c_1 D + 2 sqrt(D))
  for some explicit C, c_1 > 0.  Then gamma_D <= O(e^{- c_1 D + log(M(D)/e^{2sqrt(D)})}) = O(e^{- c_1 D}).
"""
import math
import numpy as np
import numpy.linalg as la
from itertools import combinations


sx = np.array([[0, 1], [1, 0]], dtype=complex)
sz = np.array([[1, 0], [0, -1]], dtype=complex)
I2 = np.eye(2, dtype=complex)


def kron_chain(ops):
    out = ops[0]
    for op in ops[1:]:
        out = np.kron(out, op)
    return out


def chi_S(D, S):
    return kron_chain([sz if a in S else I2 for a in range(D)])


def chi_a_Cliff(D, a):
    ops = [sz] * a + [sx] + [I2] * (D - 1 - a)
    return kron_chain(ops)


def D_sub_matrix(D):
    out = chi_a_Cliff(D, 0)
    for a in range(1, D):
        out = out + chi_a_Cliff(D, a)
    return out


def op_norm(M):
    return float(la.norm(M, ord=2))


def build_T_tail(D, kD):
    """T_tail = sum_{|S| > kD} chi_S."""
    dim = 1 << D
    out = np.zeros((dim, dim), dtype=complex)
    for size in range(kD + 1, D + 1):
        for S in combinations(range(D), size):
            out = out + chi_S(D, frozenset(S))
    return out


def spectral_projectors(D_mat, D):
    """Return P_+, P_- onto +sqrt(D), -sqrt(D) eigenspaces of D_sub."""
    sqrtD = math.sqrt(D)
    P_plus = 0.5 * (np.eye(D_mat.shape[0], dtype=complex) + D_mat / sqrtD)
    P_minus = 0.5 * (np.eye(D_mat.shape[0], dtype=complex) - D_mat / sqrtD)
    return P_plus, P_minus


def stirling_M(D, j_max=60):
    """M(D) = sup_{j>=1} (2 sqrt(D))^j / j! ."""
    sqrtD = math.sqrt(D)
    base = 2.0 * sqrtD
    best = 0.0
    val = 1.0
    for j in range(1, j_max + 1):
        val *= base / j
        if val > best:
            best = val
    return best


def main():
    print("=" * 90)
    print("  Computation 63 -- Lemma 2 reduction via D_sub eigenbasis")
    print("=" * 90)
    print()

    Ds = [4, 5, 6, 7, 8]
    results = []

    for D in Ds:
        kD = int(math.floor(math.sqrt(D)))
        D_mat = D_sub_matrix(D)

        # Verify D_sub^2 = D I
        sq = D_mat @ D_mat
        assert np.allclose(sq, D * np.eye(1 << D)), f"D_sub^2 != D I at D={D}"

        # Spectral projectors
        Pp, Pm = spectral_projectors(D_mat, D)

        # Sanity: P_+ + P_- = I, P_+ P_- = 0, P_+ D_sub = sqrt(D) P_+
        sqrtD = math.sqrt(D)
        assert np.allclose(Pp + Pm, np.eye(1 << D))
        assert np.allclose(Pp @ Pm, 0)
        assert np.allclose(Pp @ D_mat, sqrtD * Pp)

        # Build T_tail
        T_tail = build_T_tail(D, kD)
        T_norm = op_norm(T_tail)

        # Off-diagonal blocks
        T_pp = Pp @ T_tail @ Pp
        T_pm = Pp @ T_tail @ Pm
        T_mp = Pm @ T_tail @ Pp
        T_mm = Pm @ T_tail @ Pm

        # Sanity: T_pp + T_pm + T_mp + T_mm = T_tail
        assert np.allclose(T_pp + T_pm + T_mp + T_mm, T_tail)

        T_pp_n = op_norm(T_pp)
        T_pm_n = op_norm(T_pm)
        T_mp_n = op_norm(T_mp)
        T_mm_n = op_norm(T_mm)
        off_max = max(T_pm_n, T_mp_n)

        # Verify ad action: ad(T_pm) = 2sqrt(D) T_pm
        ad_T_pm = D_mat @ T_pm - T_pm @ D_mat
        rel_err = la.norm(ad_T_pm - 2 * sqrtD * T_pm) / max(la.norm(ad_T_pm), 1e-12)

        # ad(T_pp) should be 0
        ad_T_pp = D_mat @ T_pp - T_pp @ D_mat
        diag_err = la.norm(ad_T_pp)

        # Verify ad^j(T_tail) norm formula
        cur = T_tail.copy()
        ad_norms = []
        for j in range(1, 8):
            cur = D_mat @ cur - cur @ D_mat
            ad_norms.append(op_norm(cur))

        # Predicted: (2 sqrt(D))^j * max(||T_pm||, ||T_mp||)
        predicted = [(2 * sqrtD) ** j * off_max for j in range(1, 8)]
        # Verify the prediction matches numerically
        match = [abs(ad_norms[j-1] - predicted[j-1]) / max(predicted[j-1], 1e-12)
                 for j in range(1, 8)]
        max_match_err = max(match)

        # Stirling sup
        M_D = stirling_M(D)

        # L_F prediction
        L_F_pred = off_max * M_D

        # Direct L_F measurement (cap j at 60)
        cur2 = T_tail.copy()
        best_LF = 0.0
        val = 1.0
        for j in range(1, 60):
            cur2 = D_mat @ cur2 - cur2 @ D_mat
            val_j = op_norm(cur2)
            ratio = val_j / math.factorial(j)
            if ratio > best_LF:
                best_LF = ratio
            if ratio < 1e-20:
                break

        # gamma_D = ||T_tail|| / L_F
        gamma_pred = T_norm / L_F_pred
        gamma_direct = T_norm / best_LF if best_LF > 0 else float('inf')

        # Empirical fit: 0.37 * e^(-0.28 D)
        gamma_emp = 0.37 * math.exp(-0.28 * D)

        results.append((D, kD, T_norm, T_pm_n, T_mp_n, off_max,
                        gamma_pred, gamma_direct, gamma_emp, max_match_err, diag_err))

        print(f"  D = {D},  k_D = {kD}:")
        print(f"    ||T_tail||_op = {T_norm:.4f}")
        print(f"    ||T_++||_op   = {T_pp_n:.4f}  (diagonal block)")
        print(f"    ||T_--||_op   = {T_mm_n:.4f}  (diagonal block)")
        print(f"    ||T_+-||_op   = {T_pm_n:.4f}  (off-diagonal)")
        print(f"    ||T_-+||_op   = {T_mp_n:.4f}  (off-diagonal)")
        print(f"    max(||T_+-||, ||T_-+||) = {off_max:.4f}")
        print(f"    ratio off/total = {off_max / T_norm:.4f}")
        print()
        print(f"    ad(T_+-) eigenvalue = 2 sqrt(D) = {2 * sqrtD:.4f}, ", end="")
        print(f"rel err = {rel_err:.2e}")
        print(f"    ad(T_++) = 0, residual = {diag_err:.2e}")
        print(f"    ad^j(T_tail) prediction match (max rel err) = {max_match_err:.2e}")
        print()
        print(f"    M(D) = sup (2 sqrt(D))^j / j! = {M_D:.4f}")
        print(f"    L_F prediction = off_max * M(D) = {L_F_pred:.4f}")
        print(f"    L_F directly measured = {best_LF:.4f}")
        print(f"    gamma_D (predicted) = {gamma_pred:.4e}")
        print(f"    gamma_D (direct)    = {gamma_direct:.4e}")
        print(f"    gamma_D (empirical 0.37 e^(-0.28 D)) = {gamma_emp:.4e}")
        print()

    # Tabulate off-diagonal scaling
    print("=" * 90)
    print("  Off-diagonal norm scaling (Lemma 2 prime)")
    print("=" * 90)
    print()
    print(f"  {'D':>3}  {'||T_tail||':>11}  {'off_max':>10}  {'ratio':>8}  "
          f"{'log_2(off_max)':>14}  {'rate / D':>10}")
    for D, kD, T_norm, T_pm_n, T_mp_n, off_max, *_ in results:
        log2_off = math.log2(off_max) if off_max > 0 else float('-inf')
        rate = log2_off * math.log(2) / D  # natural log per D
        print(f"  {D:>3}  {T_norm:>11.4f}  {off_max:>10.4f}  {off_max / T_norm:>8.4f}  "
              f"{log2_off:>14.4f}  {rate:>10.4f}")

    print()
    print("  The 'rate' column is (ln off_max) / D, which should approach a constant")
    print("  c_off such that off_max ~ exp(c_off * D).  Expected from empirical")
    print("  gamma_D ~ e^{-0.28 D}:  c_off = log(2) - 0.28 + 2/sqrt(D) - O(log(D)/D)")
    print("  ~ 0.97 - 0.28 = 0.69 at leading order (i.e. log 2).")
    print()

    print("=" * 90)
    print("  Conclusion -- Lemma 2 strategy")
    print("=" * 90)
    print()
    print("  The block-decomposition identity")
    print("    ||ad_D^j(T_tail)||_op = (2 sqrt(D))^j * max(||T_pm||, ||T_mp||)")
    print("  is VERIFIED to machine precision for all tested D.  Lemma 2")
    print("  therefore reduces to the explicit estimation of the off-diagonal")
    print("  block norm max(||(T_tail)_pm||_op, ||(T_tail)_mp||_op).")
    print()
    print("  This is a concrete operator-norm bound on a specific structured operator:")
    print("    T_tail = sum_{|S| > k_D} chi_S   (Walsh-mode tail)")
    print("  and P_+ = (I + D_sub / sqrt(D)) / 2  with  D_sub = sum_a chi_a^Cliff.")
    print()
    print("  This eliminates the open analytical question 'do higher-order")
    print("  commutators interfere?' -- the interference is EXACTLY captured")
    print("  by the off-diagonal block.  The remaining task is closed-form or")
    print("  Fourier-on-Boolean-cube estimation of the off-diagonal block norm.")


if __name__ == "__main__":
    main()
