Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ unconstrain_fn
--------------
.. autofunction:: numpyro.infer.util.unconstrain_fn

get_log_density_fn
------------------
.. autofunction:: numpyro.infer.util.get_log_density_fn

LogDensityInfo
--------------
.. autoclass:: numpyro.infer.util.LogDensityInfo
:members:
:show-inheritance:

potential_energy
----------------
.. autofunction:: numpyro.infer.util.potential_energy
Expand Down
3,023 changes: 1,832 additions & 1,191 deletions notebooks/source/other_samplers.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions numpyro/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,12 @@
"""A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays."""


PositionDict: TypeAlias = dict[str, jax.Array]
"""An unconstrained position dict keyed by sample-site name.

Used as the canonical input/output type for log-density and postprocess
callables exposed to external samplers (see
:class:`~numpyro.infer.LogDensityInfo`)."""


NumLikeT = TypeVar("NumLikeT", bound=NumLike)
11 changes: 10 additions & 1 deletion numpyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,28 @@
from numpyro.infer.mixed_hmc import MixedHMC
from numpyro.infer.sa import SA
from numpyro.infer.svi import SVI
from numpyro.infer.util import Predictive, log_likelihood
from numpyro.infer.util import (
LogDensityInfo,
Predictive,
get_log_density_fn,
initialize_model,
log_likelihood,
)

from . import autoguide, calibration, reparam

__all__ = [
"AIES",
"autoguide",
"calibration",
"get_log_density_fn",
"init_to_feasible",
"init_to_mean",
"init_to_median",
"init_to_sample",
"init_to_uniform",
"init_to_value",
"initialize_model",
"log_likelihood",
"psis_diagnostic",
"reparam",
Expand All @@ -51,6 +59,7 @@
"HMC",
"HMCECS",
"HMCGibbs",
"LogDensityInfo",
"MCMC",
"MixedHMC",
"NUTS",
Expand Down
178 changes: 151 additions & 27 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Sequence
from contextlib import contextmanager
from functools import partial
from typing import Callable, Optional
from typing import Any, Callable, NamedTuple, Optional
import warnings

import numpy as np
Expand All @@ -18,7 +18,7 @@

import numpyro
from numpyro import distributions as dist
from numpyro._typing import TraceT
from numpyro._typing import PositionDict, TraceT
from numpyro.distributions import constraints
from numpyro.distributions.transforms import biject_to
from numpyro.distributions.util import is_identically_one, sum_rightmost
Expand All @@ -36,11 +36,13 @@

__all__ = [
"find_valid_initial_params",
"get_log_density_fn",
"get_potential_fn",
"log_density",
"log_likelihood",
"potential_energy",
"initialize_model",
"LogDensityInfo",
"Predictive",
]

Expand All @@ -50,6 +52,27 @@
ParamInfo = namedtuple("ParamInfo", ["z", "potential_energy", "z_grad"])


class LogDensityInfo(NamedTuple):
"""Return value of :func:`get_log_density_fn`.

The callable fields (``logdensity_fn``, ``postprocess_fn``) have their model
arguments pre-bound; the caller does not need to repeat ``model_args`` /
``model_kwargs`` or remember to negate.

:ivar logdensity_fn: negated potential energy (a log joint density).
:ivar init_position: unconstrained initial values from
:func:`find_valid_initial_params`.
:ivar postprocess_fn: single-position transform back to constrained space,
with ``deterministic`` sites included.
:ivar model_info: underlying :class:`ModelInfo` for power users.
"""

logdensity_fn: Callable[[PositionDict], jax.Array]
init_position: PositionDict
postprocess_fn: Callable[[PositionDict], PositionDict]
model_info: ModelInfo


class _substitute_default_key(Messenger):
def process_message(self, msg):
if msg["type"] == "prng_key" and msg["value"] is None:
Expand Down Expand Up @@ -191,44 +214,74 @@ def transform_fn(transforms, params, invert=False):
return {k: transforms[k](v) if k in transforms else v for k, v in params.items()}


def constrain_fn(model, model_args, model_kwargs, params, return_deterministic=False):
def constrain_fn(
model,
model_args,
model_kwargs,
params,
return_deterministic=False,
batch_ndims=0,
):
"""
(EXPERIMENTAL INTERFACE) Gets value at each latent site in `model` given
unconstrained parameters `params`. The `transforms` is used to transform these
unconstrained parameters to base values of the corresponding priors in `model`.
If a prior is a transformed distribution, the corresponding base value lies in
the support of base distribution. Otherwise, the base value lies in the support
of the distribution.
unconstrained parameters `params`. Each unconstrained value is pushed through
the inverse bijection of the corresponding prior's support to recover the
constrained value. If a prior is a transformed distribution, the corresponding
base value lies in the support of the base distribution. Otherwise, the base
value lies in the support of the distribution.

``batch_ndims`` declares how many leading sample dimensions each leaf of
``params`` carries, so the transforms are ``jax.vmap``-ed the correct number
of times. The common layouts are: ``batch_ndims=0`` (a single unconstrained
position, the default), ``batch_ndims=1`` (a single chain of samples), and
``batch_ndims=2`` (``num_chains x num_samples``, matching
:meth:`MCMC.get_samples(group_by_chain=True)
<numpyro.infer.MCMC.get_samples>`). This is useful to map a batch of
unconstrained samples produced by an external sampler back to the
constrained space.

:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of unconstrained values keyed by site
names.
names. Leading dimensions are batch dimensions (see ``batch_ndims``).
:param bool return_deterministic: whether to return the value of `deterministic`
sites from the model. Defaults to `False`.
:param int batch_ndims: number of leading batch dimensions on each leaf of
``params``. Defaults to ``0`` (a single position).
:return: `dict` of transformed params.
"""
if batch_ndims < 0:
raise ValueError(
f"batch_ndims must be a non-negative integer, got {batch_ndims}."
)

def substitute_fn(site):
if site["name"] in params:
if site["type"] == "sample":
with helpful_support_errors(site):
return biject_to(site["fn"].support)(params[site["name"]])
elif site["type"] == "param":
constraint = site["kwargs"].pop("constraint", constraints.real)
with helpful_support_errors(site):
return biject_to(constraint)(params[site["name"]])
else:
return params[site["name"]]
def single(position):
def substitute_fn(site):
if site["name"] in position:
if site["type"] == "sample":
with helpful_support_errors(site):
return biject_to(site["fn"].support)(position[site["name"]])
elif site["type"] == "param":
constraint = site["kwargs"].pop("constraint", constraints.real)
with helpful_support_errors(site):
return biject_to(constraint)(position[site["name"]])
else:
return position[site["name"]]

substituted_model = substitute(model, substitute_fn=substitute_fn)
model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs)
return {
k: v["value"]
for k, v in model_trace.items()
if (k in params) or (return_deterministic and (v["type"] == "deterministic"))
}
substituted_model = substitute(model, substitute_fn=substitute_fn)
model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs)
return {
k: v["value"]
for k, v in model_trace.items()
if (k in position)
or (return_deterministic and (v["type"] == "deterministic"))
}

