#!/usr/bin/env python3
"""
research_a_round_trip.py  --  self-contained for Pyodide

PST analog of Bhattacharyya-Singla 2022 Lemma 3.2 (the L_round
section-round-trip lemma).
"""

# === INLINED HELPERS ===
#!/usr/bin/env python3
"""
walsh_helpers.py
================
Common helpers for the Walsh-bridge research:

  - fast Walsh-Hadamard transform (WHT) for converting between diagonal
    functions f(x) on the Boolean hypercube {0,1}^D and their Walsh
    expansion coefficients f_S, with O(D * 2^D) cost;
  - enumeration of Walsh modes by weight;
  - diagonal-vector representation of chi_S for any S;
  - operator norm and Hilbert-Schmidt norm via diagonal-only arithmetic
    (valid because all chi_S commute and are diagonal in the computational
    basis, so any polynomial in {chi_S} is also diagonal).
"""

import math
from itertools import combinations
from typing import Dict, Iterable, Tuple

import numpy as np


def walsh_diag(D: int, S: Tuple[int, ...]) -> np.ndarray:
    """Diagonal of the Walsh mode chi_S on H_D = C^{2^D}, length 2^D, +/- 1."""
    mask = sum(1 << a for a in S)
    out = np.empty(1 << D, dtype=np.int8)
    # vectorise the popcount-mod-2 across all x at once
    xs = np.arange(1 << D, dtype=np.int64)
    pop = np.zeros(1 << D, dtype=np.int8)
    m = xs & mask
    # popcount mod 2 via bitwise XOR-fold
    while True:
        pop ^= (m & 1).astype(np.int8)
        m >>= 1
        if not np.any(m):
            break
    out[:] = 1 - 2 * pop
    return out


def diag_to_walsh(D: int, diag: np.ndarray) -> np.ndarray:
    """In-place fast Walsh-Hadamard transform.

    Input:  length-2^D real diagonal vector f, where f[x] = sum_S c_S chi_S(x).
    Output: length-2^D vector of coefficients c_S, indexed by S as integer S = sum 2^a.

    Normalization: c_S = (1/2^D) sum_x f(x) chi_S(x).  WHT below divides by 2^D
    at the end.
    """
    a = diag.astype(np.float64).copy()
    h = 1
    n = 1 << D
    while h < n:
        for i in range(0, n, h * 2):
            x = a[i : i + h].copy()
            y = a[i + h : i + 2 * h].copy()
            a[i : i + h] = x + y
            a[i + h : i + 2 * h] = x - y
        h *= 2
    a /= n
    return a


def walsh_to_diag(D: int, coefs: np.ndarray) -> np.ndarray:
    """Inverse WHT (and inverse normalisation): coefficients -> diagonal."""
    a = coefs.astype(np.float64).copy()
    h = 1
    n = 1 << D
    while h < n:
        for i in range(0, n, h * 2):
            x = a[i : i + h].copy()
            y = a[i + h : i + 2 * h].copy()
            a[i : i + h] = x + y
            a[i + h : i + 2 * h] = x - y
        h *= 2
    return a


def coef_dict_to_array(D: int, coefs_by_S: Dict[Tuple[int, ...], float]) -> np.ndarray:
    """Pack a {S -> c_S} dict into a length-2^D coefficient array indexed by mask."""
    arr = np.zeros(1 << D, dtype=np.float64)
    for S, c in coefs_by_S.items():
        mask = sum(1 << a for a in S)
        arr[mask] = c
    return arr


def array_to_coef_dict(D: int, arr: np.ndarray, tol: float = 1e-12) -> Dict[Tuple[int, ...], float]:
    """Reverse: extract non-zero coefficients keyed by S tuple."""
    out: Dict[Tuple[int, ...], float] = {}
    for mask in range(1 << D):
        if abs(arr[mask]) > tol:
            S = tuple(a for a in range(D) if (mask >> a) & 1)
            out[S] = float(arr[mask])
    return out


def weight_of_mask(mask: int) -> int:
    return bin(mask).count("1")


def truncate_walsh(D: int, coefs_arr: np.ndarray, k_max: int) -> np.ndarray:
    """Zero out coefficients of Walsh modes with |S| > k_max."""
    out = coefs_arr.copy()
    for mask in range(1 << D):
        if weight_of_mask(mask) > k_max:
            out[mask] = 0.0
    return out


def op_norm_diag(diag: np.ndarray) -> float:
    """Operator norm of a diagonal operator on H_D = C^{2^D}."""
    return float(np.max(np.abs(diag)))


def hs_norm_diag(diag: np.ndarray) -> float:
    """Hilbert-Schmidt norm of a diagonal operator: sqrt(sum |entry|^2)."""
    return float(np.sqrt(np.sum(np.abs(diag) ** 2)))



