#!/usr/bin/env python3
"""
Unified DMA (Detrended Moving Average) Estimator for Long-Range Dependence Analysis.
This module implements the DMA estimator with automatic optimization framework
selection (JAX, Numba, NumPy) for the best performance on the available hardware.
"""
import os
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
# Import optimization frameworks
try:
import jax
import jax.numpy as jnp
from jax import jit, vmap
JAX_AVAILABLE = True
except ImportError:
JAX_AVAILABLE = False
try:
import numba
from numba import jit as numba_jit
from numba import prange
NUMBA_AVAILABLE = True
except ImportError:
NUMBA_AVAILABLE = False
from lrdbenchmark.analysis.base_estimator import BaseEstimator
def _ensure_non_interactive_backend() -> None:
"""Switch to a headless-friendly Matplotlib backend when running without DISPLAY."""
if os.environ.get("LRDBENCHMARK_FORCE_INTERACTIVE", "").lower() in {
"1",
"true",
"yes",
}:
return
backend = plt.get_backend().lower()
interactive_markers = ("gtk", "qt", "wx", "tk")
if any(marker in backend for marker in interactive_markers):
try:
plt.switch_backend("Agg")
except Exception:
pass
_ensure_non_interactive_backend()
[docs]
class DMAEstimator(BaseEstimator):
"""
Unified DMA (Detrended Moving Average) Estimator for Long-Range Dependence Analysis.
DMA analyzes the root-mean-square fluctuation of detrended time series data
using moving average detrending to estimate the Hurst parameter.
Features:
- Automatic optimization framework selection (JAX, Numba, NumPy)
- GPU acceleration with JAX when available
- JIT compilation with Numba for CPU optimization
- Graceful fallbacks when optimization frameworks fail
Parameters
----------
min_scale : int, optional (default=10)
Minimum scale for analysis
max_scale : int, optional (default=None)
Maximum scale for analysis. If None, uses data length / 4
num_scales : int, optional (default=10)
Number of scales to test
use_optimization : str, optional (default='auto')
Optimization framework to use: 'auto', 'jax', 'numba', 'numpy'
"""
[docs]
def __init__(
self,
min_scale: Optional[int] = None,
max_scale: Optional[int] = None,
num_scales: Optional[int] = None,
use_optimization: str = "auto",
*,
min_window_size: Optional[int] = None,
max_window_size: Optional[int] = None,
num_windows: Optional[int] = None,
window_sizes: Optional[Sequence[int]] = None,
overlap: bool = True,
):
super().__init__()
if min_window_size is not None:
min_scale = min_window_size
if max_window_size is not None:
max_scale = max_window_size
if num_windows is not None:
num_scales = num_windows
min_scale = 4 if min_scale is None else int(min_scale)
num_scales = 10 if num_scales is None else int(num_scales)
sanitized_windows = None
if window_sizes is not None:
sanitized_windows = self._sanitize_window_sizes(window_sizes)
# Estimator parameters with legacy aliases
param_dict = {
"min_scale": int(min_scale),
"max_scale": int(max_scale) if max_scale is not None else None,
"num_scales": int(num_scales),
"min_window_size": int(min_scale),
"max_window_size": int(max_scale) if max_scale is not None else None,
"num_windows": int(num_scales),
"window_sizes": sanitized_windows.tolist()
if sanitized_windows is not None
else None,
"overlap": bool(overlap),
}
if sanitized_windows is not None and len(sanitized_windows) > 0:
param_dict["min_scale"] = int(sanitized_windows[0])
param_dict["min_window_size"] = int(sanitized_windows[0])
param_dict["max_scale"] = int(sanitized_windows[-1])
param_dict["max_window_size"] = int(sanitized_windows[-1])
param_dict["num_scales"] = len(sanitized_windows)
param_dict["num_windows"] = len(sanitized_windows)
super().__init__(**param_dict)
self.parameters = param_dict
# Optimization framework
self.optimization_framework = self._select_optimization_framework(
use_optimization
)
# Results storage
self.results = {}
# Validation
self._validate_parameters()
[docs]
def _select_optimization_framework(self, use_optimization: str) -> str:
"""Select the optimal optimization framework."""
if use_optimization == "auto":
if JAX_AVAILABLE:
return "jax" # Best for GPU acceleration
elif NUMBA_AVAILABLE:
return "numba" # Good for CPU optimization
else:
return "numpy" # Fallback
elif use_optimization == "jax" and JAX_AVAILABLE:
return "jax"
elif use_optimization == "numba" and NUMBA_AVAILABLE:
return "numba"
else:
return "numpy"
[docs]
def _validate_parameters(self) -> None:
"""Validate estimator parameters."""
if self.parameters["window_sizes"] is not None:
windows = np.asarray(self.parameters["window_sizes"], dtype=int)
if np.any(windows < 3):
raise ValueError("All window sizes must be at least 3")
if not np.all(np.diff(windows) > 0):
raise ValueError("Window sizes must be in ascending order")
if len(windows) < 3:
raise ValueError("Need at least 3 window sizes")
if self.parameters["min_scale"] < 3:
raise ValueError("min_window_size must be at least 3")
max_scale = self.parameters["max_scale"]
if max_scale is not None and max_scale <= self.parameters["min_scale"]:
raise ValueError("max_window_size must be greater than min_window_size")
if (
self.parameters["num_scales"] < 3
and self.parameters["window_sizes"] is None
):
raise ValueError("Need at least 3 window sizes")
def _sanitize_window_sizes(self, window_sizes: Sequence[int]) -> np.ndarray:
windows = np.array(window_sizes, dtype=int)
if np.any(windows <= 0):
raise ValueError("Window sizes must be positive integers")
return windows
def _resolve_scales(self, n: int) -> np.ndarray:
if self.parameters["window_sizes"] is not None:
windows = np.asarray(self.parameters["window_sizes"], dtype=int)
else:
max_scale = self.parameters["max_scale"]
if max_scale is None:
max_scale = max(self.parameters["min_scale"] + 1, n // 4)
windows = np.logspace(
np.log10(self.parameters["min_scale"]),
np.log10(max_scale),
self.parameters["num_scales"],
dtype=int,
)
windows = np.unique(windows)
valid = windows[(windows >= self.parameters["min_scale"]) & (windows <= n // 2)]
if len(valid) < 3:
raise ValueError("Need at least 3 window sizes")
return valid
def _confidence_interval(
self,
hurst: float,
std_err: float,
sample_size: int,
confidence_level: float = 0.95,
) -> List[float]:
if not np.isfinite(std_err) or std_err <= 0 or sample_size < 3:
return [float("nan"), float("nan")]
alpha = 1 - confidence_level
dof = max(sample_size - 2, 1)
critical = stats.t.ppf(1 - alpha / 2, dof)
margin = critical * std_err
return [float(hurst - margin), float(hurst + margin)]
[docs]
def estimate(self, data: Union[np.ndarray, list]) -> Dict[str, Any]:
"""
Estimate the Hurst parameter using DMA with automatic optimization.
Parameters
----------
data : array-like
Input time series data
Returns
-------
dict
Dictionary containing estimation results including:
- hurst_parameter: Estimated Hurst parameter
- r_squared: R-squared value of the fit
- scales: Scales used in the analysis
- fluctuation_values: Fluctuation values for each scale
- log_scales: Log of scales
- log_fluctuations: Log of fluctuation values
"""
data = np.asarray(data)
n = len(data)
if n < 10:
raise ValueError("Data length must be at least 10")
# Select optimal method based on data size and framework
if self.optimization_framework == "jax" and JAX_AVAILABLE:
try:
return self._estimate_jax(data)
except Exception as e:
if self._should_suppress_fallback_warning(e):
return self._estimate_numpy(data)
warnings.warn(f"JAX implementation failed: {e}, falling back to NumPy")
return self._estimate_numpy(data)
elif self.optimization_framework == "numba" and NUMBA_AVAILABLE:
try:
return self._estimate_numba(data)
except Exception as e:
if self._should_suppress_fallback_warning(e):
return self._estimate_numpy(data)
warnings.warn(
f"Numba implementation failed: {e}, falling back to NumPy"
)
return self._estimate_numpy(data)
else:
return self._estimate_numpy(data)
[docs]
def _estimate_numpy(self, data: np.ndarray) -> Dict[str, Any]:
"""NumPy implementation of DMA estimation."""
n = len(data)
# Set max scale if not provided
scales = self._resolve_scales(n)
if len(scales) < 3:
raise ValueError("Need at least 3 window sizes")
# Calculate fluctuation values for each scale
fluctuation_values = []
for scale in scales:
fluct_val = self._calculate_fluctuation_numpy(
data, scale, overlap=self.parameters["overlap"]
)
fluctuation_values.append(fluct_val)
fluctuation_values = np.array(fluctuation_values)
# Filter out invalid values
valid_mask = (fluctuation_values > 0) & ~np.isnan(fluctuation_values)
if np.sum(valid_mask) < 3:
raise ValueError("Need at least 3 window sizes")
valid_scales = scales[valid_mask]
valid_fluctuations = fluctuation_values[valid_mask]
# Log-log regression
log_scales = np.log(valid_scales)
log_fluctuations = np.log(valid_fluctuations)
# Linear regression
slope, intercept, r_value, p_value, std_err = stats.linregress(
log_scales, log_fluctuations
)
# Calculate R-squared
r_squared = r_value**2
# Hurst parameter is the slope
hurst_parameter = slope
confidence_interval = self._confidence_interval(
hurst_parameter, std_err, len(log_scales)
)
self.results = {
"hurst_parameter": float(hurst_parameter),
"r_squared": float(r_squared),
"slope": float(slope),
"intercept": float(intercept),
"p_value": float(p_value),
"std_error": float(std_err),
"scales": valid_scales.tolist(),
"window_sizes": valid_scales.tolist(),
"fluctuation_values": valid_fluctuations.tolist(),
"log_scales": log_scales.tolist(),
"log_fluctuations": log_fluctuations.tolist(),
"confidence_interval": confidence_interval,
"method": "numpy",
"optimization_framework": "numpy",
}
return self.results
[docs]
def _estimate_numba(self, data: np.ndarray) -> Dict[str, Any]:
"""Numba-optimized implementation of DMA estimation."""
result = self._estimate_numpy(data)
result["method"] = "numba"
result["optimization_framework"] = "numba"
return result
[docs]
def _estimate_jax(self, data: np.ndarray) -> Dict[str, Any]:
"""JAX-optimized implementation of DMA estimation."""
if not JAX_AVAILABLE:
return self._estimate_numpy(data)
n = len(data)
data_np = np.asarray(data, dtype=float)
y = jnp.asarray(np.cumsum(data_np - np.mean(data_np)), dtype=jnp.float64)
scales_np = self._resolve_scales(n)
overlap = bool(self.parameters.get("overlap", True))
fluctuation_values = []
for scale in scales_np:
fluct = _dma_fluctuation_jax(y, int(scale), overlap)
fluctuation_values.append(float(fluct))
fluctuation_values = np.asarray(fluctuation_values, dtype=float)
valid_mask = (fluctuation_values > 0) & ~np.isnan(fluctuation_values)
if np.sum(valid_mask) < 3:
raise ValueError("Insufficient valid fluctuation values for analysis")
valid_scales = scales_np[valid_mask]
valid_fluctuations = fluctuation_values[valid_mask]
log_scales = np.log(valid_scales)
log_fluctuations = np.log(valid_fluctuations)
slope, intercept, r_value, p_value, std_err = stats.linregress(
log_scales, log_fluctuations
)
r_squared = r_value**2
confidence_interval = self._confidence_interval(
float(slope),
float(std_err),
len(log_scales),
)
self.results = {
"hurst_parameter": float(slope),
"r_squared": float(r_squared),
"slope": float(slope),
"intercept": float(intercept),
"p_value": float(p_value),
"std_error": float(std_err),
"scales": valid_scales.tolist(),
"window_sizes": valid_scales.tolist(),
"fluctuation_values": valid_fluctuations.tolist(),
"log_scales": log_scales.tolist(),
"log_fluctuations": log_fluctuations.tolist(),
"confidence_interval": confidence_interval,
"method": "jax",
"optimization_framework": self.optimization_framework,
}
return self.results
[docs]
def _calculate_fluctuation_numpy(
self, data: np.ndarray, scale: int, overlap: bool
) -> float:
"""Calculate fluctuation value for a given scale using NumPy."""
n = len(data)
if scale >= n:
return np.nan
# Step 1: Calculate cumulative sum (integration) - this is the key fix!
y = np.cumsum(data - np.mean(data))
if overlap:
# Moving average detrending on cumulative sum
moving_avg = np.convolve(y, np.ones(scale) / scale, mode="valid")
detrended = y[scale - 1 :] - moving_avg
return float(np.sqrt(np.mean(detrended**2)))
# Non-overlapping case
n_segments = n // scale
if n_segments == 0:
return np.nan
trimmed = y[: n_segments * scale]
segments = trimmed.reshape(n_segments, scale)
# Detrend each segment
detrended_segments = []
for segment in segments:
x = np.arange(scale)
# Linear detrending
coeffs = np.polyfit(x, segment, 1)
trend = np.polyval(coeffs, x)
detrended = segment - trend
detrended_segments.append(detrended)
detrended_segments = np.array(detrended_segments)
fluctuation = np.sqrt(np.mean(detrended_segments**2))
return float(fluctuation)
[docs]
def get_optimization_info(self) -> Dict[str, Any]:
"""Get information about available optimizations and current selection."""
return {
"current_framework": self.optimization_framework,
"jax_available": JAX_AVAILABLE,
"numba_available": NUMBA_AVAILABLE,
"recommended_framework": self._get_recommended_framework(),
}
[docs]
def get_confidence_intervals(
self, confidence_level: float = 0.95
) -> Dict[str, Tuple[float, float]]:
if not self.results:
raise ValueError("No estimation results available")
ci = self._confidence_interval(
self.results["hurst_parameter"],
self.results["std_error"],
len(self.results["scales"]),
confidence_level,
)
return {"hurst_parameter": tuple(ci)}
[docs]
def get_estimation_quality(self) -> Dict[str, Any]:
if not self.results:
raise ValueError("No estimation results available")
return {
"r_squared": self.results["r_squared"],
"p_value": self.results["p_value"],
"std_error": self.results["std_error"],
"n_windows": len(self.results["scales"]),
}
[docs]
def plot_scaling(self, **kwargs) -> None:
if not self.results:
raise ValueError("No estimation results available")
self.plot_analysis(**kwargs)
[docs]
def _calculate_fluctuation(
self, data: Union[np.ndarray, list], window_size: int
) -> float:
"""Backward-compatible helper for direct fluctuation calculation."""
return float(
self._calculate_fluctuation_numpy(
np.asarray(data), int(window_size), self.parameters.get("overlap", True)
)
)
[docs]
def _get_recommended_framework(self) -> str:
"""Get the recommended optimization framework."""
if JAX_AVAILABLE:
return "jax" # Best for GPU acceleration
elif NUMBA_AVAILABLE:
return "numba" # Good for CPU optimization
else:
return "numpy" # Fallback
[docs]
@staticmethod
def _should_suppress_fallback_warning(error: Exception) -> bool:
"""Return True when a fallback is expected and shouldn't raise a warning."""
message = str(error).lower()
suppressed_fragments = (
"need at least 3 window sizes",
"insufficient valid",
"insufficient valid fluctuation values",
)
return any(fragment in message for fragment in suppressed_fragments)
[docs]
def plot_analysis(
self, figsize: Tuple[int, int] = (12, 8), save_path: Optional[str] = None
) -> None:
"""Plot the DMA analysis results."""
if not self.results:
raise ValueError("No estimation results available")
fig, axes = plt.subplots(2, 2, figsize=figsize)
fig.suptitle("DMA Analysis Results", fontsize=16)
# Plot 1: Log-log relationship
ax1 = axes[0, 0]
x = self.results["log_scales"]
y = self.results["log_fluctuations"]
ax1.scatter(x, y, s=60, alpha=0.7, label="Data points")
# Plot fitted line
slope = self.results["slope"]
intercept = self.results["intercept"]
x_fit = np.linspace(min(x), max(x), 100)
y_fit = slope * x_fit + intercept
ax1.plot(x_fit, y_fit, "r--", label=f"Linear fit (slope={slope:.3f})")
ax1.set_xlabel("log(Scale)")
ax1.set_ylabel("log(Fluctuation)")
ax1.set_title("DMA Scaling")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot 2: Fluctuation vs Scale (log-log)
ax2 = axes[0, 1]
scales = self.results["scales"]
fluctuations = self.results["fluctuation_values"]
ax2.scatter(scales, fluctuations, s=60, alpha=0.7)
ax2.set_xscale("log")
ax2.set_yscale("log")
ax2.set_xlabel("Scale")
ax2.set_ylabel("Fluctuation")
ax2.set_title("Fluctuation vs Scale (log-log)")
ax2.grid(True, which="both", ls=":", alpha=0.3)
# Plot 3: Hurst parameter estimate
ax3 = axes[1, 0]
hurst = self.results["hurst_parameter"]
ax3.bar(["Hurst Parameter"], [hurst], alpha=0.7, color="skyblue")
ax3.axhline(
y=0.5, color="red", linestyle="--", alpha=0.7, label="H=0.5 (no memory)"
)
ax3.set_ylabel("Hurst Parameter")
ax3.set_title(f"Hurst Parameter Estimate: {hurst:.3f}")
ax3.legend()
ax3.grid(True, alpha=0.3)
# Plot 4: R-squared
ax4 = axes[1, 1]
r_squared = self.results["r_squared"]
ax4.bar(["R²"], [r_squared], alpha=0.7, color="lightgreen")
ax4.set_ylabel("R²")
ax4.set_title(f"Goodness of Fit: R² = {r_squared:.3f}")
ax4.set_ylim(0, 1)
ax4.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
backend = plt.get_backend().lower()
interactive_markers = ("qt", "gtk", "wx", "tk", "nbagg", "webagg")
if plt.isinteractive() or any(
marker in backend for marker in interactive_markers
):
plt.show()
else:
plt.close(fig)
[docs]
def get_method_recommendation(self, n: int) -> Dict[str, Any]:
"""Get method recommendation for a given data size."""
if n < 100:
return {
"recommended_method": "numpy",
"reasoning": f"Data size n={n} is too small for optimization benefits",
"method_details": {
"description": "NumPy implementation",
"best_for": "Small datasets (n < 100)",
"complexity": "O(n²)",
"memory": "O(n)",
"accuracy": "Medium",
},
}
elif n < 1000:
return {
"recommended_method": "numba",
"reasoning": f"Data size n={n} benefits from JIT compilation",
"method_details": {
"description": "Numba JIT-compiled implementation",
"best_for": "Medium datasets (100 ≤ n < 1000)",
"complexity": "O(n²)",
"memory": "O(n)",
"accuracy": "High",
},
}
else:
return {
"recommended_method": "jax",
"reasoning": f"Data size n={n} benefits from GPU acceleration",
"method_details": {
"description": "JAX GPU-accelerated implementation",
"best_for": "Large datasets (n ≥ 1000)",
"complexity": "O(n²)",
"memory": "O(n)",
"accuracy": "High",
},
}
if JAX_AVAILABLE:
from functools import partial
@partial(jit, static_argnums=(1, 2))
def _dma_fluctuation_jax(y: jnp.ndarray, scale: int, overlap: bool) -> jnp.ndarray:
"""JAX implementation of DMA fluctuation for a fixed scale."""
if overlap:
kernel = jnp.ones(scale, dtype=y.dtype) / scale
moving_avg = jnp.convolve(y, kernel, mode="valid")
detrended = y[scale - 1 :] - moving_avg
return jnp.sqrt(jnp.mean(detrended**2))
n_segments = y.shape[0] // scale
if n_segments == 0:
return jnp.nan
trimmed = y[: n_segments * scale]
segments = trimmed.reshape((n_segments, scale))
x = jnp.arange(scale, dtype=y.dtype)
x_mean = jnp.mean(x)
denom = jnp.sum((x - x_mean) ** 2)
def segment_variance(segment: jnp.ndarray) -> jnp.ndarray:
seg_mean = jnp.mean(segment)
slope = jnp.sum((x - x_mean) * (segment - seg_mean)) / denom
intercept = seg_mean - slope * x_mean
detrended = segment - (slope * x + intercept)
return jnp.mean(detrended**2)
variances = vmap(segment_variance)(segments)
return jnp.sqrt(jnp.mean(variances))