Skip to content

Commit

Permalink
Port GARCH11 to v4 (#6119)
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao authored Sep 14, 2022
1 parent c53cd2f commit 91cbebd
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 90 deletions.
169 changes: 127 additions & 42 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from aeppl.abstract import _get_measurable_outputs
from aeppl.logprob import _logprob
from aesara import scan
from aesara.graph import FunctionGraph, rewrite_graph
from aesara.graph.basic import Node, clone_replace
from aesara.raise_op import Assert
Expand Down Expand Up @@ -230,7 +229,7 @@ def random_walk_moment(op, rv, init_dist, innovation_dist, steps):

@_logprob.register(RandomWalkRV)
def random_walk_logp(op, values, *inputs, **kwargs):
# ALthough Aeppl can derive the logprob of random walks, it does not collapse
# Although Aeppl can derive the logprob of random walks, it does not collapse
# what PyMC considers the core dimension of steps. We do it manually here.
(value,) = values
# Recreate RV and obtain inner graph
Expand Down Expand Up @@ -309,7 +308,6 @@ def get_dists(cls, *, mu, sigma, init_dist, **kwargs):
class AutoRegressiveRV(SymbolicRandomVariable):
"""A placeholder used to specify a log-likelihood for an AR sub-graph."""

_print_name = ("AR", "\\operatorname{AR}")
default_output = 1
ar_order: int
constant_term: bool
Expand Down Expand Up @@ -616,17 +614,29 @@ def ar_moment(op, rv, rhos, sigma, init_dist, steps, noise_rng):
return at.full_like(rv, moment(init_dist)[..., -1, None])


class GARCH11(distribution.Continuous):
class GARCH11RV(SymbolicRandomVariable):
"""A placeholder used to specify a GARCH11 graph."""

default_output = 1
_print_name = ("GARCH11", "\\operatorname{GARCH11}")

def update(self, node: Node):
"""Return the update mapping for the noise RV."""
# Since noise is a shared variable it shows up as the last node input
return {node.inputs[-1]: node.outputs[0]}


class GARCH11(Distribution):
r"""
GARCH(1,1) with Normal innovations. The model is specified by
.. math::
y_t = \sigma_t * z_t
y_t \sim N(0, \sigma_t^2)
.. math::
\sigma_t^2 = \omega + \alpha_1 * y_{t-1}^2 + \beta_1 * \sigma_{t-1}^2
with z_t iid and Normal with mean zero and unit standard deviation.
where \sigma_t^2 (the error variance) follows a ARMA(1, 1) model.
Parameters
----------
Expand All @@ -640,54 +650,129 @@ class GARCH11(distribution.Continuous):
initial_vol >= 0, initial volatility, sigma_0
"""

def __new__(cls, *args, **kwargs):
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
rv_type = GARCH11RV

def __new__(cls, *args, steps=None, **kwargs):
steps = get_steps(
steps=steps,
shape=None, # Shape will be checked in `cls.dist`
dims=kwargs.get("dims", None),
observed=kwargs.get("observed", None),
step_shape_offset=1,
)
return super().__new__(cls, *args, steps=steps, **kwargs)

@classmethod
def dist(cls, *args, **kwargs):
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
def dist(cls, omega, alpha_1, beta_1, initial_vol, *, steps=None, **kwargs):
steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=1)
if steps is None:
raise ValueError("Must specify steps or shape parameter")
steps = at.as_tensor_variable(intX(steps), ndim=0)

def __init__(self, omega, alpha_1, beta_1, initial_vol, *args, **kwargs):
super().__init__(*args, **kwargs)
omega = at.as_tensor_variable(omega)
alpha_1 = at.as_tensor_variable(alpha_1)
beta_1 = at.as_tensor_variable(beta_1)
initial_vol = at.as_tensor_variable(initial_vol)

self.omega = omega = at.as_tensor_variable(omega)
self.alpha_1 = alpha_1 = at.as_tensor_variable(alpha_1)
self.beta_1 = beta_1 = at.as_tensor_variable(beta_1)
self.initial_vol = at.as_tensor_variable(initial_vol)
self.mean = at.as_tensor_variable(0.0)
init_dist = Normal.dist(0, initial_vol)
# Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
init_dist = ignore_logprob(init_dist)

return super().dist([omega, alpha_1, beta_1, initial_vol, init_dist, steps], **kwargs)