# === MAIN SCRIPT ===
#!/usr/bin/env python3
"""
walsh_round_trip2.py
=====================
Faster, deeper version of walsh_round_trip.py.

Uses the fast Walsh-Hadamard transform in walsh_helpers.py to compute the
section round-trip residual in O(D * 2^D) time instead of O(2^(2D)).

Adds:
  - Exponential-weight Lip-norm candidates L_exp_c (weight e^{c|S|}),
  - A random-Walsh-coefficient stress test (uniform |c| <= 1, signs random),
  - Sweeps D in {4, 6, 8, 10, 12, 14, 16, 18, 20}.

Outputs:
  - per-D ratios for each Lip-norm candidate, across three test classes,
  - log-log slope of worst-case ratio (gamma_D vs D) per Lip-norm.

Targets:
  - Polynomial Lip-norms (L_sup1, ..., L_sob_2): test whether ANY polynomial
    weight kills the stress-test growth.  Based on walsh_round_trip.py
    partial run: no polynomial weight wins.
  - Exponential Lip-norms (L_exp_c for c = log 2, 1, 2): test whether
    EXPONENTIAL weighting fixes the class-(b) blow-up.

  If exponential weighting works, L_round is provable but requires a
  Lip-norm that grows faster than any polynomial in the Walsh weight.
"""

import math
import sys
import os
import time
from itertools import combinations
import numpy as np



D_LIST = [4, 6, 8, 10, 12, 14, 16, 18, 20]


# ---------- Lip-norms acting on a length-2^D coefficient ARRAY (mask-indexed) ----------
def L_sup_poly(coefs_arr: np.ndarray, D: int, alpha: float) -> float:
    """sup_S (1+|S|)^alpha * |c_S|."""
    out = 0.0
    for mask in range(1 << D):
        c = coefs_arr[mask]
        if c != 0:
            v = (weight_of_mask(mask) + 1) ** alpha * abs(c)
            if v > out:
                out = v
    return out


def L_sob_poly(coefs_arr: np.ndarray, D: int, p: float) -> float:
    """sqrt( sum_S (1+|S|)^{2p} |c_S|^2 )."""
    total = 0.0
    for mask in range(1 << D):
        c = coefs_arr[mask]
        if c != 0:
            total += (weight_of_mask(mask) + 1) ** (2 * p) * c * c
    return math.sqrt(total)


def L_sup_exp(coefs_arr: np.ndarray, D: int, c: float) -> float:
    """sup_S e^{c|S|} |c_S|."""
    out = 0.0
    for mask in range(1 << D):
        v_coef = coefs_arr[mask]
        if v_coef != 0:
            v = math.exp(c * weight_of_mask(mask)) * abs(v_coef)
            if v > out:
                out = v
    return out


def L_sob_exp(coefs_arr: np.ndarray, D: int, c: float) -> float:
    """sqrt( sum_S e^{2c|S|} |c_S|^2 )."""
    total = 0.0
    for mask in range(1 << D):
        v_coef = coefs_arr[mask]
        if v_coef != 0:
            total += math.exp(2 * c * weight_of_mask(mask)) * v_coef * v_coef
    return math.sqrt(total)


# ---------- Tail diagonal via WHT ----------
def tail_diag(coefs_arr: np.ndarray, D: int, k_D: int) -> np.ndarray:
    """Diagonal of T_tail = sum_{|S|>k_D} c_S chi_S, via inverse WHT."""
    tail_coefs = coefs_arr.copy()
    for mask in range(1 << D):
        if weight_of_mask(mask) <= k_D:
            tail_coefs[mask] = 0.0
    return walsh_to_diag(D, tail_coefs)


# ---------- Test classes ----------
def class_a_single_mode(D: int, k: int) -> np.ndarray:
    """T = chi_S where S = {0, ..., k-1}."""
    coefs = np.zeros(1 << D, dtype=np.float64)
    mask = (1 << k) - 1
    coefs[mask] = 1.0
    return coefs


def class_b_constant_tail(D: int, k_D: int) -> np.ndarray:
    """T = sum_{|S|>k_D} chi_S."""
    coefs = np.zeros(1 << D, dtype=np.float64)
    for mask in range(1 << D):
        if weight_of_mask(mask) > k_D:
            coefs[mask] = 1.0
    return coefs


def class_c_random(D: int, k_D: int, rng: np.random.Generator) -> np.ndarray:
    """Random Walsh coefficients with magnitudes uniform in [0,1], random signs."""
    coefs = rng.uniform(-1.0, 1.0, size=1 << D)
    return coefs


# ---------- All Lip-norms in a dict ----------
LIP_LABELS = [
    "L_sup1", "L_sup2", "L_sup3",
    "L_sob1", "L_sob2",
    "L_supe_05", "L_supe_1", "L_supe_2",
    "L_sobe_05", "L_sobe_1", "L_sobe_2",
]


def all_lips(coefs_arr: np.ndarray, D: int) -> dict:
    return {
        "L_sup1": L_sup_poly(coefs_arr, D, 1.0),
        "L_sup2": L_sup_poly(coefs_arr, D, 2.0),
        "L_sup3": L_sup_poly(coefs_arr, D, 3.0),
        "L_sob1": L_sob_poly(coefs_arr, D, 1.0),
        "L_sob2": L_sob_poly(coefs_arr, D, 2.0),
        "L_supe_05": L_sup_exp(coefs_arr, D, 0.5),
        "L_supe_1": L_sup_exp(coefs_arr, D, 1.0),
        "L_supe_2": L_sup_exp(coefs_arr, D, 2.0),
        "L_sobe_05": L_sob_exp(coefs_arr, D, 0.5),
        "L_sobe_1": L_sob_exp(coefs_arr, D, 1.0),
        "L_sobe_2": L_sob_exp(coefs_arr, D, 2.0),
    }


