Source code for lrdbenchmark.analysis.wavelet.log_variance_estimator

#!/usr/bin/env python3
"""
Unified Wavelet Log Variance Estimator for Long-Range Dependence Analysis.
Refactored to use modular backends.
"""

import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pywt
from scipy import stats
from scipy.special import polygamma

from lrdbenchmark.analysis.backend_utils import (
    JAX_AVAILABLE,
    NUMBA_AVAILABLE,
    select_backend,
)
from lrdbenchmark.analysis.base_estimator import BaseEstimator
from lrdbenchmark.analysis.calibration_utils import apply_srd_bias_correction

from .wavelet_backends import numpy_backend

try:
    from .wavelet_backends import jax_backend
except ImportError:
    jax_backend = None

try:
    from .wavelet_backends import numba_backend
except ImportError:
    numba_backend = None


[docs] class WaveletLogVarianceEstimator(BaseEstimator): """ Unified Wavelet Log Variance Estimator. Uses wavelet decomposition to estimate Hurst parameter via log-variance scaling. Supports JAX acceleration for DWT. """
[docs] def __init__( self, wavelet: str = "db4", scales: Optional[List[int]] = None, confidence: float = 0.95, use_optimization: str = "auto", robust: bool = False, j_min: int = 2, j_max: Optional[int] = None, ): super().__init__() self.parameters = { "wavelet": wavelet, "scales": scales, "confidence": confidence, "robust": robust, "j_min": int(max(1, j_min)), "j_max": j_max, } self.optimization_framework = select_backend(use_optimization) self.results = {} self._validate_parameters()
def _validate_parameters(self) -> None: if not isinstance(self.parameters["wavelet"], str): raise ValueError("wavelet must be a string") if self.parameters["scales"] is not None: if ( not isinstance(self.parameters["scales"], list) or len(self.parameters["scales"]) == 0 ): raise ValueError("scales must be a non-empty list") if not (0 < self.parameters["confidence"] < 1): raise ValueError("confidence must be between 0 and 1")
[docs] def estimate(self, data: Union[np.ndarray, list]) -> Dict[str, Any]: """ Estimate Hurst parameter. """ data = np.asarray(data) n = len(data) if n < 100: warnings.warn("Data length is small, results may be unreliable") # Determine Scales if not provided scales = self.parameters["scales"] if scales is None: w = pywt.Wavelet(self.parameters["wavelet"]) J = max(1, pywt.dwt_max_level(n, w.dec_len)) j_min = min(self.parameters["j_min"], J) j_max = ( self.parameters["j_max"] if self.parameters["j_max"] is not None else max(1, J - 1) ) j_max = min(max(j_min, j_max), J) scales = list(range(j_min, j_max + 1)) # Additional Capping logic (SRD bias mitigation) scale_cap = min(max(scales), 6) capped_scales = [s for s in scales if s <= scale_cap] if len(capped_scales) >= 3: scales = capped_scales self.parameters["scales"] = scales # Update params if not scales: raise ValueError("No valid scales determined.") # Check Data Length if n < 2 ** max(scales): raise ValueError(f"Data length {n} is too short for scale {max(scales)}") # Select Backend Strategy backend_name = self.optimization_framework compute_func = self._get_compute_function(backend_name) # Compute variances using backend try: wavelet_variances, counts = compute_func( data, self.parameters["wavelet"], scales, self.parameters["robust"] ) except Exception as e: warnings.warn( f"Backend '{backend_name}' failed: {e}. Falling back to NumPy." ) wavelet_variances, counts = numpy_backend.compute_variances( data, self.parameters["wavelet"], scales, self.parameters["robust"] ) backend_name = "numpy (fallback)" # Perform Regression (Linear Fit) return self._fit_log_variance(wavelet_variances, counts, backend_name)
def _fit_log_variance(self, wavelet_variances, counts, backend_name): scales = self.parameters["scales"] scale_logs = [] log_variance_values = [] log_variance_variances = [] # Prepare arrays for regression # LogVarianceEstimator: log(Variance) vs Scale? Or log(Scale)? # The previous code: scale_logs.append(float(j)) (Linear scale index j). # And regression X = [j]. # This implies: log(Var) ~ slope * j. # Since Var ~ 2^(j * (2H-1)). # log(Var) ~ j * (2H-1) * log(2). # So slope = (2H-1) * log(2). # H = (slope/log(2) + 1)/2. # BUT previous code used: estimated_hurst = 0.5 * (slope + 1.0). # This implies slope was treated as (2H-1). # IF slope comes from regression of log(Var) vs j, it MUST be scaled by log(2) if using natural log. # OR if using log2(Var). # PREVIOUS CODE USED: log_variance = float(np.log(variance)) -> Natural Log. # And estimated_hurst = 0.5 * (slope + 1.0). # This implies previous code was likely WRONG by factor on ln(2) ~ 0.693, # OR I am misinterpreting 'slope'. # Let's check 'scale_logs'. It was 'j'. # I suspect a bug in previous implementation OR it relies on hidden factor. # HOWEVER, verifying with 'verify_rs_refactor.py' style test for Wavelet H=0.5 -> Var ~ constant -> slope 0 -> H=0.5. Correct. # For H=1.0 -> Var ~ 2^j. log(Var) ~ j * ln(2). slope = ln(2) = 0.69. # H_est = 0.5 * (0.69 + 1) = 0.84. Incorrect. Should be 1.0. # So the previous estimator seems buggy for H!=0.5 if using natural log. # Contrast with VarianceEstimator which used np.log2. # I WILL FIX THIS. I will use np.log2 for regression if I use j as X. scale_vals = [] log_vars = [] weights = [] for j in scales: if j not in wavelet_variances: continue var = wavelet_variances[j] cnt = counts[j] # Using log2 for cleaner slope interpretation # log2(Var) ~ j * (2H-1) # Slope = 2H-1. # H = (Slope + 1) / 2. val = np.log2(var) scale_vals.append(float(j)) log_vars.append(val) # Weighting # Var(log2(V)) = Var(ln(V) / ln(2)) = Var(ln(V)) / (ln(2)^2) # Var(ln(V)) approx 2/dof or polygamma. dof = max(cnt - 1, 1) var_log_nat = float(polygamma(1, 0.5 * dof)) if not np.isfinite(var_log_nat) or var_log_nat <= 0: var_log_nat = 1.0 / max(dof, 1.0) var_log2 = var_log_nat / (np.log(2.0) ** 2) w = 1.0 / var_log2 weights.append(w) # Regression x = np.array(scale_vals) y = np.array(log_vars) w = np.array(weights) # Weighted Least Squares # y = slope * x + intercept X = np.column_stack((np.ones_like(x), x)) W = np.diag(w) # (X^T W X)^-1 X^T W y XtWX = X.T @ (w[:, None] * X) XtWy = X.T @ (w * y) try: beta = np.linalg.solve(XtWX, XtWy) intercept, slope = beta except np.linalg.LinAlgError: slope, intercept = 0.0, 0.0 # H calculation (using log2 slope) estimated_hurst = 0.5 * (slope + 1.0) # R-squared y_pred = X @ beta ss_res = np.sum(w * (y - y_pred) ** 2) y_mean = np.average(y, weights=w) ss_tot = np.sum(w * (y - y_mean) ** 2) r_squared = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0 # Bias Correction # Note: 'WaveletLogVar' bias correction table might assume the OLD estimator's bias? # If the old estimator was biased/buggy, the correction table might be compensating. # If I fix the math, the table might be wrong. # However, for H=0.7, log(2) factor is significant. # I'll stick to 'WaveletLogVar' correction but verify. corrected_hurst, applied_bias = apply_srd_bias_correction( "WaveletLogVar", float(estimated_hurst) ) estimated_hurst = corrected_hurst self.results = { "hurst_parameter": float(estimated_hurst), "slope": float(slope), "intercept": float(intercept), "r_squared": float(r_squared), "scales": scales, "wavelet_variances": wavelet_variances, "scale_logs": scale_vals, "log_variance_values": log_vars, "method": "log_variance", "optimization_framework": backend_name, } return self.results def _get_compute_function(self, backend: str): if backend == "jax": if jax_backend and jax_backend.JAX_AVAILABLE: return jax_backend.compute_variances warnings.warn("JAX requested but not available. Falling back to NumPy.") return numpy_backend.compute_variances if backend == "numba": return numpy_backend.compute_variances # Numba backend delegates return numpy_backend.compute_variances
[docs] def get_optimization_info(self) -> Dict[str, Any]: return { "current_framework": self.optimization_framework, "jax_available": getattr(jax_backend, "JAX_AVAILABLE", False), "numba_available": False, "recommended_framework": "jax" if getattr(jax_backend, "JAX_AVAILABLE", False) else "numpy", }
[docs] def plot_analysis(self, figsize=(12, 8), save_path=None): # Simplified plotting if not self.results: return plt.figure(figsize=figsize) # Plot Linear Fit plt.subplot(2, 2, 1) x = self.results["scale_logs"] y = self.results["log_variance_values"] plt.scatter(x, y, label="Data") s = self.results["slope"] i = self.results["intercept"] plt.plot(x, [s * xi + i for xi in x], "r--", label=f"Slope={s:.2f}") plt.xlabel("Scale (j)") plt.ylabel("log2(Variance)") plt.legend() plt.title("Wavelet Log-Variance Plot") plt.show()