Source code for lrdbenchmark.analysis.temporal.dfa_estimator

#!/usr/bin/env python3
"""
Unified DFA (Detrended Fluctuation Analysis) Estimator.
Refactored to use modular backends (NumPy, JAX, Numba).
"""

import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

from lrdbenchmark.analysis.backend_utils import (
    JAX_AVAILABLE,
    NUMBA_AVAILABLE,
    select_backend,
)
from lrdbenchmark.analysis.base_estimator import BaseEstimator

from .dfa_backends import numpy_backend

# Optional backends
try:
    from .dfa_backends import jax_backend
except ImportError:
    jax_backend = None

try:
    from .dfa_backends import numba_backend
except ImportError:
    numba_backend = None


[docs] class DFAEstimator(BaseEstimator): """ Unified DFA Estimator with modular backend support. **Input convention.** Pass a **stationary increment series** (e.g. fractional Gaussian noise or first differences of fractional Brownian motion). For such processes the integrated profile scales as :math:`F(s) \\sim s^{H}` in the mid-scale range, and the ``hurst_parameter`` field is the slope of :math:`\\log F(s)` versus :math:`\\log s` (i.e. the DFA scaling exponent, identified with :math:`H` for fGn-like inputs). Feeding **fBm levels** without differencing does not satisfy the same scaling law and can mislead the estimate. Backends: - 'jax': GPU/TPU accelerated (if available) - 'numba': CPU JIT compiled - 'numpy': Reference implementation """
[docs] def __init__( self, min_scale: int = 10, max_scale: Optional[int] = None, num_scales: int = 10, order: int = 1, use_optimization: str = "auto", ): super().__init__() self.parameters = { "min_scale": min_scale, "max_scale": max_scale, "num_scales": num_scales, "order": order, } self.optimization_framework = select_backend(use_optimization) self.results = {} self._validate_parameters()
def _validate_parameters(self) -> None: if self.parameters["min_scale"] < 4: raise ValueError("min_scale must be at least 4") if ( self.parameters["max_scale"] is not None and self.parameters["max_scale"] <= self.parameters["min_scale"] ): raise ValueError("max_scale must be greater than min_scale") if self.parameters["num_scales"] < 3: raise ValueError("num_scales must be at least 3") if self.parameters["order"] < 0: raise ValueError("order must be non-negative")
[docs] def estimate(self, data: Union[np.ndarray, list]) -> Dict[str, Any]: """ Estimate Hurst parameter using DFA. Delegates calculation to the selected backend. The returned ``hurst_parameter`` is the log-log slope of the fluctuation function vs. scale; see the class docstring for the recommended input (increments / fGn vs. fBm levels). """ data = np.asarray(data) n = len(data) if n < 100: warnings.warn("Data length is small, results may be unreliable") # Set max_scale if None max_scale = self.parameters["max_scale"] if max_scale is None: max_scale = n // 4 # Generate scales scales = np.logspace( np.log10(self.parameters["min_scale"]), np.log10(max_scale), self.parameters["num_scales"], dtype=int, ) scales = np.unique(scales) scales = scales[scales <= n // 2] if len(scales) < 3: raise ValueError("Insufficient valid scales for analysis") # Select Backend Strategy backend_name = self.optimization_framework compute_func = self._get_compute_function(backend_name) # Execute Computation try: fluctuation_values = compute_func(data, scales, self.parameters["order"]) except Exception as e: warnings.warn( f"Backend '{backend_name}' failed: {e}. Falling back to NumPy." ) fluctuation_values = numpy_backend.compute_fluctuations( data, scales, self.parameters["order"] ) backend_name = "numpy (fallback)" # Post-Processing (Regression) # Ensure outputs are NumPy arrays (JAX might return jnp array on host) fluctuation_values = np.asarray(fluctuation_values) scales = np.asarray(scales) valid_mask = (fluctuation_values > 0) & ~np.isnan(fluctuation_values) if np.sum(valid_mask) < 3: raise ValueError("Insufficient valid fluctuation values for analysis") valid_scales = scales[valid_mask] valid_fluctuations = fluctuation_values[valid_mask] log_scales = np.log(valid_scales) log_fluctuations = np.log(valid_fluctuations) slope, intercept, r_value, p_value, std_err = stats.linregress( log_scales, log_fluctuations ) self.results = { "hurst_parameter": float(slope), "r_squared": float(r_value**2), "slope": float(slope), "intercept": float(intercept), "p_value": float(p_value), "std_error": float(std_err), "scales": valid_scales.tolist(), "fluctuation_values": valid_fluctuations.tolist(), "log_scales": log_scales.tolist(), "log_fluctuations": log_fluctuations.tolist(), "method": backend_name, "optimization_framework": self.optimization_framework, } return self.results
[docs] def _get_compute_function(self, backend: str): """Factory method for backend strategy.""" if backend == "jax": if jax_backend and jax_backend.JAX_AVAILABLE: return jax_backend.compute_fluctuations warnings.warn("JAX requested but not available. Falling back to NumPy.") return numpy_backend.compute_fluctuations if backend == "numba": if numba_backend and numba_backend.NUMBA_AVAILABLE: return numba_backend.compute_fluctuations warnings.warn("Numba requested but not available. Falling back to NumPy.") return numpy_backend.compute_fluctuations return numpy_backend.compute_fluctuations
[docs] def get_optimization_info(self) -> Dict[str, Any]: return { "current_framework": self.optimization_framework, "jax_available": getattr(jax_backend, "JAX_AVAILABLE", False), "numba_available": getattr(numba_backend, "NUMBA_AVAILABLE", False), "recommended_framework": self._get_recommended_framework(), }
def _get_recommended_framework(self) -> str: if getattr(jax_backend, "JAX_AVAILABLE", False): return "jax" elif getattr(numba_backend, "NUMBA_AVAILABLE", False): return "numba" else: return "numpy"
[docs] def plot_analysis( self, figsize: Tuple[int, int] = (12, 8), save_path: Optional[str] = None ) -> None: """Plot the DFA analysis results.""" if not self.results: raise ValueError("No results available. Run estimate() first.") fig, axes = plt.subplots(2, 2, figsize=figsize) fig.suptitle("DFA Analysis Results", fontsize=16) # Plot 1: Log-log relationship ax1 = axes[0, 0] x = self.results["log_scales"] y = self.results["log_fluctuations"] ax1.scatter(x, y, s=60, alpha=0.7, label="Data points") slope = self.results["slope"] intercept = self.results["intercept"] x_fit = np.linspace(min(x), max(x), 100) y_fit = slope * x_fit + intercept ax1.plot(x_fit, y_fit, "r--", label=f"Linear fit (slope={slope:.3f})") ax1.set_xlabel("log(Scale)") ax1.set_ylabel("log(Fluctuation)") ax1.set_title("DFA Scaling") ax1.legend() ax1.grid(True, alpha=0.3) # Plot 2: Fluctuation vs Scale (log-log) ax2 = axes[0, 1] scales = self.results["scales"] fluctuations = self.results["fluctuation_values"] ax2.scatter(scales, fluctuations, s=60, alpha=0.7) ax2.set_xscale("log") ax2.set_yscale("log") ax2.set_xlabel("Scale") ax2.set_ylabel("Fluctuation") ax2.set_title("Fluctuation vs Scale (log-log)") ax2.grid(True, which="both", ls=":", alpha=0.3) # Plot 3: Hurst parameter estimate ax3 = axes[1, 0] hurst = self.results["hurst_parameter"] ax3.bar(["Hurst Parameter"], [hurst], alpha=0.7, color="skyblue") ax3.axhline( y=0.5, color="red", linestyle="--", alpha=0.7, label="H=0.5 (no memory)" ) ax3.set_ylabel("Hurst Parameter") ax3.set_title(f"Hurst Parameter Estimate: {hurst:.3f}") ax3.legend() ax3.grid(True, alpha=0.3) # Plot 4: R-squared ax4 = axes[1, 1] r_squared = self.results["r_squared"] ax4.bar(["R²"], [r_squared], alpha=0.7, color="lightgreen") ax4.set_ylabel("R²") ax4.set_title(f"Goodness of Fit: R² = {r_squared:.3f}") ax4.set_ylim(0, 1) ax4.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.show()
[docs] def get_method_recommendation(self, n: int) -> Dict[str, Any]: """Get method recommendation for a given data size.""" # Kept for compatibility / completeness if n < 100: return {"recommended_method": "numpy", "reasoning": "Small size"} elif n < 1000: return { "recommended_method": "numba", "reasoning": "Medium size benefits from JIT", } else: return { "recommended_method": "jax", "reasoning": "Large size benefits from GPU", }