Source code for lrdbenchmark.analysis.spectral.gph_estimator

#!/usr/bin/env python3
"""
Unified Geweke-Porter-Hudak (GPH) Hurst parameter estimator.

This module implements the GPH estimator with automatic optimization framework
selection (JAX, Numba, NumPy) for the best performance on the available hardware.
"""

import warnings
from typing import Any, Dict, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from scipy import signal, stats

from lrdbenchmark.analysis.backend_utils import (
    JAX_AVAILABLE,
    NUMBA_AVAILABLE,
    select_backend,
)

# Import optimization frameworks
if JAX_AVAILABLE:
    import jax
    import jax.numpy as jnp
    from jax import jit, vmap
if NUMBA_AVAILABLE:
    import numba
    from numba import jit as numba_jit
    from numba import prange

# Import base estimator
try:
    from lrdbenchmark.analysis.base_estimator import BaseEstimator
except ImportError:
    from lrdbenchmark.analysis.base_estimator import BaseEstimator


[docs] class GPHEstimator(BaseEstimator): """ Unified Geweke-Porter-Hudak (GPH) Hurst parameter estimator. This estimator uses log-periodogram regression with the regressor log(4*sin^2(ω/2)) to estimate the fractional differencing parameter d, then converts to Hurst parameter as H = d + 0.5. Features: - Automatic optimization framework selection (JAX, Numba, NumPy) - GPU acceleration with JAX when available - JIT compilation with Numba for CPU optimization - Graceful fallbacks when optimization frameworks fail Parameters ---------- min_freq_ratio : float, optional (default=0.01) Minimum frequency ratio (relative to Nyquist) for fitting. max_freq_ratio : float, optional (default=0.1) Maximum frequency ratio (relative to Nyquist) for fitting. apply_bias_correction : bool, optional (default=True) Whether to apply bias correction for finite sample effects. use_optimization : str, optional (default='auto') Optimization framework to use: 'auto', 'jax', 'numba', 'numpy' """
[docs] def __init__( self, min_freq_ratio: float = 0.01, max_freq_ratio: float = 0.1, apply_bias_correction: bool = True, use_optimization: str = "auto", ): super().__init__() # Estimator parameters self.parameters = { "min_freq_ratio": min_freq_ratio, "max_freq_ratio": max_freq_ratio, "apply_bias_correction": apply_bias_correction, } # Optimization framework self.optimization_framework = select_backend(use_optimization) # Results storage self.results = {} # Validation self._validate_parameters()
[docs] def _validate_parameters(self) -> None: """Validate estimator parameters.""" pass
[docs] def estimate(self, data: Union[np.ndarray, list]) -> Dict[str, Any]: """ Estimate Hurst parameter using GPH method with automatic optimization. Parameters ---------- data : array-like Time series data. Returns ------- dict Dictionary containing estimation results including: - hurst_parameter: Estimated Hurst parameter - d_parameter: Estimated fractional differencing parameter - intercept: Intercept of the linear fit - r_squared: R-squared value of the fit - m: Number of frequency points used in fitting - log_regressor: Log regressor values - log_periodogram: Log periodogram values """ data = np.asarray(data) n = len(data) if n < 100: warnings.warn("Data length is small, results may be unreliable") # Select optimal method based on data size and framework backend = self.optimization_framework if backend == "jax": try: return self._estimate_jax(data) except Exception as e: warnings.warn(f"JAX implementation failed: {e}, falling back to NumPy") return self._estimate_numpy(data) elif backend == "numba": try: return self._estimate_numba(data) except Exception as e: warnings.warn( f"Numba implementation failed: {e}, falling back to NumPy" ) return self._estimate_numpy(data) else: # numpy return self._estimate_numpy(data)
[docs] def _estimate_numpy(self, data: np.ndarray) -> Dict[str, Any]: """NumPy implementation of GPH estimation.""" n = len(data) # Coarse preliminary estimate to guide bandwidth coarse_m = int(n**0.8) freqs, psd = signal.periodogram(data, scaling="density") coarse_freqs = freqs[1 : coarse_m + 1] coarse_psd = psd[1 : coarse_m + 1] if len(coarse_freqs) > 3: omega_coarse = 2 * np.pi * coarse_freqs regressor_coarse = np.log(4 * np.sin(omega_coarse / 2) ** 2) log_psd_coarse = np.log(coarse_psd) slope_coarse, _, _, _, _ = stats.linregress( regressor_coarse, log_psd_coarse ) h_coarse = -slope_coarse + 0.5 else: h_coarse = 0.5 # Refined bandwidth based on coarse estimate alpha = 0.5 - 0.2 * abs(h_coarse - 0.5) # Heuristic adjustment m = int(n**alpha) # Select frequency range for fitting freqs_sel = freqs[1 : m + 1] psd_sel = psd[1 : m + 1] if len(freqs_sel) < 3: raise ValueError("Insufficient frequency points for fitting") # Filter out zero/negative PSD values valid_mask = psd_sel > 0 freqs_sel = freqs_sel[valid_mask] psd_sel = psd_sel[valid_mask] if len(freqs_sel) < 3: raise ValueError("Insufficient valid PSD points for fitting") # Convert to angular frequencies omega = 2 * np.pi * freqs_sel # GPH regressor: log(4*sin^2(ω/2)) regressor = np.log(4 * np.sin(omega / 2) ** 2) log_periodogram = np.log(psd_sel) # Linear regression slope, intercept, r_value, p_value, std_err = stats.linregress( regressor, log_periodogram ) d_parameter = -slope # d = -slope # Apply bias correction if requested if self.parameters["apply_bias_correction"]: m = len(freqs_sel) # Simple bias correction for finite sample effects bias_correction = 0.5 * np.log(m) / m d_parameter += bias_correction # Convert to Hurst parameter: H = d + 0.5 hurst = d_parameter + 0.5 # Ensure Hurst parameter is in valid range hurst = np.clip(hurst, 0.01, 0.99) self.results = { "hurst_parameter": float(hurst), "d_parameter": float(d_parameter), "intercept": float(intercept), "slope": float(slope), "r_squared": float(r_value**2), "p_value": float(p_value), "std_error": float(std_err), "m": int(len(freqs_sel)), "log_regressor": regressor, "log_periodogram": log_periodogram, "frequency": freqs_sel, "periodogram": psd_sel, "method": "numpy", "optimization_framework": self.optimization_framework, } return self.results
[docs] def _estimate_jax(self, data: np.ndarray) -> Dict[str, Any]: """JAX-optimized implementation of GPH estimation.""" # JAX implementation is hybrid: uses NumPy/SciPy for PSD and JAX for regression. n = len(data) freqs, psd = signal.periodogram(data, scaling="density") # Coarse preliminary estimate to guide bandwidth (NumPy portion) coarse_m = int(n**0.8) coarse_freqs = freqs[1 : coarse_m + 1] coarse_psd = psd[1 : coarse_m + 1] if len(coarse_freqs) > 3: omega_coarse = 2 * np.pi * coarse_freqs regressor_coarse = np.log(4 * np.sin(omega_coarse / 2) ** 2) log_psd_coarse = np.log(coarse_psd) slope_coarse, _, _, _, _ = stats.linregress( regressor_coarse, log_psd_coarse ) h_coarse = -slope_coarse + 0.5 else: h_coarse = 0.5 # Refined bandwidth based on coarse estimate alpha = 0.5 - 0.2 * abs(h_coarse - 0.5) # Heuristic adjustment m = int(n**alpha) # Select frequency range for fitting freqs_sel = freqs[1 : m + 1] psd_sel = psd[1 : m + 1] if len(freqs_sel) < 3: raise ValueError("Insufficient frequency points for fitting") # Filter out zero/negative PSD values valid_mask = psd_sel > 0 freqs_sel = freqs_sel[valid_mask] psd_sel = psd_sel[valid_mask] if len(freqs_sel) < 3: raise ValueError("Insufficient valid PSD points for fitting") # Convert to JAX arrays for computation freqs_jax = jnp.array(freqs_sel) psd_jax = jnp.array(psd_sel) # Convert to angular frequencies omega = 2 * jnp.pi * freqs_jax # GPH regressor: log(4*sin^2(ω/2)) regressor = jnp.log(4 * jnp.sin(omega / 2) ** 2) log_periodogram = jnp.log(psd_jax) # JAX linear regression (simplified) # For production use, consider using jax.scipy.stats.linregress x_mean = jnp.mean(regressor) y_mean = jnp.mean(log_periodogram) numerator = jnp.sum((regressor - x_mean) * (log_periodogram - y_mean)) denominator = jnp.sum((regressor - x_mean) ** 2) slope = numerator / denominator intercept = y_mean - slope * x_mean # Calculate R-squared y_pred = slope * regressor + intercept ss_res = jnp.sum((log_periodogram - y_pred) ** 2) ss_tot = jnp.sum((log_periodogram - y_mean) ** 2) r_squared = 1 - (ss_res / ss_tot) d_parameter = -float(slope) # d = -slope # Apply bias correction if requested if self.parameters["apply_bias_correction"]: m = len(freqs_sel) bias_correction = 0.5 * jnp.log(m) / m d_parameter += float(bias_correction) # Convert to Hurst parameter: H = d + 0.5 hurst = d_parameter + 0.5 # Ensure Hurst parameter is in valid range hurst = np.clip(hurst, 0.01, 0.99) self.results = { "hurst_parameter": float(hurst), "d_parameter": float(d_parameter), "intercept": float(intercept), "slope": float(slope), "r_squared": float(r_squared), "p_value": None, # Not computed in JAX version "std_error": None, # Not computed in JAX version "m": int(len(freqs_sel)), "log_regressor": np.array(regressor), "log_periodogram": np.array(log_periodogram), "frequency": np.array(freqs_sel), "periodogram": np.array(psd_sel), "method": "jax", "optimization_framework": self.optimization_framework, } return self.results
[docs] def _estimate_numba(self, data: np.ndarray) -> Dict[str, Any]: """Numba-optimized implementation of GPH estimation.""" try: from lrdbenchmark.analysis.high_performance.numba.gph_numba import ( GPHEstimatorNumba, ) estimator = GPHEstimatorNumba( min_freq_ratio=self.parameters["min_freq_ratio"], max_freq_ratio=self.parameters["max_freq_ratio"], apply_bias_correction=self.parameters["apply_bias_correction"], ) result = estimator.estimate(data) # Add method info result["method"] = "numba" result["optimization_framework"] = self.optimization_framework return result except ImportError: raise ImportError("Numba implementation not available") except Exception as e: raise RuntimeError(f"Numba estimation failed: {e}")
[docs] def get_optimization_info(self) -> Dict[str, Any]: """Get information about available optimizations and current selection.""" return { "current_framework": self.optimization_framework, "jax_available": JAX_AVAILABLE, "numba_available": NUMBA_AVAILABLE, "recommended_framework": self._get_recommended_framework(), }
[docs] def plot_scaling(self, save_path: Optional[str] = None) -> None: """Plot the scaling relationship and PSD.""" if not self.results: raise ValueError("No results available. Run estimate() first.") plt.figure(figsize=(15, 4)) # Log-log scaling relationship plt.subplot(1, 3, 1) x = self.results["log_regressor"] y = self.results["log_periodogram"] plt.scatter(x, y, s=40, alpha=0.7, label="Data points") # Plot fitted line slope = self.results["slope"] intercept = self.results["intercept"] x_fit = np.linspace(min(x), max(x), 100) y_fit = slope * x_fit + intercept plt.plot(x_fit, y_fit, "r--", label="Linear fit") plt.xlabel("log(4 sin²(ω/2))") plt.ylabel("log(Periodogram)") plt.title("GPH Regression") plt.legend() plt.grid(True, alpha=0.3) # Log-log components plt.subplot(1, 3, 2) plt.scatter(np.exp(x), np.exp(y), s=30, alpha=0.7) plt.xscale("log") plt.yscale("log") plt.xlabel("Regressor") plt.ylabel("Periodogram") plt.title("GPH Components (log-log)") plt.grid(True, which="both", ls=":", alpha=0.3) # Plain PSD view for context plt.subplot(1, 3, 3) plt.plot(self.results["frequency"], self.results["periodogram"], alpha=0.7) plt.xlabel("Frequency") plt.ylabel("Periodogram") plt.title("Power Spectral Density") plt.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.show()
[docs] def get_method_recommendation(self, n: int) -> Dict[str, Any]: """Get method recommendation for a given data size.""" if n < 100: return { "recommended_method": "numpy", "reasoning": f"Data size n={n} is too small for optimization benefits", "method_details": { "description": "NumPy implementation", "best_for": "Small datasets (n < 100)", "complexity": "O(n log n)", "memory": "O(n)", "accuracy": "High", }, } elif n < 1000: return { "recommended_method": "numba", "reasoning": f"Data size n={n} benefits from JIT compilation", "method_details": { "description": "Numba JIT-compiled implementation", "best_for": "Medium datasets (100 ≤ n < 1000)", "complexity": "O(n log n)", "memory": "O(n)", "accuracy": "High", }, } else: return { "recommended_method": "jax", "reasoning": f"Data size n={n} benefits from GPU acceleration", "method_details": { "description": "JAX GPU-accelerated implementation", "best_for": "Large datasets (n ≥ 1000)", "complexity": "O(n log n)", "memory": "O(n)", "accuracy": "High", }, }