From c4ae82ba5c18c1e7063995787a60032eb3395333 Mon Sep 17 00:00:00 2001 From: Nicholas Mancuso Date: Wed, 28 Apr 2021 12:20:31 -0700 Subject: [PATCH 1/7] implemented left, right, and doubly truncated gamma distributions --- numpyro/distributions/__init__.py | 4 + numpyro/distributions/continuous.py | 10 +- numpyro/distributions/truncated.py | 320 +++++++++++++++++++++++++++- numpyro/distributions/util.py | 14 ++ test/test_distributions.py | 3 + 5 files changed, 349 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 1c31f67fc..d3548a974 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -73,6 +73,10 @@ TruncatedNormal, TruncatedPolyaGamma, TwoSidedTruncatedDistribution, + TruncatedGamma, + LeftTruncatedGamma, + RightTruncatedGamma, + TwoSidedTruncatedGamma, ) from . import constraints, transforms diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 469c2748e..273af0a77 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -31,7 +31,7 @@ import jax.numpy as jnp import jax.random as random from jax.scipy.linalg import cho_solve, solve_triangular -from jax.scipy.special import betainc, expit, gammaln, logit, multigammaln, ndtr, ndtri +from jax.scipy.special import betainc, expit, gammainc, gammaln, logit, multigammaln, ndtr, ndtri from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution, TransformedDistribution @@ -276,6 +276,14 @@ def mean(self): def variance(self): return self.concentration / jnp.power(self.rate, 2) + def cdf(self, value): + return gammainc(self.concentration, value * self.rate) + + def icdf(self, q): + # https://github.com/pyro-ppl/numpyro/issues/969 + from numpyro.distributions.util import gammaincinv + return gammaincinv(self.concentration, q) * self.rate + class Chi2(Gamma): arg_constraints = {"df": constraints.positive} diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 1717b155a..063b397a6 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -4,12 +4,13 @@ from jax import lax import jax.numpy as jnp import jax.random as random -from jax.scipy.special import logsumexp +from jax.scipy.special import logsumexp, gammainc from jax.tree_util import tree_map from numpyro.distributions import constraints from numpyro.distributions.continuous import ( Cauchy, + Gamma, Laplace, Logistic, Normal, @@ -402,3 +403,320 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, params): return cls(batch_shape=aux_data) + + +def TruncatedGamma(base_gamma, low=None, high=None, validate_args=None): + """ + A function to generate a truncated gamma distribution. + + :param base_gamma: The base Gamma distribution to be truncated. + :param low: the value which is used to truncate the base distribution from below. + Setting this parameter to None to not truncate from below. + :param high: the value which is used to truncate the base distribution from above. + Setting this parameter to None to not truncate from above. + """ + if high is None: + if low is None: + return base_gamma + else: + return LeftTruncatedGamma(base_gamma, low=low, validate_args=validate_args) + elif low is None: + return RightTruncatedGamma(base_gamma, high=high, validate_args=validate_args) + else: + return TwoSidedTruncatedGamma( + base_gamma, low=low, high=high, validate_args=validate_args + ) + + +class LeftTruncatedGamma(Distribution): + arg_constraints = {"low": constraints.positive} + reparametrized_params = ["low"] + + def __init__(self, base_gamma, low, validate_args=None): + assert isinstance(base_gamma, Gamma) + batch_shape = lax.broadcast_shapes(base_gamma.batch_shape, jnp.shape(low)) + self.base_gamma = tree_map( + lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma + ) + (self.low,) = promote_shapes(low, shape=batch_shape) + self._support = constraints.greater_than(low) + super().__init__(batch_shape, validate_args=validate_args) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return self._support + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + u = random.uniform(key, sample_shape + self.batch_shape) + lscale = self.base_gamma.cdf(self.low) + q = (1 - u) * lscale + u + return self.base_gamma.icdf(q) + + @validate_sample + def log_prob(self, value): + lprob = self.base_gamma.log_prob(value) + lscale = self.base_gamma.cdf(self.low) + return lprob - jnp.log(1.0 - lscale) + + def _scale_moment(self, t): + assert t > -self.concentration + s_lscale = gammainc(self.base_gamma.concentration + t, self.low * self.rate) + lscale = self.base_gamma.cdf(self.low) + return (1.0 - s_lscale) / (1.0 - lscale) + + @property + def mean(self): + base_mean = self.base_gamma.mean + rescale = self._scale_moment(1.0) + return rescale * base_mean + + @property + def variance(self): + # compute E[X]^2 + fst_m_sq = jnp.pow(self.mean, 2.0) + + # compute E[X^2] + base_sec_mt = ( + (self.concentration + 1) * self.concentration * jnp.pow(self.rate, -2.0) + ) + rescale = self._scale_moment(2.0) + sec_mt = base_sec_mt * rescale + + # V[X] = E[X^2] - E[X]^2 + return sec_mt - fst_m_sq + + def tree_flatten(self): + base_flatten, base_aux = self.base_gamma.tree_flatten() + if isinstance(self._support.lower_bound, (int, float)): + return base_flatten, ( + type(self.base_gamma), + base_aux, + self._support.lower_bound, + ) + else: + return (base_flatten, self.low), (type(self.base_gamma), base_aux) + + @classmethod + def tree_unflatten(cls, aux_data, params): + if len(aux_data) == 2: + base_flatten, low = params + base_cls, base_aux = aux_data + else: + base_flatten = params + base_cls, base_aux, low = aux_data + base_gamma = Gamma.tree_unflatten(base_aux, base_flatten) + return cls(base_gamma, low=low) + + @validate_sample + def cdf(self, value): + gcdf = self.base_gamma.cdf(value) + lscale = self.base_gamma.cdf(self.low) + return (gcdf - lscale) / (1.0 - lscale) + + def icdf(self, q): + lscale = self.base_gamma.cdf(self.low) + q = q * (1.0 - lscale) + lscale + return self.base_gamma.icdf(q) + + +class RightTruncatedGamma(Distribution): + arg_constraints = {"high": constraints.positive} + reparametrized_params = ["high"] + + def __init__(self, base_gamma, high, validate_args=None): + assert isinstance(base_gamma, Gamma) + batch_shape = lax.broadcast_shapes(base_gamma.batch_shape, jnp.shape(high)) + self.base_gamma = tree_map( + lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma + ) + (self.high,) = promote_shapes(high, shape=batch_shape) + self._support = constraints.interval(0.0, high) + super().__init__(batch_shape, validate_args=validate_args) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return self._support + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + u = random.uniform(key, sample_shape + self.batch_shape) + hscale = self.base_gamma.cdf(self.high) + q = u * hscale + return self.base_gamma.icdf(q) + + @validate_sample + def log_prob(self, value): + lprob = self.base_gamma.log_prob(value) + hscale = self.base_gamma.cdf(self.high) + return lprob - jnp.log(hscale) + + def _scale_moment(self, t): + assert t > -self.concentration + s_hscale = gammainc(self.base_gamma.concentration + t, self.high * self.rate) + hscale = self.base_gamma.cdf(self.high) + return s_hscale / hscale + + @property + def mean(self): + base_mean = self.base_gamma.mean + rescale = self._scale_moment(1.0) + return rescale * base_mean + + @property + def variance(self): + # compute E[X]^2 + fst_m_sq = jnp.pow(self.mean, 2.0) + + # compute E[X^2] + base_sec_mt = ( + (self.concentration + 1) * self.concentration * jnp.pow(self.rate, -2.0) + ) + rescale = self._scale_moment(2.0) + sec_mt = base_sec_mt * rescale + + # V[X] = E[X^2] - E[X]^2 + return sec_mt - fst_m_sq + + def tree_flatten(self): + base_flatten, base_aux = self.base_gamma.tree_flatten() + if isinstance(self._support.upper_bound, (int, float)): + return base_flatten, ( + type(self.base_gamma), + base_aux, + self._support.upper_bound, + ) + else: + return (base_flatten, self.high), (type(self.base_gamma), base_aux) + + @classmethod + def tree_unflatten(cls, aux_data, params): + if len(aux_data) == 2: + base_flatten, high = params + base_cls, base_aux = aux_data + else: + base_flatten = params + base_cls, base_aux, high = aux_data + base_gamma = Gamma.tree_unflatten(base_aux, base_flatten) + return cls(base_gamma, high=high) + + @validate_sample + def cdf(self, value): + gcdf = self.base_gamma.cdf(value) + hscale = self.base_gamma.cdf(self.high) + return gcdf / hscale + + def icdf(self, q): + hscale = self.base_gamma.cdf(self.high) + q = q * hscale + return self.base_gamma.icdf(q) + + +class TwoSidedTruncatedGamma(Distribution): + arg_constraints = { + "low": constraints.positive, + "high": constraints.dependent, + } + reparametrized_params = ["low", "high"] + + def __init__(self, base_gamma, low, high, validate_args=None): + assert isinstance(base_gamma, Gamma) + batch_shape = lax.broadcast_shapes( + base_gamma.batch_shape, jnp.shape(low), jnp.shape(high) + ) + self.base_gamma = tree_map( + lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma + ) + (self.low,) = promote_shapes(low, shape=batch_shape) + (self.high,) = promote_shapes(high, shape=batch_shape) + self._support = constraints.interval(low, high) + super().__init__(batch_shape, validate_args=validate_args) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return self._support + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + u = random.uniform(key, sample_shape + self.batch_shape) + lscale = self.base_gamma.cdf(self.low) + hscale = self.base_gamma.cdf(self.high) + q = (1 - u) * lscale + u * hscale + return self.base_gamma.icdf(q) + + @validate_sample + def log_prob(self, value): + lprob = self.base_gamma.log_prob(value) + lscale = self.base_gamma.cdf(self.low) + hscale = self.base_gamma.cdf(self.high) + return lprob - jnp.log(hscale - lscale) + + def _scale_moment(self, t): + assert t > -self.concentration + s_lscale = gammainc(self.base_gamma.concentration + t, self.low * self.rate) + s_hscale = gammainc(self.base_gamma.concentration + t, self.high * self.rate) + lscale = self.base_gamma.cdf(self.low) + hscale = self.base_gamma.cdf(self.high) + return (s_hscale - s_lscale) / (hscale - lscale) + + @property + def mean(self): + base_mean = self.base_gamma.mean + rescale = self._scale_moment(1.0) + return rescale * base_mean + + @property + def variance(self): + # compute E[X]^2 + fst_m_sq = jnp.pow(self.mean, 2.0) + + # compute E[X^2] + base_sec_mt = ( + (self.concentration + 1) * self.concentration * jnp.pow(self.rate, -2.0) + ) + rescale = self._scale_moment(2.0) + sec_mt = base_sec_mt * rescale + + # V[X] = E[X^2] - E[X]^2 + return sec_mt - fst_m_sq + + def tree_flatten(self): + base_flatten, base_aux = self.base_gamma.tree_flatten() + if isinstance(self._support.lower_bound, (int, float)) and isinstance( + self._support.upper_bound, (int, float) + ): + return base_flatten, ( + type(self.base_gamma), + base_aux, + self._support.lower_bound, + self._support.upper_bound, + ) + else: + return (base_flatten, self.low, self.high), ( + type(self.base_gamma), + base_aux, + ) + + @classmethod + def tree_unflatten(cls, aux_data, params): + if len(aux_data) == 2: + base_flatten, low, high = params + base_cls, base_aux = aux_data + else: + base_flatten = params + base_cls, base_aux, low, high = aux_data + base_gamma = Gamma.tree_unflatten(base_aux, base_flatten) + return cls(base_gamma, low=low, high=high) + + @validate_sample + def cdf(self, value): + gcdf = self.base_gamma.cdf(value) + lscale = self.base_gamma.cdf(self.low) + hscale = self.base_gamma.cdf(self.high) + return (gcdf - lscale) / (hscale - lscale) + + def icdf(self, q): + lscale = self.base_gamma.cdf(self.low) + hscale = self.base_gamma.cdf(self.high) + q = q * (hscale - lscale) + lscale + return self.base_gamma.icdf(q) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 3917b407f..c68907c18 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -560,6 +560,20 @@ def is_prng_key(key): return False +def gammaincinv(a, p): + # until jax/lax has direct implementation we'll need to rely on tfp + # https://github.com/pyro-ppl/numpyro/issues/969 + try: + import tensorflow_probability.math as tfpm + except ImportError as e: + raise ImportError( + "To use gammaincinv, please install TensorFlow Probability. It can be" + " installed with `pip install tensorflow_probability`" + ) from e + + return tfpm.igammacinv(a, p) + + # The is sourced from: torch.distributions.util.py # # Copyright (c) 2016- Facebook, Inc (Adam Paszke) diff --git a/test/test_distributions.py b/test/test_distributions.py index ff5578b9e..ca6a9eca6 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -183,6 +183,7 @@ def get_sp_dist(jax_dist): T(dist.Laplace, 0.0, 1.0), T(dist.Laplace, 0.5, jnp.array([1.0, 2.5])), T(dist.Laplace, jnp.array([1.0, -0.5]), jnp.array([2.3, 3.0])), + T(dist.LeftTruncatedGamma, dist.Gamma(2.0, 2.0), 1.0), T(dist.LKJ, 2, 0.5, "onion"), T(dist.LKJ, 5, jnp.array([0.5, 1.0, 2.0]), "cvine"), T(dist.LKJCholesky, 2, 0.5, "onion"), @@ -257,6 +258,7 @@ def get_sp_dist(jax_dist): T(dist.Pareto, 1.0, 2.0), T(dist.Pareto, jnp.array([1.0, 0.5]), jnp.array([0.3, 2.0])), T(dist.Pareto, jnp.array([[1.0], [3.0]]), jnp.array([1.0, 0.5])), + T(dist.RightTruncatedGamma, dist.Gamma(2.0, 2.0), 10.0), T(dist.SoftLaplace, 1.0, 1.0), T(dist.SoftLaplace, jnp.array([-1.0, 50.0]), jnp.array([4.0, 100.0])), T(dist.StudentT, 1.0, 1.0, 0.5), @@ -290,6 +292,7 @@ def get_sp_dist(jax_dist): jnp.array([-2.0, 2.0]), ), T(dist.TwoSidedTruncatedDistribution, dist.Laplace(0.0, 1.0), -2.0, 3.0), + T(dist.TwoSidedTruncatedGamma, dist.Gamma(2.0, 2.0), 0.5, 10.0), T(dist.Uniform, 0.0, 2.0), T(dist.Uniform, 1.0, jnp.array([2.0, 3.0])), T(dist.Uniform, jnp.array([0.0, 0.0]), jnp.array([[2.0], [3.0]])), From f0c80b748cf6ad33aa14aa9f792a401948d09376 Mon Sep 17 00:00:00 2001 From: Nicholas Mancuso Date: Wed, 28 Apr 2021 15:13:57 -0700 Subject: [PATCH 2/7] implemented left, right, and doubly truncated gamma distributions --- numpyro/distributions/continuous.py | 2 +- numpyro/distributions/truncated.py | 53 ++++++++++++++++------------- numpyro/distributions/util.py | 4 +-- 3 files changed, 33 insertions(+), 26 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 273af0a77..e559da495 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -282,7 +282,7 @@ def cdf(self, value): def icdf(self, q): # https://github.com/pyro-ppl/numpyro/issues/969 from numpyro.distributions.util import gammaincinv - return gammaincinv(self.concentration, q) * self.rate + return gammaincinv(self.concentration, q) / self.rate class Chi2(Gamma): diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 063b397a6..4fd6e007d 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -449,9 +449,7 @@ def support(self): def sample(self, key, sample_shape=()): assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) - lscale = self.base_gamma.cdf(self.low) - q = (1 - u) * lscale + u - return self.base_gamma.icdf(q) + return self.icdf(u) @validate_sample def log_prob(self, value): @@ -460,8 +458,10 @@ def log_prob(self, value): return lprob - jnp.log(1.0 - lscale) def _scale_moment(self, t): - assert t > -self.concentration - s_lscale = gammainc(self.base_gamma.concentration + t, self.low * self.rate) + assert t > -self.base_gamma.concentration + s_lscale = gammainc( + self.base_gamma.concentration + t, self.low * self.base_gamma.rate + ) lscale = self.base_gamma.cdf(self.low) return (1.0 - s_lscale) / (1.0 - lscale) @@ -474,11 +474,13 @@ def mean(self): @property def variance(self): # compute E[X]^2 - fst_m_sq = jnp.pow(self.mean, 2.0) + fst_m_sq = jnp.power(self.mean, 2.0) # compute E[X^2] base_sec_mt = ( - (self.concentration + 1) * self.concentration * jnp.pow(self.rate, -2.0) + (self.base_gamma.concentration + 1) + * self.base_gamma.concentration + * jnp.power(self.base_gamma.rate, -2.0) ) rescale = self._scale_moment(2.0) sec_mt = base_sec_mt * rescale @@ -541,9 +543,7 @@ def support(self): def sample(self, key, sample_shape=()): assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) - hscale = self.base_gamma.cdf(self.high) - q = u * hscale - return self.base_gamma.icdf(q) + return self.icdf(u) @validate_sample def log_prob(self, value): @@ -552,8 +552,10 @@ def log_prob(self, value): return lprob - jnp.log(hscale) def _scale_moment(self, t): - assert t > -self.concentration - s_hscale = gammainc(self.base_gamma.concentration + t, self.high * self.rate) + assert t > -self.base_gamma.concentration + s_hscale = gammainc( + self.base_gamma.concentration + t, self.high * self.base_gamma.rate + ) hscale = self.base_gamma.cdf(self.high) return s_hscale / hscale @@ -566,11 +568,13 @@ def mean(self): @property def variance(self): # compute E[X]^2 - fst_m_sq = jnp.pow(self.mean, 2.0) + fst_m_sq = jnp.power(self.mean, 2.0) # compute E[X^2] base_sec_mt = ( - (self.concentration + 1) * self.concentration * jnp.pow(self.rate, -2.0) + (self.base_gamma.concentration + 1) + * self.base_gamma.concentration + * jnp.power(self.base_gamma.rate, -2.0) ) rescale = self._scale_moment(2.0) sec_mt = base_sec_mt * rescale @@ -639,10 +643,7 @@ def support(self): def sample(self, key, sample_shape=()): assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) - lscale = self.base_gamma.cdf(self.low) - hscale = self.base_gamma.cdf(self.high) - q = (1 - u) * lscale + u * hscale - return self.base_gamma.icdf(q) + return self.icdf(u) @validate_sample def log_prob(self, value): @@ -652,9 +653,13 @@ def log_prob(self, value): return lprob - jnp.log(hscale - lscale) def _scale_moment(self, t): - assert t > -self.concentration - s_lscale = gammainc(self.base_gamma.concentration + t, self.low * self.rate) - s_hscale = gammainc(self.base_gamma.concentration + t, self.high * self.rate) + assert t > -self.base_gamma.concentration + s_lscale = gammainc( + self.base_gamma.concentration + t, self.low * self.base_gamma.rate + ) + s_hscale = gammainc( + self.base_gamma.concentration + t, self.high * self.base_gamma.rate + ) lscale = self.base_gamma.cdf(self.low) hscale = self.base_gamma.cdf(self.high) return (s_hscale - s_lscale) / (hscale - lscale) @@ -668,11 +673,13 @@ def mean(self): @property def variance(self): # compute E[X]^2 - fst_m_sq = jnp.pow(self.mean, 2.0) + fst_m_sq = jnp.power(self.mean, 2.0) # compute E[X^2] base_sec_mt = ( - (self.concentration + 1) * self.concentration * jnp.pow(self.rate, -2.0) + (self.base_gamma.concentration + 1) + * self.base_gamma.concentration + * jnp.power(self.base_gamma.rate, -2.0) ) rescale = self._scale_moment(2.0) sec_mt = base_sec_mt * rescale diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index c68907c18..f789f334a 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -564,14 +564,14 @@ def gammaincinv(a, p): # until jax/lax has direct implementation we'll need to rely on tfp # https://github.com/pyro-ppl/numpyro/issues/969 try: - import tensorflow_probability.math as tfpm + import tensorflow_probability as tfpm except ImportError as e: raise ImportError( "To use gammaincinv, please install TensorFlow Probability. It can be" " installed with `pip install tensorflow_probability`" ) from e - return tfpm.igammacinv(a, p) + return tfpm.math.igammainv(a, p) # The is sourced from: torch.distributions.util.py From 13ad41db2ff696088506f38e7886355d372cc12c Mon Sep 17 00:00:00 2001 From: Nicholas Mancuso Date: Wed, 12 May 2021 15:56:02 -0700 Subject: [PATCH 3/7] change to jax substrate on tfp for incomplete gamma calls. updated tests --- numpyro/distributions/util.py | 2 +- test/test_distributions.py | 40 +++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index f789f334a..d3798c192 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -571,7 +571,7 @@ def gammaincinv(a, p): " installed with `pip install tensorflow_probability`" ) from e - return tfpm.math.igammainv(a, p) + return tfpm.substrates.jax.math.igammainv(a, p) # The is sourced from: torch.distributions.util.py diff --git a/test/test_distributions.py b/test/test_distributions.py index ca6a9eca6..8ed936f77 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -733,6 +733,41 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): ) assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) return + elif isinstance(jax_dist, + ( + dist.LeftTruncatedGamma, + dist.RightTruncatedGamma, + dist.TwoSidedTruncatedGamma, + ), + ): + # params = [base_gamma[concentration, rate], low, high] + if isinstance(jax_dist, dist.LeftTruncatedGamma): + conc, rate, low = ( + params[0].concentration, + params[0].rate, + params[1], + ) + high = np.inf + elif isinstance(jax_dist, dist.RightTruncatedGamma): + conc, rate, high = ( + params[0].concentration, + params[0].rate, + params[1], + ) + low = -np.inf + else: + conc, rate, low, high = ( + params[0].concentration, + params[0].rate, + params[1], + params[2], + ) + sp_dist = get_sp_dist(dist.Gamma)(conc, rate) + expected = sp_dist.logpdf(samples) - jnp.log( + sp_dist.cdf(high) - sp_dist.cdf(low) + ) + assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) + return pytest.skip("no corresponding scipy distn.") if _is_batched_multivariate(jax_dist): pytest.skip("batching not allowed in multivariate distns.") @@ -1150,6 +1185,11 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): and dist_args[i] == "base_dist" ): continue + if ( + jax_dist is dist.TwoSidedTruncatedGamma + and dist_args[i] == "base_gamma" + ): + continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": continue if params[i] is None: From 265ca67b7cee54e83d8cf3037496a0a19fd74a54 Mon Sep 17 00:00:00 2001 From: Nicholas Mancuso Date: Tue, 12 Oct 2021 10:13:29 -0700 Subject: [PATCH 4/7] flake8 update --- numpyro/distributions/__init__.py | 4 ++++ test/test_distributions.py | 40 +++++++++++++++---------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 3f777ae48..a66468cc9 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -145,6 +145,7 @@ "MultinomialLogits", "MultinomialProbs", "MultivariateNormal", + "LeftTruncatedGamma", "LowRankMultivariateNormal", "Normal", "NegativeBinomialProbs", @@ -156,6 +157,7 @@ "ProjectedNormal", "PRNGIdentity", "RightTruncatedDistribution", + "RightTruncatedGamma", "SineBivariateVonMises", "SineSkewed", "SoftLaplace", @@ -163,9 +165,11 @@ "TransformedDistribution", "TruncatedCauchy", "TruncatedDistribution", + "TruncatedGamma", "TruncatedNormal", "TruncatedPolyaGamma", "TwoSidedTruncatedDistribution", + "TwoSidedTruncatedGamma", "Uniform", "Unit", "VonMises", diff --git a/test/test_distributions.py b/test/test_distributions.py index 51da49ee7..879aea644 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -906,34 +906,35 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): ) assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) return - elif isinstance(jax_dist, - ( - dist.LeftTruncatedGamma, - dist.RightTruncatedGamma, - dist.TwoSidedTruncatedGamma, - ), + elif isinstance( + jax_dist, + ( + dist.LeftTruncatedGamma, + dist.RightTruncatedGamma, + dist.TwoSidedTruncatedGamma, + ), ): - # params = [base_gamma[concentration, rate], low, high] + # params = [base_gamma[concentration, rate], low, high] if isinstance(jax_dist, dist.LeftTruncatedGamma): conc, rate, low = ( - params[0].concentration, - params[0].rate, - params[1], + params[0].concentration, + params[0].rate, + params[1], ) high = np.inf elif isinstance(jax_dist, dist.RightTruncatedGamma): conc, rate, high = ( - params[0].concentration, - params[0].rate, - params[1], + params[0].concentration, + params[0].rate, + params[1], ) low = -np.inf else: conc, rate, low, high = ( - params[0].concentration, - params[0].rate, - params[1], - params[2], + params[0].concentration, + params[0].rate, + params[1], + params[2], ) sp_dist = get_sp_dist(dist.Gamma)(conc, rate) expected = sp_dist.logpdf(samples) - jnp.log( @@ -1396,10 +1397,7 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): and dist_args[i] == "base_dist" ): continue - if ( - jax_dist is dist.TwoSidedTruncatedGamma - and dist_args[i] == "base_gamma" - ): + if jax_dist is dist.TwoSidedTruncatedGamma and dist_args[i] == "base_gamma": continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": continue From 4778b7c0e6038606cfa75a7064a66b2f6e7bad5a Mon Sep 17 00:00:00 2001 From: Nicholas Mancuso Date: Tue, 12 Oct 2021 10:25:03 -0700 Subject: [PATCH 5/7] flake8 update --- numpyro/distributions/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index bc2381337..57ca9757d 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -567,9 +567,9 @@ def gammaincinv(a, p): import tensorflow_probability as tfpm except ImportError as e: raise ImportError( - "To use gammaincinv, please install TensorFlow Probability. It can be" - " installed with `pip install tensorflow_probability`" - ) from e + "To use gammaincinv, please install TensorFlow Probability. It can be" + " installed with `pip install tensorflow_probability`" + ) from e return tfpm.substrates.jax.math.igammainv(a, p) From 83b411f8e3ef6009146c5e41c4a69a7828f9ce02 Mon Sep 17 00:00:00 2001 From: Nicholas Mancuso Date: Tue, 12 Oct 2021 10:50:25 -0700 Subject: [PATCH 6/7] flake8 update --- numpyro/distributions/continuous.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 5c11e1485..1194c9eb9 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -31,7 +31,16 @@ import jax.numpy as jnp import jax.random as random from jax.scipy.linalg import cho_solve, solve_triangular -from jax.scipy.special import betainc, expit, gammainc, gammaln, logit, multigammaln, ndtr, ndtri +from jax.scipy.special import ( + betainc, + expit, + gammainc, + gammaln, + logit, + multigammaln, + ndtr, + ndtri, +) from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution, TransformedDistribution @@ -288,6 +297,7 @@ def cdf(self, value): def icdf(self, q): # https://github.com/pyro-ppl/numpyro/issues/969 from numpyro.distributions.util import gammaincinv + return gammaincinv(self.concentration, q) / self.rate From b69bc972927e59996b98c8bd13f460f6c65d8606 Mon Sep 17 00:00:00 2001 From: Nicholas Mancuso Date: Tue, 12 Oct 2021 10:54:19 -0700 Subject: [PATCH 7/7] isort update --- numpyro/distributions/__init__.py | 6 +++--- numpyro/distributions/truncated.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index a66468cc9..07d108539 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -81,15 +81,15 @@ from numpyro.distributions.transforms import biject_to from numpyro.distributions.truncated import ( LeftTruncatedDistribution, + LeftTruncatedGamma, RightTruncatedDistribution, + RightTruncatedGamma, TruncatedCauchy, TruncatedDistribution, + TruncatedGamma, TruncatedNormal, TruncatedPolyaGamma, TwoSidedTruncatedDistribution, - TruncatedGamma, - LeftTruncatedGamma, - RightTruncatedGamma, TwoSidedTruncatedGamma, ) diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 1dbf36068..e0331f073 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -4,7 +4,7 @@ from jax import lax import jax.numpy as jnp import jax.random as random -from jax.scipy.special import logsumexp, gammainc +from jax.scipy.special import gammainc, logsumexp from jax.tree_util import tree_map from numpyro.distributions import constraints