Source code for lrdbenchmark.analysis.temporal.dma_estimator

#!/usr/bin/env python3
"""
Unified DMA (Detrended Moving Average) Estimator for Long-Range Dependence Analysis.

This module implements the DMA estimator with automatic optimization framework
selection (JAX, Numba, NumPy) for the best performance on the available hardware.
"""

import os
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

# Import optimization frameworks
try:
    import jax
    import jax.numpy as jnp
    from jax import jit, vmap

    JAX_AVAILABLE = True
except ImportError:
    JAX_AVAILABLE = False

try:
    import numba
    from numba import jit as numba_jit
    from numba import prange

    NUMBA_AVAILABLE = True
except ImportError:
    NUMBA_AVAILABLE = False

from lrdbenchmark.analysis.base_estimator import BaseEstimator


def _ensure_non_interactive_backend() -> None:
    """Switch to a headless-friendly Matplotlib backend when running without DISPLAY."""
    if os.environ.get("LRDBENCHMARK_FORCE_INTERACTIVE", "").lower() in {
        "1",
        "true",
        "yes",
    }:
        return
    backend = plt.get_backend().lower()
    interactive_markers = ("gtk", "qt", "wx", "tk")
    if any(marker in backend for marker in interactive_markers):
        try:
            plt.switch_backend("Agg")
        except Exception:
            pass


_ensure_non_interactive_backend()