fn = single
for _ in range(batch_ndims):
fn = jax.vmap(fn)
return fn(params)


def get_transforms(model, model_args, model_kwargs, params):
Expand Down Expand Up @@ -801,6 +854,77 @@ def initialize_model(
)


def get_log_density_fn(
rng_key: jax.Array,
model: Callable[..., Any],
*,
model_args: tuple[Any, ...] = (),
model_kwargs: Optional[dict[str, Any]] = None,
init_strategy: Callable[..., Any] = init_to_uniform,
forward_mode_differentiation: bool = False,
validate_grad: bool = True,
) -> LogDensityInfo:
"""
(EXPERIMENTAL INTERFACE) Convenience entry point that wraps
:func:`initialize_model` for use with external samplers (e.g. ``blackjax``,
``flowMC``).

The returned :class:`LogDensityInfo` carries a ``logdensity_fn`` that is
already the *negated* potential energy (i.e. a log joint density that
external samplers can maximize), an unconstrained ``init_position``, and a
single-position ``postprocess_fn`` callable with ``model_args`` /
``model_kwargs`` already bound; the caller does not have to repeat them or
remember to negate by hand.

:param jax.Array rng_key: PRNG key used to sample the initial position
from the prior.
:param Callable model: a Python callable containing NumPyro primitives.
:param tuple model_args: positional arguments passed to ``model``.
:param dict model_kwargs: keyword arguments passed to ``model``.
:param Callable init_strategy: a per-site initialization function. See
:ref:`init_strategy` section for available functions.
:param bool forward_mode_differentiation: whether to use forward-mode
differentiation when validating initial parameters.
:param bool validate_grad: whether to validate gradients of the initial
params.
:returns: a :class:`LogDensityInfo`.
:rtype: LogDensityInfo

**Example**::

info = get_log_density_fn(rng_key, model, model_args=(x, y))
kernel = blackjax.nuts(info.logdensity_fn, step_size, inverse_mass_matrix)
state = kernel.init(info.init_position)
"""
# Defensive copy: callers may mutate their kwargs dict after this call;
# the returned closures must not observe those mutations.
model_kwargs = dict(model_kwargs) if model_kwargs else {}
model_info = initialize_model(
rng_key,
model,
init_strategy=init_strategy,
dynamic_args=False,
model_args=model_args,
model_kwargs=model_kwargs,
forward_mode_differentiation=forward_mode_differentiation,
validate_grad=validate_grad,
)
# With `dynamic_args=False`, `potential_fn` is a single-position callable
# `(position) -> potential_energy`. Expose its negation as a log joint
# density that external samplers (e.g. ``blackjax``, ``flowMC``) maximize.
potential_fn = model_info.potential_fn

def logdensity_fn(position: PositionDict) -> jax.Array:
return -potential_fn(position)

return LogDensityInfo(
Comment thread
juanitorduz marked this conversation as resolved.
logdensity_fn=logdensity_fn,
init_position=model_info.param_info.z,
postprocess_fn=model_info.postprocess_fn,
model_info=model_info,
)


def _predictive(
rng_key,
model,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ module = [
"numpyro.diagnostics.*",
"numpyro.handlers.*",
"numpyro.infer.elbo.*",
"numpyro.infer.util.*",
"numpyro.optim.*",
"numpyro.primitives.*",
"numpyro.patch.*",
Expand Down
Loading
Loading