def get_volatility(self, x):
x = x[:-1]
@classmethod
def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None):
if size is not None:
batch_size = size
else:
# In this case the size of the init_dist depends on the parameters shape
batch_size = at.broadcast_shape(omega, alpha_1, beta_1, initial_vol)
init_dist = change_dist_size(init_dist, batch_size)
# initial_vol = initial_vol * at.ones(batch_size)

def volatility_update(x, vol, w, a, b):
return at.sqrt(w + a * at.square(x) + b * at.square(vol))
# Create OpFromGraph representing random draws from GARCH11 process
# Variables with underscore suffix are dummy inputs into the OpFromGraph
init_ = init_dist.type()
initial_vol_ = initial_vol.type()
omega_ = omega.type()
alpha_1_ = alpha_1.type()
beta_1_ = beta_1.type()
steps_ = steps.type()

vol, _ = scan(
fn=volatility_update,
sequences=[x],
outputs_info=[self.initial_vol],
non_sequences=[self.omega, self.alpha_1, self.beta_1],
noise_rng = aesara.shared(np.random.default_rng())

def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
new_sigma = at.sqrt(
omega + alpha_1 * at.square(prev_y) + beta_1 * at.square(prev_sigma)
)
next_rng, new_y = Normal.dist(mu=0, sigma=new_sigma, rng=rng).owner.outputs
return (new_y, new_sigma), {rng: next_rng}

(y_t, _), innov_updates_ = aesara.scan(
fn=step,
outputs_info=[init_, initial_vol_ * at.ones(batch_size)],
non_sequences=[omega_, alpha_1_, beta_1_, noise_rng],
n_steps=steps_,
strict=True,
)
return at.concatenate([[self.initial_vol], vol])
(noise_next_rng,) = tuple(innov_updates_.values())

def logp(self, x):
"""
Calculate log-probability of GARCH(1, 1) distribution at specified value.
garch11_ = at.concatenate([init_[None, ...], y_t], axis=0).dimshuffle(
tuple(range(1, y_t.ndim)) + (0,)
)

Parameters
----------
x: numeric
Value for which log-probability is calculated.
garch11_op = GARCH11RV(
inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_],
outputs=[noise_next_rng, garch11_],
ndim_supp=1,
)

Returns
-------
TensorVariable
"""
vol = self.get_volatility(x)
return at.sum(Normal.dist(0.0, sigma=vol).logp(x))
garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps)
return garch11

def _distr_parameters_for_repr(self):
return ["omega", "alpha_1", "beta_1"]

@_change_dist_size.register(GARCH11RV)
def change_garch11_size(op, dist, new_size, expand=False):

if expand:
old_size = dist.shape[:-1]
new_size = tuple(new_size) + tuple(old_size)

return GARCH11.rv_op(
*dist.owner.inputs[:-1],
size=new_size,
)


@_logprob.register(GARCH11RV)
def garch11_logp(
op, values, omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng, **kwargs
):
(value,) = values
# Move the time axis to the first dimension
value_dimswapped = value.dimshuffle((value.ndim - 1,) + tuple(range(0, value.ndim - 1)))
initial_vol = initial_vol * at.ones_like(value_dimswapped[0])

def volatility_update(x, vol, w, a, b):
return at.sqrt(w + a * at.square(x) + b * at.square(vol))

vol, _ = aesara.scan(
fn=volatility_update,
sequences=[value_dimswapped[:-1]],
outputs_info=[initial_vol],
non_sequences=[omega, alpha_1, beta_1],
strict=True,
)
sigma_t = at.concatenate([[initial_vol], vol])
# Compute and collapse logp across time dimension
innov_logp = at.sum(logp(Normal.dist(0, sigma_t), value_dimswapped), axis=0)
return innov_logp


@_moment.register(GARCH11RV)
def garch11_moment(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng):
# GARCH(1,1) mean is zero
return at.zeros_like(rv)


class EulerMaruyama(distribution.Continuous):
Expand Down
1 change: 0 additions & 1 deletion pymc/tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def test_all_distributions_have_moments():

# Distributions that have not been refactored for V4 yet
not_implemented = {
dist_module.timeseries.GARCH11,
dist_module.timeseries.MvGaussianRandomWalk,
dist_module.timeseries.MvStudentTRandomWalk,
dist_module.timeseries.EulerMaruyama,
Expand Down
Loading

0 comments on commit 91cbebd

Please sign in to comment.