[docs] class DMAEstimator(BaseEstimator): """ Unified DMA (Detrended Moving Average) Estimator for Long-Range Dependence Analysis. DMA analyzes the root-mean-square fluctuation of detrended time series data using moving average detrending to estimate the Hurst parameter. Features: - Automatic optimization framework selection (JAX, Numba, NumPy) - GPU acceleration with JAX when available - JIT compilation with Numba for CPU optimization - Graceful fallbacks when optimization frameworks fail Parameters ---------- min_scale : int, optional (default=10) Minimum scale for analysis max_scale : int, optional (default=None) Maximum scale for analysis. If None, uses data length / 4 num_scales : int, optional (default=10) Number of scales to test use_optimization : str, optional (default='auto') Optimization framework to use: 'auto', 'jax', 'numba', 'numpy' """
[docs] def __init__( self, min_scale: Optional[int] = None, max_scale: Optional[int] = None, num_scales: Optional[int] = None, use_optimization: str = "auto", *, min_window_size: Optional[int] = None, max_window_size: Optional[int] = None, num_windows: Optional[int] = None, window_sizes: Optional[Sequence[int]] = None, overlap: bool = True, ): super().__init__() if min_window_size is not None: min_scale = min_window_size if max_window_size is not None: max_scale = max_window_size if num_windows is not None: num_scales = num_windows min_scale = 4 if min_scale is None else int(min_scale) num_scales = 10 if num_scales is None else int(num_scales) sanitized_windows = None if window_sizes is not None: sanitized_windows = self._sanitize_window_sizes(window_sizes) # Estimator parameters with legacy aliases param_dict = { "min_scale": int(min_scale), "max_scale": int(max_scale) if max_scale is not None else None, "num_scales": int(num_scales), "min_window_size": int(min_scale), "max_window_size": int(max_scale) if max_scale is not None else None, "num_windows": int(num_scales), "window_sizes": sanitized_windows.tolist() if sanitized_windows is not None else None, "overlap": bool(overlap), } if sanitized_windows is not None and len(sanitized_windows) > 0: param_dict["min_scale"] = int(sanitized_windows[0]) param_dict["min_window_size"] = int(sanitized_windows[0]) param_dict["max_scale"] = int(sanitized_windows[-1]) param_dict["max_window_size"] = int(sanitized_windows[-1]) param_dict["num_scales"] = len(sanitized_windows) param_dict["num_windows"] = len(sanitized_windows) super().__init__(**param_dict) self.parameters = param_dict # Optimization framework self.optimization_framework = self._select_optimization_framework( use_optimization ) # Results storage self.results = {} # Validation self._validate_parameters()
[docs] def _select_optimization_framework(self, use_optimization: str) -> str: """Select the optimal optimization framework.""" if use_optimization == "auto": if JAX_AVAILABLE: return "jax" # Best for GPU acceleration elif NUMBA_AVAILABLE: return "numba" # Good for CPU optimization else: return "numpy" # Fallback elif use_optimization == "jax" and JAX_AVAILABLE: return "jax" elif use_optimization == "numba" and NUMBA_AVAILABLE: return "numba" else: return "numpy"
[docs] def _validate_parameters(self) -> None: """Validate estimator parameters.""" if self.parameters["window_sizes"] is not None: windows = np.asarray(self.parameters["window_sizes"], dtype=int) if np.any(windows < 3): raise ValueError("All window sizes must be at least 3") if not np.all(np.diff(windows) > 0): raise ValueError("Window sizes must be in ascending order") if len(windows) < 3: raise ValueError("Need at least 3 window sizes") if self.parameters["min_scale"] < 3: raise ValueError("min_window_size must be at least 3") max_scale = self.parameters["max_scale"] if max_scale is not None and max_scale <= self.parameters["min_scale"]: raise ValueError("max_window_size must be greater than min_window_size") if ( self.parameters["num_scales"] < 3 and self.parameters["window_sizes"] is None ): raise ValueError("Need at least 3 window sizes")
def _sanitize_window_sizes(self, window_sizes: Sequence[int]) -> np.ndarray: windows = np.array(window_sizes, dtype=int) if np.any(windows <= 0): raise ValueError("Window sizes must be positive integers") return windows def _resolve_scales(self, n: int) -> np.ndarray: if self.parameters["window_sizes"] is not None: windows = np.asarray(self.parameters["window_sizes"], dtype=int) else: max_scale = self.parameters["max_scale"] if max_scale is None: max_scale = max(self.parameters["min_scale"] + 1, n // 4) windows = np.logspace( np.log10(self.parameters["min_scale"]), np.log10(max_scale), self.parameters["num_scales"], dtype=int, ) windows = np.unique(windows) valid = windows[(windows >= self.parameters["min_scale"]) & (windows <= n // 2)] if len(valid) < 3: raise ValueError("Need at least 3 window sizes") return valid def _confidence_interval( self, hurst: float, std_err: float, sample_size: int, confidence_level: float = 0.95, ) -> List[float]: if not np.isfinite(std_err) or std_err <= 0 or sample_size < 3: return [float("nan"), float("nan")] alpha = 1 - confidence_level dof = max(sample_size - 2, 1) critical = stats.t.ppf(1 - alpha / 2, dof) margin = critical * std_err return [float(hurst - margin), float(hurst + margin)]
[docs] def estimate(self, data: Union[np.ndarray, list]) -> Dict[str, Any]: """ Estimate the Hurst parameter using DMA with automatic optimization. Parameters ---------- data : array-like Input time series data Returns ------- dict Dictionary containing estimation results including: - hurst_parameter: Estimated Hurst parameter - r_squared: R-squared value of the fit - scales: Scales used in the analysis - fluctuation_values: Fluctuation values for each scale - log_scales: Log of scales - log_fluctuations: Log of fluctuation values """ data = np.asarray(data) n = len(data) if n < 10: raise ValueError("Data length must be at least 10") # Select optimal method based on data size and framework if self.optimization_framework == "jax" and JAX_AVAILABLE: try: return self._estimate_jax(data) except Exception as e: if self._should_suppress_fallback_warning(e): return self._estimate_numpy(data) warnings.warn(f"JAX implementation failed: {e}, falling back to NumPy") return self._estimate_numpy(data) elif self.optimization_framework == "numba" and NUMBA_AVAILABLE: try: return self._estimate_numba(data) except Exception as e: if self._should_suppress_fallback_warning(e): return self._estimate_numpy(data) warnings.warn( f"Numba implementation failed: {e}, falling back to NumPy" ) return self._estimate_numpy(data) else: return self._estimate_numpy(data)
[docs] def _estimate_numpy(self, data: np.ndarray) -> Dict[str, Any]: """NumPy implementation of DMA estimation.""" n = len(data) # Set max scale if not provided scales = self._resolve_scales(n) if len(scales) < 3: raise ValueError("Need at least 3 window sizes") # Calculate fluctuation values for each scale fluctuation_values = [] for scale in scales: fluct_val = self._calculate_fluctuation_numpy( data, scale, overlap=self.parameters["overlap"] ) fluctuation_values.append(fluct_val) fluctuation_values = np.array(fluctuation_values) # Filter out invalid values valid_mask = (fluctuation_values > 0) & ~np.isnan(fluctuation_values) if np.sum(valid_mask) < 3: raise ValueError("Need at least 3 window sizes") valid_scales = scales[valid_mask] valid_fluctuations = fluctuation_values[valid_mask] # Log-log regression log_scales = np.log(valid_scales) log_fluctuations = np.log(valid_fluctuations) # Linear regression slope, intercept, r_value, p_value, std_err = stats.linregress( log_scales, log_fluctuations ) # Calculate R-squared r_squared = r_value**2 # Hurst parameter is the slope hurst_parameter = slope confidence_interval = self._confidence_interval( hurst_parameter, std_err, len(log_scales) ) self.results = { "hurst_parameter": float(hurst_parameter), "r_squared": float(r_squared), "slope": float(slope), "intercept": float(intercept), "p_value": float(p_value), "std_error": float(std_err), "scales": valid_scales.tolist(), "window_sizes": valid_scales.tolist(), "fluctuation_values": valid_fluctuations.tolist(), "log_scales": log_scales.tolist(), "log_fluctuations": log_fluctuations.tolist(), "confidence_interval": confidence_interval, "method": "numpy", "optimization_framework": "numpy", } return self.results
[docs] def _estimate_numba(self, data: np.ndarray) -> Dict[str, Any]: """Numba-optimized implementation of DMA estimation.""" result = self._estimate_numpy(data) result["method"] = "numba" result["optimization_framework"] = "numba" return result
[docs] def _estimate_jax(self, data: np.ndarray) -> Dict[str, Any]: """JAX-optimized implementation of DMA estimation.""" if not JAX_AVAILABLE: return self._estimate_numpy(data) n = len(data) data_np = np.asarray(data, dtype=float) y = jnp.asarray(np.cumsum(data_np - np.mean(data_np)), dtype=jnp.float64) scales_np = self._resolve_scales(n) overlap = bool(self.parameters.get("overlap", True)) fluctuation_values = [] for scale in scales_np: fluct = _dma_fluctuation_jax(y, int(scale), overlap) fluctuation_values.append(float(fluct)) fluctuation_values = np.asarray(fluctuation_values, dtype=float) valid_mask = (fluctuation_values > 0) & ~np.isnan(fluctuation_values) if np.sum(valid_mask) < 3: raise ValueError("Insufficient valid fluctuation values for analysis") valid_scales = scales_np[valid_mask] valid_fluctuations = fluctuation_values[valid_mask] log_scales = np.log(valid_scales) log_fluctuations = np.log(valid_fluctuations) slope, intercept, r_value, p_value, std_err = stats.linregress( log_scales, log_fluctuations ) r_squared = r_value**2 confidence_interval = self._confidence_interval( float(slope), float(std_err), len(log_scales), ) self.results = { "hurst_parameter": float(slope), "r_squared": float(r_squared), "slope": float(slope), "intercept": float(intercept), "p_value": float(p_value), "std_error": float(std_err), "scales": valid_scales.tolist(), "window_sizes": valid_scales.tolist(), "fluctuation_values": valid_fluctuations.tolist(), "log_scales": log_scales.tolist(), "log_fluctuations": log_fluctuations.tolist(), "confidence_interval": confidence_interval, "method": "jax", "optimization_framework": self.optimization_framework, } return self.results
[docs] def _calculate_fluctuation_numpy( self, data: np.ndarray, scale: int, overlap: bool ) -> float: """Calculate fluctuation value for a given scale using NumPy.""" n = len(data) if scale >= n: return np.nan # Step 1: Calculate cumulative sum (integration) - this is the key fix! y = np.cumsum(data - np.mean(data)) if overlap: # Moving average detrending on cumulative sum moving_avg = np.convolve(y, np.ones(scale) / scale, mode="valid") detrended = y[scale - 1 :] - moving_avg return float(np.sqrt(np.mean(detrended**2))) # Non-overlapping case n_segments = n // scale if n_segments == 0: return np.nan trimmed = y[: n_segments * scale] segments = trimmed.reshape(n_segments, scale) # Detrend each segment detrended_segments = [] for segment in segments: x = np.arange(scale) # Linear detrending coeffs = np.polyfit(x, segment, 1) trend = np.polyval(coeffs, x) detrended = segment - trend detrended_segments.append(detrended) detrended_segments = np.array(detrended_segments) fluctuation = np.sqrt(np.mean(detrended_segments**2)) return float(fluctuation)
[docs] def get_optimization_info(self) -> Dict[str, Any]: """Get information about available optimizations and current selection.""" return { "current_framework": self.optimization_framework, "jax_available": JAX_AVAILABLE, "numba_available": NUMBA_AVAILABLE, "recommended_framework": self._get_recommended_framework(), }
[docs] def get_confidence_intervals( self, confidence_level: float = 0.95 ) -> Dict[str, Tuple[float, float]]: if not self.results: raise ValueError("No estimation results available") ci = self._confidence_interval( self.results["hurst_parameter"], self.results["std_error"], len(self.results["scales"]), confidence_level, ) return {"hurst_parameter": tuple(ci)}
[docs] def get_estimation_quality(self) -> Dict[str, Any]: if not self.results: raise ValueError("No estimation results available") return { "r_squared": self.results["r_squared"], "p_value": self.results["p_value"], "std_error": self.results["std_error"], "n_windows": len(self.results["scales"]), }
[docs] def plot_scaling(self, **kwargs) -> None: if not self.results: raise ValueError("No estimation results available") self.plot_analysis(**kwargs)
[docs] def _calculate_fluctuation( self, data: Union[np.ndarray, list], window_size: int ) -> float: """Backward-compatible helper for direct fluctuation calculation.""" return float( self._calculate_fluctuation_numpy( np.asarray(data), int(window_size), self.parameters.get("overlap", True) ) )
[docs] @staticmethod def _should_suppress_fallback_warning(error: Exception) -> bool: """Return True when a fallback is expected and shouldn't raise a warning.""" message = str(error).lower() suppressed_fragments = ( "need at least 3 window sizes", "insufficient valid", "insufficient valid fluctuation values", ) return any(fragment in message for fragment in suppressed_fragments)
[docs] def plot_analysis( self, figsize: Tuple[int, int] = (12, 8), save_path: Optional[str] = None ) -> None: """Plot the DMA analysis results.""" if not self.results: raise ValueError("No estimation results available") fig, axes = plt.subplots(2, 2, figsize=figsize) fig.suptitle("DMA Analysis Results", fontsize=16) # Plot 1: Log-log relationship ax1 = axes[0, 0] x = self.results["log_scales"] y = self.results["log_fluctuations"] ax1.scatter(x, y, s=60, alpha=0.7, label="Data points") # Plot fitted line slope = self.results["slope"] intercept = self.results["intercept"] x_fit = np.linspace(min(x), max(x), 100) y_fit = slope * x_fit + intercept ax1.plot(x_fit, y_fit, "r--", label=f"Linear fit (slope={slope:.3f})") ax1.set_xlabel("log(Scale)") ax1.set_ylabel("log(Fluctuation)") ax1.set_title("DMA Scaling") ax1.legend() ax1.grid(True, alpha=0.3) # Plot 2: Fluctuation vs Scale (log-log) ax2 = axes[0, 1] scales = self.results["scales"] fluctuations = self.results["fluctuation_values"] ax2.scatter(scales, fluctuations, s=60, alpha=0.7) ax2.set_xscale("log") ax2.set_yscale("log") ax2.set_xlabel("Scale") ax2.set_ylabel("Fluctuation") ax2.set_title("Fluctuation vs Scale (log-log)") ax2.grid(True, which="both", ls=":", alpha=0.3) # Plot 3: Hurst parameter estimate ax3 = axes[1, 0] hurst = self.results["hurst_parameter"] ax3.bar(["Hurst Parameter"], [hurst], alpha=0.7, color="skyblue") ax3.axhline( y=0.5, color="red", linestyle="--", alpha=0.7, label="H=0.5 (no memory)" ) ax3.set_ylabel("Hurst Parameter") ax3.set_title(f"Hurst Parameter Estimate: {hurst:.3f}") ax3.legend() ax3.grid(True, alpha=0.3) # Plot 4: R-squared ax4 = axes[1, 1] r_squared = self.results["r_squared"] ax4.bar(["R²"], [r_squared], alpha=0.7, color="lightgreen") ax4.set_ylabel("R²") ax4.set_title(f"Goodness of Fit: R² = {r_squared:.3f}") ax4.set_ylim(0, 1) ax4.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") backend = plt.get_backend().lower() interactive_markers = ("qt", "gtk", "wx", "tk", "nbagg", "webagg") if plt.isinteractive() or any( marker in backend for marker in interactive_markers ): plt.show() else: plt.close(fig)
[docs] def get_method_recommendation(self, n: int) -> Dict[str, Any]: """Get method recommendation for a given data size.""" if n < 100: return { "recommended_method": "numpy", "reasoning": f"Data size n={n} is too small for optimization benefits", "method_details": { "description": "NumPy implementation", "best_for": "Small datasets (n < 100)", "complexity": "O(n²)", "memory": "O(n)", "accuracy": "Medium", }, } elif n < 1000: return { "recommended_method": "numba", "reasoning": f"Data size n={n} benefits from JIT compilation", "method_details": { "description": "Numba JIT-compiled implementation", "best_for": "Medium datasets (100 ≤ n < 1000)", "complexity": "O(n²)", "memory": "O(n)", "accuracy": "High", }, } else: return { "recommended_method": "jax", "reasoning": f"Data size n={n} benefits from GPU acceleration", "method_details": { "description": "JAX GPU-accelerated implementation", "best_for": "Large datasets (n ≥ 1000)", "complexity": "O(n²)", "memory": "O(n)", "accuracy": "High", }, }
if JAX_AVAILABLE: from functools import partial @partial(jit, static_argnums=(1, 2)) def _dma_fluctuation_jax(y: jnp.ndarray, scale: int, overlap: bool) -> jnp.ndarray: """JAX implementation of DMA fluctuation for a fixed scale.""" if overlap: kernel = jnp.ones(scale, dtype=y.dtype) / scale moving_avg = jnp.convolve(y, kernel, mode="valid") detrended = y[scale - 1 :] - moving_avg return jnp.sqrt(jnp.mean(detrended**2)) n_segments = y.shape[0] // scale if n_segments == 0: return jnp.nan trimmed = y[: n_segments * scale] segments = trimmed.reshape((n_segments, scale)) x = jnp.arange(scale, dtype=y.dtype) x_mean = jnp.mean(x) denom = jnp.sum((x - x_mean) ** 2) def segment_variance(segment: jnp.ndarray) -> jnp.ndarray: seg_mean = jnp.mean(segment) slope = jnp.sum((x - x_mean) * (segment - seg_mean)) / denom intercept = seg_mean - slope * x_mean detrended = segment - (slope * x + intercept) return jnp.mean(detrended**2) variances = vmap(segment_variance)(segments) return jnp.sqrt(jnp.mean(variances))