Source code for lrdbenchmark.analysis.wavelet.cwt_estimator

#!/usr/bin/env python3
"""
Unified Continuous Wavelet Transform (CWT) Estimator for Long-Range Dependence Analysis.

This module implements the CWT estimator with automatic optimization framework
selection (JAX, Numba, NumPy) for the best performance on the available hardware.
"""

import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pywt
from scipy import stats

# Import optimization frameworks
try:
    import jax
    import jax.numpy as jnp
    from jax import vmap

    JAX_AVAILABLE = True
except ImportError:
    JAX_AVAILABLE = False

try:
    import numba
    from numba import jit as numba_jit
    from numba import prange

    NUMBA_AVAILABLE = True
except ImportError:
    NUMBA_AVAILABLE = False

# Import base estimator (single source of truth)
from lrdbenchmark.analysis.base_estimator import BaseEstimator
from lrdbenchmark.analysis.calibration_utils import apply_srd_bias_correction


[docs] class CWTEstimator(BaseEstimator): """ Unified Continuous Wavelet Transform (CWT) Estimator for Long-Range Dependence Analysis. This estimator uses continuous wavelet transforms to analyze the scaling behavior of time series data and estimate the Hurst parameter for fractional processes. Features: - Automatic optimization framework selection (JAX, Numba, NumPy) - GPU acceleration with JAX when available - JIT compilation with Numba for CPU optimization - Graceful fallbacks when optimization frameworks fail Parameters ---------- wavelet : str, optional (default='cmor1.5-1.0') Wavelet type for continuous transform scales : np.ndarray, optional (default=None) Array of scales for analysis. If None, uses automatic scale selection confidence : float, optional (default=0.95) Confidence level for confidence intervals use_optimization : str, optional (default='auto') Optimization framework to use: 'auto', 'jax', 'numba', 'numpy' """
[docs] def __init__( self, wavelet: str = "morl", scales: Optional[np.ndarray] = None, confidence: float = 0.95, use_optimization: str = "auto", robust: bool = False, scale_range: Optional[Tuple[float, float]] = None, trim_ends: int = 0, ): super().__init__() # Estimator parameters self.parameters = { "wavelet": wavelet, "scales": scales, "confidence": confidence, "robust": robust, "scale_range": scale_range, "trim_ends": int(max(0, trim_ends)), } # Optimization framework self.optimization_framework = self._select_optimization_framework( use_optimization ) # Results storage self.results = {} # Validation self._validate_parameters()
[docs] def _select_optimization_framework(self, use_optimization: str) -> str: """Select the optimal optimization framework.""" if use_optimization == "auto": if JAX_AVAILABLE: return "jax" # Best for GPU acceleration elif NUMBA_AVAILABLE: return "numba" # Good for CPU optimization else: return "numpy" # Fallback elif use_optimization == "jax" and JAX_AVAILABLE: return "jax" elif use_optimization == "numba" and NUMBA_AVAILABLE: return "numba" else: return "numpy"
[docs] def _validate_parameters(self) -> None: """Validate estimator parameters.""" if not isinstance(self.parameters["wavelet"], str): raise ValueError("wavelet must be a string") if self.parameters["scales"] is not None: if ( not isinstance(self.parameters["scales"], np.ndarray) or len(self.parameters["scales"]) == 0 ): raise ValueError("scales must be a non-empty numpy array") if not (0 < self.parameters["confidence"] < 1): raise ValueError("confidence must be between 0 and 1") if self.parameters["scale_range"] is not None: lo, hi = self.parameters["scale_range"] if not (lo > 0 and hi > lo): raise ValueError("scale_range must satisfy 0 < lo < hi")
[docs] def estimate(self, data: Union[np.ndarray, list]) -> Dict[str, Any]: """ Estimate the Hurst parameter using Continuous Wavelet Transform analysis with automatic optimization. Parameters ---------- data : array-like Input time series data Returns ------- dict Dictionary containing estimation results including: - hurst_parameter: Estimated Hurst parameter - confidence_interval: Confidence interval for the estimate - r_squared: R-squared value of the fit - scales: Scales used in the analysis - wavelet_type: Wavelet type used - slope: Slope of the log-log regression - intercept: Intercept of the log-log regression - scale_powers: Power at each scale """ data = np.asarray(data) n = len(data) if n < 50: raise ValueError("Data length must be at least 50 for CWT analysis") if n < 100: warnings.warn("Data length is small, results may be unreliable") # Select optimal method based on data size and framework if self.optimization_framework == "jax" and JAX_AVAILABLE: try: return self._estimate_jax(data) except Exception as e: warnings.warn(f"JAX implementation failed: {e}, falling back to NumPy") return self._estimate_numpy(data) elif self.optimization_framework == "numba" and NUMBA_AVAILABLE: try: return self._estimate_numba(data) except Exception as e: warnings.warn( f"Numba implementation failed: {e}, falling back to NumPy" ) return self._estimate_numpy(data) else: return self._estimate_numpy(data)
[docs] def _estimate_numpy(self, data: np.ndarray) -> Dict[str, Any]: """NumPy implementation of CWT estimation.""" n = len(data) # Set default scales if not provided if self.parameters["scales"] is None: # Geometric scales roughly covering [2, n/8] s_min = 2 s_max = max(8, int(n // 8)) self.parameters["scales"] = np.unique( (np.geomspace(s_min, s_max, num=24)).astype(float) ) # Adjust scales for shorter data if n < 100: # Use fewer scales for shorter data max_scale = min(max(self.parameters["scales"]), n // 4) self.parameters["scales"] = np.array( [s for s in self.parameters["scales"] if s <= max_scale] ) if len(self.parameters["scales"]) < 2: raise ValueError("Insufficient scales available for data length") # Optionally restrict scale range and trim ends scales = self.parameters["scales"].astype(float) if self.parameters["scale_range"] is not None: lo, hi = self.parameters["scale_range"] mask = (scales >= lo) & (scales <= hi) scales = scales[mask] if ( self.parameters["trim_ends"] > 0 and len(scales) > 2 * self.parameters["trim_ends"] ): t = self.parameters["trim_ends"] scales = scales[t:-t] # Cap maximum scale to reduce SRD bias scale_cap = min(np.max(scales), 64.0) scales = scales[scales <= scale_cap] if len(scales) < 2: raise ValueError("Insufficient scales after trimming/range selection") # Perform continuous wavelet transform wavelet_coeffs, frequencies = pywt.cwt(data, scales, self.parameters["wavelet"]) # Calculate power spectrum (squared magnitude of coefficients) power_spectrum = np.abs(wavelet_coeffs) ** 2 # Calculate average power at each scale scale_powers = {} scale_logs = [] power_logs = [] power_log_variances = [] n_time = power_spectrum.shape[1] for i, scale in enumerate(scales): coeff_row = power_spectrum[i, :] avg_power = np.mean(coeff_row) scale_powers[scale] = avg_power scale_logs.append(np.log2(scale)) power_logs.append(np.log2(avg_power)) var_power = np.var(coeff_row, ddof=1) if not np.isfinite(var_power) or var_power <= 0: var_power = (avg_power**2) / max(n_time, 1) var_mean = var_power / max(n_time, 1) var_log = var_mean / (avg_power**2 * (np.log(2.0) ** 2)) power_log_variances.append(max(var_log, 1e-12)) x = np.asarray(scale_logs, dtype=float) y = np.asarray(power_logs, dtype=float) weights = 1.0 / np.clip( np.asarray(power_log_variances, dtype=float), 1e-12, None ) X = np.column_stack((np.ones_like(x), x)) XtWX = X.T @ (weights[:, None] * X) XtWy = X.T @ (weights * y) if self.parameters["robust"]: slope, intercept = self._huber_regression(x, y) else: beta = np.linalg.solve(XtWX, XtWy) intercept, slope = beta r_squared, slope_se = self._regression_statistics( x, y, weights, slope, intercept, XtWX ) # Empirical mapping consistent with PyWavelets normalization: H ≈ (slope + 1)/2 # This provides low-bias estimates across tested FBM signals estimated_hurst = 0.5 * (slope + 1.0) # Calculate confidence interval confidence_interval = self._get_confidence_interval( estimated_hurst, slope_se, len(scale_logs), ) corrected_hurst, applied_bias = apply_srd_bias_correction( "CWT", float(estimated_hurst) ) if applied_bias != 0.0 and confidence_interval is not None: lower = max(0.01, min(0.99, confidence_interval[0] - applied_bias)) upper = max(0.01, min(0.99, confidence_interval[1] - applied_bias)) confidence_interval = (lower, upper) estimated_hurst = corrected_hurst # Store results self.results = { "hurst_parameter": float(estimated_hurst), "confidence_interval": confidence_interval, "r_squared": float(r_squared), "scales": scales.tolist(), "wavelet_type": self.parameters["wavelet"], "slope": float(slope), "intercept": float(intercept), "scale_powers": scale_powers, "scale_logs": scale_logs, "power_logs": power_logs, "regression_weights": weights.tolist(), "bias_correction": applied_bias, "wavelet_coeffs": wavelet_coeffs, "power_spectrum": power_spectrum, "method": "numpy", "optimization_framework": self.optimization_framework, } return self.results
[docs] def _estimate_numba(self, data: np.ndarray) -> Dict[str, Any]: """Numba-optimized implementation of CWT estimation.""" # For now, use NumPy implementation with Numba JIT compilation # This can be enhanced with custom Numba kernels for specific operations return self._estimate_numpy(data)
[docs] def _estimate_jax(self, data: np.ndarray) -> Dict[str, Any]: """JAX-optimized implementation of Wavelet Log Variance estimation.""" if not JAX_AVAILABLE: return self._estimate_numpy(data) if self.parameters["wavelet"] != "morl": raise NotImplementedError( "JAX CWT currently supports the 'morl' wavelet only" ) data_np = np.asarray(data, dtype=float) n = len(data_np) demeaned = data_np - np.mean(data_np) x = jnp.asarray(demeaned, dtype=jnp.float64) if self.parameters["scales"] is None: s_min = 2 s_max = max(8, int(n // 8)) self.parameters["scales"] = np.unique( (np.geomspace(s_min, s_max, num=24)).astype(float) ) scales = self.parameters["scales"].astype(float) if n < 100: max_scale = min(max(scales), n // 4) scales = np.array([s for s in scales if s <= max_scale]) if len(scales) < 2: raise ValueError("Insufficient scales available for data length") trim = self.parameters["trim_ends"] if trim > 0 and len(scales) > 2 * trim: scales = scales[trim:-trim] if len(scales) < 2: raise ValueError("Insufficient scales after trimming") padded_len = int(2 ** np.ceil(np.log2(n))) padded = jnp.pad(x, (0, padded_len - n)) fft_data = jnp.fft.fft(padded) omega = 2 * jnp.pi * jnp.fft.fftfreq(padded_len, d=1.0) def morlet_fft(scale: float) -> jnp.ndarray: w0 = 6.0 factor = jnp.sqrt(scale) return factor * jnp.exp(-0.5 * (scale * omega - w0) ** 2) def compute_coeff(scale: float) -> jnp.ndarray: wavelet_fft = morlet_fft(scale) coeff = jnp.fft.ifft(fft_data * wavelet_fft) return coeff[:n] coeffs = vmap(compute_coeff)(jnp.asarray(scales, dtype=jnp.float64)) power_spectrum = jnp.abs(coeffs) ** 2 scale_powers = jnp.mean(power_spectrum, axis=1) n_time = power_spectrum.shape[1] scale_logs = jnp.log2(jnp.asarray(scales, dtype=jnp.float64)) power_logs = jnp.log2(scale_powers + 1e-300) variances = jnp.var(power_spectrum, axis=1, ddof=1) variances = jnp.where( variances <= 0, (scale_powers**2) / max(n_time, 1), variances ) var_mean = variances / max(n_time, 1) var_log = var_mean / (scale_powers**2 * (jnp.log(2.0) ** 2)) weights = 1.0 / jnp.clip(var_log, 1e-12, None) X = jnp.stack([jnp.ones_like(scale_logs), scale_logs], axis=1) XtWX = X.T @ (weights[:, None] * X) XtWy = X.T @ (weights * power_logs) beta = jnp.linalg.solve(XtWX, XtWy) intercept, slope = beta y_fit = slope * scale_logs + intercept y_mean = jnp.average(power_logs, weights=weights) ss_res = jnp.sum(weights * (power_logs - y_fit) ** 2) ss_tot = jnp.sum(weights * (power_logs - y_mean) ** 2) r_squared = jnp.where(ss_tot > 0, 1.0 - ss_res / ss_tot, 0.0) estimated_hurst = 0.5 * (slope + 1.0) slope_se = jnp.sqrt(jnp.clip(jnp.linalg.inv(XtWX)[1, 1], 1e-12, None)) confidence_interval = self._get_confidence_interval( float(estimated_hurst), float(slope_se), len(scale_logs), ) corrected_hurst, applied_bias = apply_srd_bias_correction( "CWT", float(estimated_hurst) ) if applied_bias != 0.0 and confidence_interval is not None: lower = max(0.01, min(0.99, confidence_interval[0] - applied_bias)) upper = max(0.01, min(0.99, confidence_interval[1] - applied_bias)) confidence_interval = (lower, upper) estimated_hurst = corrected_hurst power_log_variances = (1.0 / weights).tolist() scale_powers_dict = { float(scales[i]): float(scale_powers[i]) for i in range(len(scales)) } self.results = { "hurst_parameter": float(estimated_hurst), "confidence_interval": confidence_interval, "r_squared": float(r_squared), "scales": scales.tolist(), "wavelet_type": self.parameters["wavelet"], "slope": float(slope), "intercept": float(intercept), "scale_powers": scale_powers_dict, "power_log_variances": power_log_variances, "method": "jax", "optimization_framework": self.optimization_framework, "bias_correction": applied_bias, } return self.results
[docs] def _regression_statistics( self, x: np.ndarray, y: np.ndarray, weights: np.ndarray, slope: float, intercept: float, XtWX: np.ndarray, ) -> Tuple[float, float]: """Compute weighted regression diagnostics for slope.""" residuals = y - (slope * x + intercept) dof = max(len(x) - 2, 1) ss_res = float(np.sum(weights * residuals**2)) y_mean = np.average(y, weights=weights) ss_tot = float(np.sum(weights * (y - y_mean) ** 2)) r_squared = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0 sigma2 = ss_res / dof if dof > 0 else 0.0 if sigma2 < 1e-10: sigma2 = np.mean(1.0 / weights) cov_beta = sigma2 * np.linalg.inv(XtWX) slope_se = float(np.sqrt(max(cov_beta[1, 1], 1e-12))) return r_squared, slope_se
[docs] def _get_confidence_interval( self, estimated_hurst: float, slope_se: float, n_points: int, ) -> Tuple[float, float]: """Calculate confidence interval for the Hurst parameter estimate.""" confidence = self.parameters["confidence"] hurst_se = 0.5 * slope_se dof = max(n_points - 2, 1) t_value = stats.t.ppf((1 + confidence) / 2, df=dof) margin = float(t_value * hurst_se) return (float(estimated_hurst - margin), float(estimated_hurst + margin))
def _huber_regression( self, X: np.ndarray, y: np.ndarray, c: float = 1.345, iters: int = 50, tol: float = 1e-8, ) -> Tuple[float, float]: X1 = np.vstack([X, np.ones_like(X)]).T beta, *_ = np.linalg.lstsq(X1, y, rcond=None) for _ in range(iters): r = y - X1 @ beta s = 1.4826 * np.median(np.abs(r - np.median(r)) + 1e-12) u = r / (s + 1e-12) w = np.clip(c / np.maximum(np.abs(u), 1e-12), 0.0, 1.0) W = np.diag(w) XtWX = X1.T @ W @ X1 XtWy = X1.T @ W @ y beta_new, *_ = np.linalg.lstsq(XtWX, XtWy, rcond=None) if np.linalg.norm(beta_new - beta) < tol: beta = beta_new break beta = beta_new return float(beta[0]), float(beta[1])
[docs] def get_optimization_info(self) -> Dict[str, Any]: """Get information about available optimizations and current selection.""" return { "current_framework": self.optimization_framework, "jax_available": JAX_AVAILABLE, "numba_available": NUMBA_AVAILABLE, "recommended_framework": self._get_recommended_framework(), }
[docs] def plot_analysis( self, data: np.ndarray, figsize: Tuple[int, int] = (15, 10), save_path: Optional[str] = None, ) -> None: """Plot the CWT analysis results.""" if not self.results: raise ValueError("No results available. Run estimate() first.") fig, axes = plt.subplots(2, 3, figsize=figsize) fig.suptitle( f'CWT Analysis - {self.parameters["wavelet"]} Wavelet', fontsize=16 ) # Plot 1: Original time series ax1 = axes[0, 0] ax1.plot(data, alpha=0.7) ax1.set_xlabel("Time") ax1.set_ylabel("Amplitude") ax1.set_title("Original Time Series") ax1.grid(True, alpha=0.3) # Plot 2: Log-log scaling relationship ax2 = axes[0, 1] x = self.results["scale_logs"] y = self.results["power_logs"] ax2.scatter(x, y, s=60, alpha=0.7, label="Data points") # Plot fitted line slope = self.results["slope"] intercept = self.results["intercept"] x_fit = np.linspace(min(x), max(x), 100) y_fit = slope * x_fit + intercept ax2.plot(x_fit, y_fit, "r--", label=f"Linear fit (slope={slope:.3f})") ax2.set_xlabel("log₂(Scale)") ax2.set_ylabel("log₂(Power)") ax2.set_title("CWT Power Scaling") ax2.legend() ax2.grid(True, alpha=0.3) # Plot 3: Power vs Scale (log-log) ax3 = axes[0, 2] scales = self.results["scales"] powers = [self.results["scale_powers"][s] for s in scales] ax3.scatter(scales, powers, s=60, alpha=0.7) ax3.set_xscale("log") ax3.set_yscale("log") ax3.set_xlabel("Scale") ax3.set_ylabel("Power") ax3.set_title("Power vs Scale (log-log)") ax3.grid(True, which="both", ls=":", alpha=0.3) # Plot 4: Hurst parameter estimate ax4 = axes[1, 0] hurst = self.results["hurst_parameter"] conf_interval = self.results["confidence_interval"] ax4.bar( ["Hurst Parameter"], [hurst], yerr=[[hurst - conf_interval[0]], [conf_interval[1] - hurst]], capsize=10, alpha=0.7, color="skyblue", ) ax4.axhline( y=0.5, color="red", linestyle="--", alpha=0.7, label="H=0.5 (no memory)" ) ax4.set_ylabel("Hurst Parameter") ax4.set_title(f"Hurst Parameter Estimate: {hurst:.3f}") ax4.legend() ax4.grid(True, alpha=0.3) # Plot 5: R-squared ax5 = axes[1, 1] r_squared = self.results["r_squared"] ax5.bar(["R²"], [r_squared], alpha=0.7, color="lightgreen") ax5.set_ylabel("R²") ax5.set_title(f"Goodness of Fit: R² = {r_squared:.3f}") ax5.set_ylim(0, 1) ax5.grid(True, alpha=0.3) # Plot 6: Wavelet scalogram (power spectrum) ax6 = axes[1, 2] power_spectrum = self.results["power_spectrum"] scales = self.results["scales"] im = ax6.imshow( power_spectrum, aspect="auto", extent=[0, len(data), min(scales), max(scales)], ) ax6.set_xlabel("Time") ax6.set_ylabel("Scale") ax6.set_title("Wavelet Scalogram") ax6.set_yscale("log") plt.colorbar(im, ax=ax6, label="Power") plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.show()
[docs] def get_method_recommendation(self, n: int) -> Dict[str, Any]: """Get method recommendation for a given data size.""" if n < 50: return { "recommended_method": "numpy", "reasoning": f"Data size n={n} is too small for CWT analysis", "method_details": { "description": "NumPy implementation", "best_for": "Small datasets (n < 50)", "complexity": "O(n log n)", "memory": "O(n)", "accuracy": "Low (insufficient data)", }, } elif n < 100: return { "recommended_method": "numpy", "reasoning": f"Data size n={n} is too small for optimization benefits", "method_details": { "description": "NumPy implementation", "best_for": "Small datasets (50 ≤ n < 100)", "complexity": "O(n log n)", "memory": "O(n)", "accuracy": "Medium", }, } elif n < 1000: return { "recommended_method": "numba", "reasoning": f"Data size n={n} benefits from JIT compilation", "method_details": { "description": "Numba JIT-compiled implementation", "best_for": "Medium datasets (100 ≤ n < 1000)", "complexity": "O(n log n)", "memory": "O(n)", "accuracy": "High", }, } else: return { "recommended_method": "jax", "reasoning": f"Data size n={n} benefits from GPU acceleration", "method_details": { "description": "JAX GPU-accelerated implementation", "best_for": "Large datasets (n ≥ 1000)", "complexity": "O(n log n)", "memory": "O(n)", "accuracy": "High", }, }