# ---------- Main sweep ----------
def main():
    print("=" * 100)
    print("  walsh_round_trip2.py  --  empirical L_round via fast WHT")
    print("=" * 100)
    print()
    rng = np.random.default_rng(20260531)

    results = {
        D: {"a": {}, "b": {}, "c": {}} for D in D_LIST
    }

    for D in D_LIST:
        k_D = int(math.floor(math.sqrt(D)))
        t0 = time.time()

        # Class a: single modes
        worst_a = {label: 0.0 for label in LIP_LABELS}
        for k in range(D + 1):
            coefs = class_a_single_mode(D, k)
            diag = tail_diag(coefs, D, k_D)
            op_n = op_norm_diag(diag)
            if op_n == 0:
                continue
            lips = all_lips(coefs, D)
            for label in LIP_LABELS:
                L = lips[label]
                if L > 0:
                    r = op_n / L
                    if r > worst_a[label]:
                        worst_a[label] = r
        results[D]["a"] = worst_a

        # Class b: constant tail
        coefs_b = class_b_constant_tail(D, k_D)
        diag_b = tail_diag(coefs_b, D, k_D)
        op_b = op_norm_diag(diag_b)
        lips_b = all_lips(coefs_b, D)
        worst_b = {label: (op_b / lips_b[label] if lips_b[label] > 0 else 0.0)
                   for label in LIP_LABELS}
        results[D]["b"] = worst_b

        # Class c: 5 random samples; track worst-case per Lip-norm
        worst_c = {label: 0.0 for label in LIP_LABELS}
        for trial in range(5):
            coefs_c = class_c_random(D, k_D, rng)
            diag_c = tail_diag(coefs_c, D, k_D)
            op_c = op_norm_diag(diag_c)
            lips_c = all_lips(coefs_c, D)
            for label in LIP_LABELS:
                L = lips_c[label]
                if L > 0:
                    r = op_c / L
                    if r > worst_c[label]:
                        worst_c[label] = r
        results[D]["c"] = worst_c

        dt = time.time() - t0
        print(f"  D = {D:>2}, k_D = {k_D}: {dt:.2f} s")

    print()
    print("=" * 100)
    print("  Worst-case ratio per class, per Lip-norm, per D")
    print("=" * 100)

    for cls in ("a", "b", "c"):
        print()
        print(f"  Class ({cls})")
        header = ["D"] + LIP_LABELS
        widths = [3] + [10] * len(LIP_LABELS)
        print("  " + " ".join(h.rjust(w) for h, w in zip(header, widths)))
        print("  " + " ".join("-" * w for w in widths))
        for D in D_LIST:
            row = [str(D)] + [f"{results[D][cls][label]:.3e}" for label in LIP_LABELS]
            print("  " + " ".join(c.rjust(w) for c, w in zip(row, widths)))

    print()
    print("=" * 100)
    print("  Worst-case ratio across ALL classes (a, b, c) per D, per Lip-norm")
    print("=" * 100)
    header = ["D"] + LIP_LABELS
    widths = [3] + [10] * len(LIP_LABELS)
    print("  " + " ".join(h.rjust(w) for h, w in zip(header, widths)))
    print("  " + " ".join("-" * w for w in widths))
    for D in D_LIST:
        worst = {label: max(results[D]["a"][label], results[D]["b"][label], results[D]["c"][label])
                 for label in LIP_LABELS}
        row = [str(D)] + [f"{worst[label]:.3e}" for label in LIP_LABELS]
        print("  " + " ".join(c.rjust(w) for c, w in zip(row, widths)))

    print()
    print("=" * 100)
    print("  Log-log slope of worst-case ratio vs D (gamma_D ~ D^slope)")
    print("=" * 100)
    print()
    print(f"  {'Lip-norm':>10} {'slope':>10}     meaning")
    print(f"  {'-'*10} {'-'*10}    " + "-" * 70)
    for label in LIP_LABELS:
        worst = [max(results[D]["a"][label], results[D]["b"][label], results[D]["c"][label])
                 for D in D_LIST]
        log_D = np.log(np.array(D_LIST, dtype=float))
        log_r = np.log(np.array([max(r, 1e-30) for r in worst]))
        s, _ = np.polyfit(log_D, log_r, 1)
        if s <= -0.45:
            verdict = "OK: gamma_D -> 0 at rate <= D^{-0.45}"
        elif s <= 0:
            verdict = "OK: gamma_D -> 0 but slower than D^{-0.45}"
        else:
            verdict = f"FAIL: gamma_D grows with D (slope +{s:.2f})"
        print(f"  {label:>10} {s:>10.3f}    {verdict}")

    print()


if __name__ == "__main__":
    main()
