#!/usr/bin/env python3
"""
Unified Geweke-Porter-Hudak (GPH) Hurst parameter estimator.
This module implements the GPH estimator with automatic optimization framework
selection (JAX, Numba, NumPy) for the best performance on the available hardware.
"""
import warnings
from typing import Any, Dict, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from scipy import signal, stats
from lrdbenchmark.analysis.backend_utils import (
JAX_AVAILABLE,
NUMBA_AVAILABLE,
select_backend,
)
# Import optimization frameworks
if JAX_AVAILABLE:
import jax
import jax.numpy as jnp
from jax import jit, vmap
if NUMBA_AVAILABLE:
import numba
from numba import jit as numba_jit
from numba import prange
# Import base estimator
try:
from lrdbenchmark.analysis.base_estimator import BaseEstimator
except ImportError:
from lrdbenchmark.analysis.base_estimator import BaseEstimator
[docs]
class GPHEstimator(BaseEstimator):
"""
Unified Geweke-Porter-Hudak (GPH) Hurst parameter estimator.
This estimator uses log-periodogram regression with the regressor
log(4*sin^2(ω/2)) to estimate the fractional differencing parameter d,
then converts to Hurst parameter as H = d + 0.5.
Features:
- Automatic optimization framework selection (JAX, Numba, NumPy)
- GPU acceleration with JAX when available
- JIT compilation with Numba for CPU optimization
- Graceful fallbacks when optimization frameworks fail
Parameters
----------
min_freq_ratio : float, optional (default=0.01)
Minimum frequency ratio (relative to Nyquist) for fitting.
max_freq_ratio : float, optional (default=0.1)
Maximum frequency ratio (relative to Nyquist) for fitting.
apply_bias_correction : bool, optional (default=True)
Whether to apply bias correction for finite sample effects.
use_optimization : str, optional (default='auto')
Optimization framework to use: 'auto', 'jax', 'numba', 'numpy'
"""
[docs]
def __init__(
self,
min_freq_ratio: float = 0.01,
max_freq_ratio: float = 0.1,
apply_bias_correction: bool = True,
use_optimization: str = "auto",
):
super().__init__()
# Estimator parameters
self.parameters = {
"min_freq_ratio": min_freq_ratio,
"max_freq_ratio": max_freq_ratio,
"apply_bias_correction": apply_bias_correction,
}
# Optimization framework
self.optimization_framework = select_backend(use_optimization)
# Results storage
self.results = {}
# Validation
self._validate_parameters()
[docs]
def _validate_parameters(self) -> None:
"""Validate estimator parameters."""
pass
[docs]
def estimate(self, data: Union[np.ndarray, list]) -> Dict[str, Any]:
"""
Estimate Hurst parameter using GPH method with automatic optimization.
Parameters
----------
data : array-like
Time series data.
Returns
-------
dict
Dictionary containing estimation results including:
- hurst_parameter: Estimated Hurst parameter
- d_parameter: Estimated fractional differencing parameter
- intercept: Intercept of the linear fit
- r_squared: R-squared value of the fit
- m: Number of frequency points used in fitting
- log_regressor: Log regressor values
- log_periodogram: Log periodogram values
"""
data = np.asarray(data)
n = len(data)
if n < 100:
warnings.warn("Data length is small, results may be unreliable")
# Select optimal method based on data size and framework
backend = self.optimization_framework
if backend == "jax":
try:
return self._estimate_jax(data)
except Exception as e:
warnings.warn(f"JAX implementation failed: {e}, falling back to NumPy")
return self._estimate_numpy(data)
elif backend == "numba":
try:
return self._estimate_numba(data)
except Exception as e:
warnings.warn(
f"Numba implementation failed: {e}, falling back to NumPy"
)
return self._estimate_numpy(data)
else: # numpy
return self._estimate_numpy(data)
[docs]
def _estimate_numpy(self, data: np.ndarray) -> Dict[str, Any]:
"""NumPy implementation of GPH estimation."""
n = len(data)
# Coarse preliminary estimate to guide bandwidth
coarse_m = int(n**0.8)
freqs, psd = signal.periodogram(data, scaling="density")
coarse_freqs = freqs[1 : coarse_m + 1]
coarse_psd = psd[1 : coarse_m + 1]
if len(coarse_freqs) > 3:
omega_coarse = 2 * np.pi * coarse_freqs
regressor_coarse = np.log(4 * np.sin(omega_coarse / 2) ** 2)
log_psd_coarse = np.log(coarse_psd)
slope_coarse, _, _, _, _ = stats.linregress(
regressor_coarse, log_psd_coarse
)
h_coarse = -slope_coarse + 0.5
else:
h_coarse = 0.5
# Refined bandwidth based on coarse estimate
alpha = 0.5 - 0.2 * abs(h_coarse - 0.5) # Heuristic adjustment
m = int(n**alpha)
# Select frequency range for fitting
freqs_sel = freqs[1 : m + 1]
psd_sel = psd[1 : m + 1]
if len(freqs_sel) < 3:
raise ValueError("Insufficient frequency points for fitting")
# Filter out zero/negative PSD values
valid_mask = psd_sel > 0
freqs_sel = freqs_sel[valid_mask]
psd_sel = psd_sel[valid_mask]
if len(freqs_sel) < 3:
raise ValueError("Insufficient valid PSD points for fitting")
# Convert to angular frequencies
omega = 2 * np.pi * freqs_sel
# GPH regressor: log(4*sin^2(ω/2))
regressor = np.log(4 * np.sin(omega / 2) ** 2)
log_periodogram = np.log(psd_sel)
# Linear regression
slope, intercept, r_value, p_value, std_err = stats.linregress(
regressor, log_periodogram
)
d_parameter = -slope # d = -slope
# Apply bias correction if requested
if self.parameters["apply_bias_correction"]:
m = len(freqs_sel)
# Simple bias correction for finite sample effects
bias_correction = 0.5 * np.log(m) / m
d_parameter += bias_correction
# Convert to Hurst parameter: H = d + 0.5
hurst = d_parameter + 0.5
# Ensure Hurst parameter is in valid range
hurst = np.clip(hurst, 0.01, 0.99)
self.results = {
"hurst_parameter": float(hurst),
"d_parameter": float(d_parameter),
"intercept": float(intercept),
"slope": float(slope),
"r_squared": float(r_value**2),
"p_value": float(p_value),
"std_error": float(std_err),
"m": int(len(freqs_sel)),
"log_regressor": regressor,
"log_periodogram": log_periodogram,
"frequency": freqs_sel,
"periodogram": psd_sel,
"method": "numpy",
"optimization_framework": self.optimization_framework,
}
return self.results
[docs]
def _estimate_jax(self, data: np.ndarray) -> Dict[str, Any]:
"""JAX-optimized implementation of GPH estimation."""
# JAX implementation is hybrid: uses NumPy/SciPy for PSD and JAX for regression.
n = len(data)
freqs, psd = signal.periodogram(data, scaling="density")
# Coarse preliminary estimate to guide bandwidth (NumPy portion)
coarse_m = int(n**0.8)
coarse_freqs = freqs[1 : coarse_m + 1]
coarse_psd = psd[1 : coarse_m + 1]
if len(coarse_freqs) > 3:
omega_coarse = 2 * np.pi * coarse_freqs
regressor_coarse = np.log(4 * np.sin(omega_coarse / 2) ** 2)
log_psd_coarse = np.log(coarse_psd)
slope_coarse, _, _, _, _ = stats.linregress(
regressor_coarse, log_psd_coarse
)
h_coarse = -slope_coarse + 0.5
else:
h_coarse = 0.5
# Refined bandwidth based on coarse estimate
alpha = 0.5 - 0.2 * abs(h_coarse - 0.5) # Heuristic adjustment
m = int(n**alpha)
# Select frequency range for fitting
freqs_sel = freqs[1 : m + 1]
psd_sel = psd[1 : m + 1]
if len(freqs_sel) < 3:
raise ValueError("Insufficient frequency points for fitting")
# Filter out zero/negative PSD values
valid_mask = psd_sel > 0
freqs_sel = freqs_sel[valid_mask]
psd_sel = psd_sel[valid_mask]
if len(freqs_sel) < 3:
raise ValueError("Insufficient valid PSD points for fitting")
# Convert to JAX arrays for computation
freqs_jax = jnp.array(freqs_sel)
psd_jax = jnp.array(psd_sel)
# Convert to angular frequencies
omega = 2 * jnp.pi * freqs_jax
# GPH regressor: log(4*sin^2(ω/2))
regressor = jnp.log(4 * jnp.sin(omega / 2) ** 2)
log_periodogram = jnp.log(psd_jax)
# JAX linear regression (simplified)
# For production use, consider using jax.scipy.stats.linregress
x_mean = jnp.mean(regressor)
y_mean = jnp.mean(log_periodogram)
numerator = jnp.sum((regressor - x_mean) * (log_periodogram - y_mean))
denominator = jnp.sum((regressor - x_mean) ** 2)
slope = numerator / denominator
intercept = y_mean - slope * x_mean
# Calculate R-squared
y_pred = slope * regressor + intercept
ss_res = jnp.sum((log_periodogram - y_pred) ** 2)
ss_tot = jnp.sum((log_periodogram - y_mean) ** 2)
r_squared = 1 - (ss_res / ss_tot)
d_parameter = -float(slope) # d = -slope
# Apply bias correction if requested
if self.parameters["apply_bias_correction"]:
m = len(freqs_sel)
bias_correction = 0.5 * jnp.log(m) / m
d_parameter += float(bias_correction)
# Convert to Hurst parameter: H = d + 0.5
hurst = d_parameter + 0.5
# Ensure Hurst parameter is in valid range
hurst = np.clip(hurst, 0.01, 0.99)
self.results = {
"hurst_parameter": float(hurst),
"d_parameter": float(d_parameter),
"intercept": float(intercept),
"slope": float(slope),
"r_squared": float(r_squared),
"p_value": None, # Not computed in JAX version
"std_error": None, # Not computed in JAX version
"m": int(len(freqs_sel)),
"log_regressor": np.array(regressor),
"log_periodogram": np.array(log_periodogram),
"frequency": np.array(freqs_sel),
"periodogram": np.array(psd_sel),
"method": "jax",
"optimization_framework": self.optimization_framework,
}
return self.results
[docs]
def _estimate_numba(self, data: np.ndarray) -> Dict[str, Any]:
"""Numba-optimized implementation of GPH estimation."""
try:
from lrdbenchmark.analysis.high_performance.numba.gph_numba import (
GPHEstimatorNumba,
)
estimator = GPHEstimatorNumba(
min_freq_ratio=self.parameters["min_freq_ratio"],
max_freq_ratio=self.parameters["max_freq_ratio"],
apply_bias_correction=self.parameters["apply_bias_correction"],
)
result = estimator.estimate(data)
# Add method info
result["method"] = "numba"
result["optimization_framework"] = self.optimization_framework
return result
except ImportError:
raise ImportError("Numba implementation not available")
except Exception as e:
raise RuntimeError(f"Numba estimation failed: {e}")
[docs]
def get_optimization_info(self) -> Dict[str, Any]:
"""Get information about available optimizations and current selection."""
return {
"current_framework": self.optimization_framework,
"jax_available": JAX_AVAILABLE,
"numba_available": NUMBA_AVAILABLE,
"recommended_framework": self._get_recommended_framework(),
}
[docs]
def _get_recommended_framework(self) -> str:
"""Get the recommended optimization framework."""
if JAX_AVAILABLE:
return "jax" # Best for GPU acceleration
elif NUMBA_AVAILABLE:
return "numba" # Good for CPU optimization
else:
return "numpy" # Fallback
[docs]
def plot_scaling(self, save_path: Optional[str] = None) -> None:
"""Plot the scaling relationship and PSD."""
if not self.results:
raise ValueError("No results available. Run estimate() first.")
plt.figure(figsize=(15, 4))
# Log-log scaling relationship
plt.subplot(1, 3, 1)
x = self.results["log_regressor"]
y = self.results["log_periodogram"]
plt.scatter(x, y, s=40, alpha=0.7, label="Data points")
# Plot fitted line
slope = self.results["slope"]
intercept = self.results["intercept"]
x_fit = np.linspace(min(x), max(x), 100)
y_fit = slope * x_fit + intercept
plt.plot(x_fit, y_fit, "r--", label="Linear fit")
plt.xlabel("log(4 sin²(ω/2))")
plt.ylabel("log(Periodogram)")
plt.title("GPH Regression")
plt.legend()
plt.grid(True, alpha=0.3)
# Log-log components
plt.subplot(1, 3, 2)
plt.scatter(np.exp(x), np.exp(y), s=30, alpha=0.7)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Regressor")
plt.ylabel("Periodogram")
plt.title("GPH Components (log-log)")
plt.grid(True, which="both", ls=":", alpha=0.3)
# Plain PSD view for context
plt.subplot(1, 3, 3)
plt.plot(self.results["frequency"], self.results["periodogram"], alpha=0.7)
plt.xlabel("Frequency")
plt.ylabel("Periodogram")
plt.title("Power Spectral Density")
plt.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.show()
[docs]
def get_method_recommendation(self, n: int) -> Dict[str, Any]:
"""Get method recommendation for a given data size."""
if n < 100:
return {
"recommended_method": "numpy",
"reasoning": f"Data size n={n} is too small for optimization benefits",
"method_details": {
"description": "NumPy implementation",
"best_for": "Small datasets (n < 100)",
"complexity": "O(n log n)",
"memory": "O(n)",
"accuracy": "High",
},
}
elif n < 1000:
return {
"recommended_method": "numba",
"reasoning": f"Data size n={n} benefits from JIT compilation",
"method_details": {
"description": "Numba JIT-compiled implementation",
"best_for": "Medium datasets (100 ≤ n < 1000)",
"complexity": "O(n log n)",
"memory": "O(n)",
"accuracy": "High",
},
}
else:
return {
"recommended_method": "jax",
"reasoning": f"Data size n={n} benefits from GPU acceleration",
"method_details": {
"description": "JAX GPU-accelerated implementation",
"best_for": "Large datasets (n ≥ 1000)",
"complexity": "O(n log n)",
"memory": "O(n)",
"accuracy": "High",
},
}