diff --git a/gwinferno/distributions.py b/gwinferno/distributions.py index 40db691..c449adb 100644 --- a/gwinferno/distributions.py +++ b/gwinferno/distributions.py @@ -124,19 +124,22 @@ def truncnorm_pdf(xx, mu, sig, low, high, log=False): $$ p(x) \propto \mathcal{N}(x | \mu, \sigma)\Theta(x-x_\mathrm{min})\Theta(x_\mathrm{max}-x) $$ `log=True` makes this a log-normal distribution! + + If 'low == -jnp.inf', then return a right-truncated norm + If 'high == jnp.inf', then return a left-truncated norm """ if log: prob = jnp.exp(-jnp.power(jnp.log(xx) - mu, 2) / (2 * sig**2)) continuous_norm = 1 / (xx * sig * (2 * jnp.pi) ** 0.5) - left_tail_cdf = 0.5 * (1 + erf((jnp.log(low) - mu) / (sig * (2**0.5)))) - right_tail_cdf = 0.5 * (1 + erf((jnp.log(high) - mu) / (sig * (2**0.5)))) + left_tail_cdf = 0 if low == -jnp.inf else 0.5 * (1 + erf((jnp.log(low) - mu) / (sig * (2**0.5)))) + right_tail_cdf = 1 if high == jnp.inf else 0.5 * (1 + erf((jnp.log(high) - mu) / (sig * (2**0.5)))) denom = right_tail_cdf - left_tail_cdf else: prob = jnp.exp(-jnp.power(xx - mu, 2) / (2 * sig**2)) continuous_norm = 1 / (sig * (2 * jnp.pi) ** 0.5) - left_tail_cdf = 0.5 * (1 + erf((low - mu) / (sig * (2**0.5)))) - right_tail_cdf = 0.5 * (1 + erf((high - mu) / (sig * (2**0.5)))) + left_tail_cdf = 0 if low == -jnp.inf else 0.5 * (1 + erf((low - mu) / (sig * (2**0.5)))) + right_tail_cdf = 1 if high == jnp.inf else 0.5 * (1 + erf((high - mu) / (sig * (2**0.5)))) denom = right_tail_cdf - left_tail_cdf norm = continuous_norm / denom