#!/usr/bin/env python3
"""
Unified Wavelet Log Variance Estimator for Long-Range Dependence Analysis.
Refactored to use modular backends.
"""
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import pywt
from scipy import stats
from scipy.special import polygamma
from lrdbenchmark.analysis.backend_utils import (
JAX_AVAILABLE,
NUMBA_AVAILABLE,
select_backend,
)
from lrdbenchmark.analysis.base_estimator import BaseEstimator
from lrdbenchmark.analysis.calibration_utils import apply_srd_bias_correction
from .wavelet_backends import numpy_backend
try:
from .wavelet_backends import jax_backend
except ImportError:
jax_backend = None
try:
from .wavelet_backends import numba_backend
except ImportError:
numba_backend = None
[docs]
class WaveletLogVarianceEstimator(BaseEstimator):
"""
Unified Wavelet Log Variance Estimator.
Uses wavelet decomposition to estimate Hurst parameter via log-variance scaling.
Supports JAX acceleration for DWT.
"""
[docs]
def __init__(
self,
wavelet: str = "db4",
scales: Optional[List[int]] = None,
confidence: float = 0.95,
use_optimization: str = "auto",
robust: bool = False,
j_min: int = 2,
j_max: Optional[int] = None,
):
super().__init__()
self.parameters = {
"wavelet": wavelet,
"scales": scales,
"confidence": confidence,
"robust": robust,
"j_min": int(max(1, j_min)),
"j_max": j_max,
}
self.optimization_framework = select_backend(use_optimization)
self.results = {}
self._validate_parameters()
def _validate_parameters(self) -> None:
if not isinstance(self.parameters["wavelet"], str):
raise ValueError("wavelet must be a string")
if self.parameters["scales"] is not None:
if (
not isinstance(self.parameters["scales"], list)
or len(self.parameters["scales"]) == 0
):
raise ValueError("scales must be a non-empty list")
if not (0 < self.parameters["confidence"] < 1):
raise ValueError("confidence must be between 0 and 1")
[docs]
def estimate(self, data: Union[np.ndarray, list]) -> Dict[str, Any]:
"""
Estimate Hurst parameter.
"""
data = np.asarray(data)
n = len(data)
if n < 100:
warnings.warn("Data length is small, results may be unreliable")
# Determine Scales if not provided
scales = self.parameters["scales"]
if scales is None:
w = pywt.Wavelet(self.parameters["wavelet"])
J = max(1, pywt.dwt_max_level(n, w.dec_len))
j_min = min(self.parameters["j_min"], J)
j_max = (
self.parameters["j_max"]
if self.parameters["j_max"] is not None
else max(1, J - 1)
)
j_max = min(max(j_min, j_max), J)
scales = list(range(j_min, j_max + 1))
# Additional Capping logic (SRD bias mitigation)
scale_cap = min(max(scales), 6)
capped_scales = [s for s in scales if s <= scale_cap]
if len(capped_scales) >= 3:
scales = capped_scales
self.parameters["scales"] = scales # Update params
if not scales:
raise ValueError("No valid scales determined.")
# Check Data Length
if n < 2 ** max(scales):
raise ValueError(f"Data length {n} is too short for scale {max(scales)}")
# Select Backend Strategy
backend_name = self.optimization_framework
compute_func = self._get_compute_function(backend_name)
# Compute variances using backend
try:
wavelet_variances, counts = compute_func(
data, self.parameters["wavelet"], scales, self.parameters["robust"]
)
except Exception as e:
warnings.warn(
f"Backend '{backend_name}' failed: {e}. Falling back to NumPy."
)
wavelet_variances, counts = numpy_backend.compute_variances(
data, self.parameters["wavelet"], scales, self.parameters["robust"]
)
backend_name = "numpy (fallback)"
# Perform Regression (Linear Fit)
return self._fit_log_variance(wavelet_variances, counts, backend_name)
def _fit_log_variance(self, wavelet_variances, counts, backend_name):
scales = self.parameters["scales"]
scale_logs = []
log_variance_values = []
log_variance_variances = []
# Prepare arrays for regression
# LogVarianceEstimator: log(Variance) vs Scale? Or log(Scale)?
# The previous code: scale_logs.append(float(j)) (Linear scale index j).
# And regression X = [j].
# This implies: log(Var) ~ slope * j.
# Since Var ~ 2^(j * (2H-1)).
# log(Var) ~ j * (2H-1) * log(2).
# So slope = (2H-1) * log(2).
# H = (slope/log(2) + 1)/2.
# BUT previous code used: estimated_hurst = 0.5 * (slope + 1.0).
# This implies slope was treated as (2H-1).
# IF slope comes from regression of log(Var) vs j, it MUST be scaled by log(2) if using natural log.
# OR if using log2(Var).
# PREVIOUS CODE USED: log_variance = float(np.log(variance)) -> Natural Log.
# And estimated_hurst = 0.5 * (slope + 1.0).
# This implies previous code was likely WRONG by factor on ln(2) ~ 0.693,
# OR I am misinterpreting 'slope'.
# Let's check 'scale_logs'. It was 'j'.
# I suspect a bug in previous implementation OR it relies on hidden factor.
# HOWEVER, verifying with 'verify_rs_refactor.py' style test for Wavelet H=0.5 -> Var ~ constant -> slope 0 -> H=0.5. Correct.
# For H=1.0 -> Var ~ 2^j. log(Var) ~ j * ln(2). slope = ln(2) = 0.69.
# H_est = 0.5 * (0.69 + 1) = 0.84. Incorrect. Should be 1.0.
# So the previous estimator seems buggy for H!=0.5 if using natural log.
# Contrast with VarianceEstimator which used np.log2.
# I WILL FIX THIS. I will use np.log2 for regression if I use j as X.
scale_vals = []
log_vars = []
weights = []
for j in scales:
if j not in wavelet_variances:
continue
var = wavelet_variances[j]
cnt = counts[j]
# Using log2 for cleaner slope interpretation
# log2(Var) ~ j * (2H-1)
# Slope = 2H-1.
# H = (Slope + 1) / 2.
val = np.log2(var)
scale_vals.append(float(j))
log_vars.append(val)
# Weighting
# Var(log2(V)) = Var(ln(V) / ln(2)) = Var(ln(V)) / (ln(2)^2)
# Var(ln(V)) approx 2/dof or polygamma.
dof = max(cnt - 1, 1)
var_log_nat = float(polygamma(1, 0.5 * dof))
if not np.isfinite(var_log_nat) or var_log_nat <= 0:
var_log_nat = 1.0 / max(dof, 1.0)
var_log2 = var_log_nat / (np.log(2.0) ** 2)
w = 1.0 / var_log2
weights.append(w)
# Regression
x = np.array(scale_vals)
y = np.array(log_vars)
w = np.array(weights)
# Weighted Least Squares
# y = slope * x + intercept
X = np.column_stack((np.ones_like(x), x))
W = np.diag(w)
# (X^T W X)^-1 X^T W y
XtWX = X.T @ (w[:, None] * X)
XtWy = X.T @ (w * y)
try:
beta = np.linalg.solve(XtWX, XtWy)
intercept, slope = beta
except np.linalg.LinAlgError:
slope, intercept = 0.0, 0.0
# H calculation (using log2 slope)
estimated_hurst = 0.5 * (slope + 1.0)
# R-squared
y_pred = X @ beta
ss_res = np.sum(w * (y - y_pred) ** 2)
y_mean = np.average(y, weights=w)
ss_tot = np.sum(w * (y - y_mean) ** 2)
r_squared = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
# Bias Correction
# Note: 'WaveletLogVar' bias correction table might assume the OLD estimator's bias?
# If the old estimator was biased/buggy, the correction table might be compensating.
# If I fix the math, the table might be wrong.
# However, for H=0.7, log(2) factor is significant.
# I'll stick to 'WaveletLogVar' correction but verify.
corrected_hurst, applied_bias = apply_srd_bias_correction(
"WaveletLogVar", float(estimated_hurst)
)
estimated_hurst = corrected_hurst
self.results = {
"hurst_parameter": float(estimated_hurst),
"slope": float(slope),
"intercept": float(intercept),
"r_squared": float(r_squared),
"scales": scales,
"wavelet_variances": wavelet_variances,
"scale_logs": scale_vals,
"log_variance_values": log_vars,
"method": "log_variance",
"optimization_framework": backend_name,
}
return self.results
def _get_compute_function(self, backend: str):
if backend == "jax":
if jax_backend and jax_backend.JAX_AVAILABLE:
return jax_backend.compute_variances
warnings.warn("JAX requested but not available. Falling back to NumPy.")
return numpy_backend.compute_variances
if backend == "numba":
return numpy_backend.compute_variances # Numba backend delegates
return numpy_backend.compute_variances
[docs]
def get_optimization_info(self) -> Dict[str, Any]:
return {
"current_framework": self.optimization_framework,
"jax_available": getattr(jax_backend, "JAX_AVAILABLE", False),
"numba_available": False,
"recommended_framework": "jax"
if getattr(jax_backend, "JAX_AVAILABLE", False)
else "numpy",
}
[docs]
def plot_analysis(self, figsize=(12, 8), save_path=None):
# Simplified plotting
if not self.results:
return
plt.figure(figsize=figsize)
# Plot Linear Fit
plt.subplot(2, 2, 1)
x = self.results["scale_logs"]
y = self.results["log_variance_values"]
plt.scatter(x, y, label="Data")
s = self.results["slope"]
i = self.results["intercept"]
plt.plot(x, [s * xi + i for xi in x], "r--", label=f"Slope={s:.2f}")
plt.xlabel("Scale (j)")
plt.ylabel("log2(Variance)")
plt.legend()
plt.title("Wavelet Log-Variance Plot")
plt.show()