"""Objects representing distributions that can be sampled from."""
import collections
import functools
import itertools
import random
import math
import typing
import warnings
import numpy
import decorator
from scenic.core.lazy_eval import (LazilyEvaluable,
requiredProperties, needsLazyEvaluation, valueInContext, makeDelayedFunctionCall)
from scenic.core.utils import DefaultIdentityDict, argsToString, cached, sqrt2, get_type_hints
from scenic.core.errors import RuntimeParseError
## Misc
[docs]def dependencies(thing):
"""Dependencies which must be sampled before this value."""
return getattr(thing, '_dependencies', ())
[docs]def needsSampling(thing):
"""Whether this value requires sampling."""
return isinstance(thing, Distribution) or dependencies(thing)
[docs]def supportInterval(thing):
"""Lower and upper bounds on this value, if known."""
if hasattr(thing, 'supportInterval'):
return thing.supportInterval()
elif isinstance(thing, (int, float)):
return thing, thing
else:
return None, None
def supmin(*vals):
return None if None in vals else min(vals)
def supmax(*vals):
return None if None in vals else max(vals)
def unionOfSupports(supports):
mins, maxes = zip(*supports)
return supmin(*mins), supmax(*maxes)
[docs]def underlyingFunction(thing):
"""Original function underlying a distribution wrapper."""
func = getattr(thing, '__wrapped__', thing)
return getattr(func, '__func__', func)
[docs]def canUnpackDistributions(func):
"""Whether the function supports iterable unpacking of distributions."""
return getattr(func, '_canUnpackDistributions', False)
[docs]def unpacksDistributions(func):
"""Decorator indicating the function supports iterable unpacking of distributions."""
func._canUnpackDistributions = True
return func
[docs]class RejectionException(Exception):
"""Exception used to signal that the sample currently being generated must be rejected."""
pass
[docs]class RandomControlFlowError(RuntimeParseError):
"""Exception indicating illegal conditional control flow depending on a random value.
This includes trying to iterate over a random value, take the length of a random
sequence whose length can't be determined statically, etc.
"""
pass
## Abstract distributions
[docs]class Samplable(LazilyEvaluable):
"""Abstract class for values which can be sampled, possibly depending on other values.
Samplables may specify a proxy object which must have the same distribution as the
original after conditioning on the scenario's requirements. This allows transparent
conditioning without modifying Samplable fields of immutable objects.
Args:
dependencies: sequence of values that this value may depend on (formally, objects
for which sampled values must be provided to `sampleGiven`). It is legal to
include values which are not instances of `Samplable`, e.g. integers.
Attributes:
_conditioned: proxy object as described above; set using `conditionTo`.
_dependencies: tuple of other samplables which must be sampled before this one;
set by the initializer and subsequently immutable.
"""
def __init__(self, dependencies):
deps = []
props = set()
for dep in dependencies:
if needsSampling(dep) or needsLazyEvaluation(dep):
deps.append(dep)
props.update(requiredProperties(dep))
super().__init__(props)
self._dependencies = tuple(deps) # fixed order for reproducibility
self._conditioned = self # version (partially) conditioned on requirements
[docs] @staticmethod
def sampleAll(quantities):
"""Sample all the given Samplables, which may have dependencies in common.
Reproducibility note: the order in which the quantities are given can affect the
order in which calls to random are made, affecting the final result.
"""
subsamples = DefaultIdentityDict()
for q in quantities:
if q not in subsamples:
subsamples[q] = q.sample(subsamples) if isinstance(q, Samplable) else q
return subsamples
[docs] def sample(self, subsamples=None):
"""Sample this value, optionally given some values already sampled."""
if subsamples is None:
subsamples = DefaultIdentityDict()
for child in self._conditioned._dependencies:
if child not in subsamples:
subsamples[child] = child.sample(subsamples)
return self._conditioned.sampleGiven(subsamples)
[docs] def sampleGiven(self, value):
"""Sample this value, given values for all its dependencies.
Implemented by subclasses.
Args:
value (DefaultIdentityDict): dictionary mapping objects to their sampled
values. Guaranteed to provide values for all objects given in the set of
dependencies when this `Samplable` was created.
"""
raise NotImplementedError
[docs] def conditionTo(self, value):
"""Condition this value to another value with the same conditional distribution."""
assert isinstance(value, Samplable)
self._conditioned = value
[docs] def evaluateIn(self, context):
"""See `LazilyEvaluable.evaluateIn`."""
value = super().evaluateIn(context)
# Check that all dependencies have been evaluated
assert all(not needsLazyEvaluation(dep) for dep in value._dependencies)
return value
[docs]class ConstantSamplable(Samplable):
"""A samplable which always evaluates to a constant value.
Only for internal use.
"""
def __init__(self, value):
assert not needsSampling(value)
assert not needsLazyEvaluation(value)
self.value = value
super().__init__(())
def sampleGiven(self, value):
return self.value
[docs]class Distribution(Samplable):
"""Abstract class for distributions.
.. note::
When called during dynamic simulations (vs. scenario compilation), constructors
for distributions return *actual sampled values*, not `Distribution` objects.
Args:
dependencies: values which this distribution may depend on (see `Samplable`).
valueType: **_valueType** to use (see below), or `None` for the default.
Attributes:
_valueType: type of the values sampled from this distribution, or `object` if the
type is not known.
"""
#: Default valueType for distributions of this class, when not otherwise specified.
_defaultValueType = object
def __new__(cls, *args, **kwargs):
dist = super().__new__(cls)
# at runtime, return a sample from the distribution immediately
import scenic.syntax.veneer as veneer
if veneer.simulationInProgress():
dist.__init__(*args, **kwargs)
return dist.sample()
else:
return dist
def __init__(self, *dependencies, valueType=None):
super().__init__(dependencies)
if valueType is None:
valueType = self._defaultValueType
self._valueType = valueType
[docs] def clone(self):
"""Construct an independent copy of this Distribution.
Optionally implemented by subclasses.
"""
raise NotImplementedError('clone() not supported by this distribution')
@property
@cached
def isPrimitive(self):
"""Whether this is a primitive Distribution."""
try:
self.clone()
return True
except NotImplementedError:
return False
[docs] def bucket(self, buckets=None):
"""Construct a bucketed approximation of this Distribution.
Optionally implemented by subclasses.
This function factors a given Distribution into a discrete distribution over
buckets together with a distribution for each bucket. The argument *buckets*
controls how many buckets the domain of the original Distribution is split into.
Since the result is an independent distribution, the original must support
`clone`.
"""
raise NotImplementedError('bucket() not supported by this distribution')
[docs] def supportInterval(self):
"""Compute lower and upper bounds on the value of this Distribution.
By default returns :scenic:`(None, None)` indicating that no lower or upper bounds
are known. Subclasses may override this method to provide more accurate results.
"""
return None, None
def __getattr__(self, name):
if name.startswith('__') and name.endswith('__'): # ignore special attributes
return object.__getattribute__(self, name)
return AttributeDistribution(name, self)
def __call__(self, *args):
return OperatorDistribution('__call__', self, args)
def __iter__(self):
raise RandomControlFlowError(f'cannot iterate through a random value')
def _comparisonError(self, other):
raise RandomControlFlowError('random values cannot be compared '
'(and control flow cannot depend on them)')
__lt__ = _comparisonError
__le__ = _comparisonError
__gt__ = _comparisonError
__ge__ = _comparisonError
__eq__ = _comparisonError
__ne__ = _comparisonError
def __hash__(self): # need to explicitly define since we overrode __eq__
return id(self)
def __len__(self):
raise RandomControlFlowError('cannot take the len of a random value')
def __bool__(self):
raise RandomControlFlowError('control flow cannot depend on a random value')
## Derived distributions
[docs]class TupleDistribution(Distribution, collections.abc.Sequence):
"""Distributions over tuples (or namedtuples, or lists)."""
def __init__(self, *coordinates, builder=tuple):
super().__init__(*coordinates)
self.coordinates = coordinates
self.builder = builder
def __len__(self):
return len(self.coordinates)
def __getitem__(self, index):
return self.coordinates[index]
def __iter__(self):
yield from self.coordinates
def sampleGiven(self, value):
return self.builder(value[coordinate] for coordinate in self.coordinates)
def evaluateInner(self, context):
coordinates = (valueInContext(coord, context) for coord in self.coordinates)
return TupleDistribution(*coordinates, builder=self.builder)
def __repr__(self):
coords = ', '.join(repr(c) for c in self.coordinates)
if self.builder is tuple:
return f'({coords},)'
elif self.builder is list:
return f'[{coords}]'
else:
return f'TupleDistribution({coords}, builder={self.builder!r})'
[docs]def toDistribution(val):
"""Wrap Python data types with Distributions, if necessary.
For example, tuples containing Samplables need to be converted into TupleDistributions
in order to keep track of dependencies properly.
"""
if isinstance(val, (tuple, list)):
coords = [toDistribution(c) for c in val]
if any(needsSampling(c) or needsLazyEvaluation(c) for c in coords):
if isinstance(val, tuple) and hasattr(val, '_fields'): # namedtuple
builder = type(val)._make
else:
builder = type(val)
return TupleDistribution(*coords, builder=builder)
return val
[docs]class FunctionDistribution(Distribution):
"""Distribution resulting from passing distributions to a function"""
def __init__(self, func, args, kwargs, support=None, valueType=None):
args = tuple(toDistribution(arg) for arg in args)
kwargs = { name: toDistribution(arg) for name, arg in kwargs.items() }
if valueType is None:
valueType = get_type_hints(func).get('return')
super().__init__(*args, *kwargs.values(), valueType=valueType)
self.function = func
self.arguments = args
self.kwargs = kwargs
self.support = support
def sampleGiven(self, value):
args = []
for arg in self.arguments:
if isinstance(arg, StarredDistribution):
val = value[arg]
try:
iter(val)
except TypeError: # TODO improve backtrace
raise TypeError(f"'{type(val).__name__}' object on line {arg.lineno} "
"is not iterable") from None
args.extend(val)
else:
args.append(value[arg])
kwargs = { name: value[arg] for name, arg in self.kwargs.items() }
return self.function(*args, **kwargs)
def evaluateInner(self, context):
function = valueInContext(self.function, context)
arguments = tuple(valueInContext(arg, context) for arg in self.arguments)
kwargs = { name: valueInContext(arg, context) for name, arg in self.kwargs.items() }
return FunctionDistribution(function, arguments, kwargs)
def supportInterval(self):
if self.support is None:
return None, None
subsupports = (supportInterval(arg) for arg in self.arguments)
kwss = { name: supportInterval(arg) for name, arg in self.kwargs.items() }
return self.support(*subsupports, **kwss)
def __repr__(self):
return f'{self.function.__name__}({argsToString(self.arguments, self.kwargs)})'
[docs]def distributionFunction(wrapped=None, *, support=None, valueType=None):
"""Decorator for wrapping a function so that it can take distributions as arguments.
This decorator is mainly for internal use, and is not necessary when defining a
function in a Scenic file. It is, however, needed when calling external functions
which contain control flow or other operations that Scenic distribution objects
(representing random values) do not support.
"""
if wrapped is None: # written without arguments as @distributionFunction
return lambda wrapped: distributionFunction(wrapped,
support=support, valueType=valueType)
def helper(wrapped, *args, **kwargs):
args = tuple(toDistribution(arg) for arg in args)
kwargs = { name: toDistribution(arg) for name, arg in kwargs.items() }
if any(needsSampling(arg) for arg in itertools.chain(args, kwargs.values())):
return FunctionDistribution(wrapped, args, kwargs, support, valueType)
elif any(needsLazyEvaluation(arg)
for arg in itertools.chain(args, kwargs.values())):
# recursively call this helper (not the original function), since the
# delayed arguments may evaluate to distributions, in which case we'll
# have to make a FunctionDistribution
return makeDelayedFunctionCall(helper, (wrapped,) + args, kwargs)
else:
return wrapped(*args, **kwargs)
try:
newFunc = decorator.decorate(wrapped, helper, kwsyntax=True)
except ValueError:
# We couldn't preserve the wrapped function's metadata using decorator.decorate
# (e.g. it's a built-in function like print on which inspect.signature fails),
# so fall back on functools.wraps.
@functools.wraps(wrapped)
def newFunc(*args, **kwargs):
return helper(wrapped, *args, **kwargs)
return unpacksDistributions(newFunc)
[docs]def monotonicDistributionFunction(method, valueType=None):
"""Like distributionFunction, but additionally specifies that the function is monotonic."""
def support(*subsupports, **kwss):
mins, maxes = zip(*subsupports)
kwmins = { name: interval[0] for name, interval in kwss.items() }
kwmaxes = { name: interval[1] for name, interval in kwss.items() }
l = None if None in mins or None in kwmins else method(*mins, **kwmins)
r = None if None in maxes or None in kwmaxes else method(*maxes, **kwmaxes)
return l, r
return distributionFunction(method, support=support, valueType=valueType)
[docs]class StarredDistribution(Distribution):
"""A placeholder for the iterable unpacking operator * applied to a distribution."""
def __init__(self, value, lineno):
assert isinstance(value, Distribution)
self.value = value
self.lineno = lineno # for error handling when unpacking fails
super().__init__(value, valueType=value._valueType)
def sampleGiven(self, value):
return value[self.value]
def evaluateInner(self, context):
return StarredDistribution(valueInContext(self.value, context), self.lineno)
def __repr__(self):
return f'*{self.value!r}'
[docs]class MethodDistribution(Distribution):
"""Distribution resulting from passing distributions to a method of a fixed object"""
def __init__(self, method, obj, args, kwargs, valueType=None):
args = tuple(toDistribution(arg) for arg in args)
kwargs = { name: toDistribution(arg) for name, arg in kwargs.items() }
if valueType is None:
valueType = get_type_hints(method).get('return')
super().__init__(*args, *kwargs.values(), valueType=valueType)
self.method = method
self.object = obj
self.arguments = args
self.kwargs = kwargs
def sampleGiven(self, value):
args = []
for arg in self.arguments:
if isinstance(arg, StarredDistribution):
args.extend(value[arg.value])
else:
args.append(value[arg])
kwargs = { name: value[arg] for name, arg in self.kwargs.items() }
return self.method(self.object, *args, **kwargs)
def evaluateInner(self, context):
obj = valueInContext(self.object, context)
arguments = tuple(valueInContext(arg, context) for arg in self.arguments)
kwargs = { name: valueInContext(arg, context) for name, arg in self.kwargs.items() }
return MethodDistribution(self.method, obj, arguments, kwargs)
def __repr__(self):
args = argsToString(self.arguments, self.kwargs)
return f'{self.object!r}.{self.method.__name__}({args})'
[docs]def distributionMethod(method):
"""Decorator for wrapping a method so that it can take distributions as arguments."""
def helper(wrapped, self, *args, **kwargs):
args = tuple(toDistribution(arg) for arg in args)
kwargs = { name: toDistribution(arg) for name, arg in kwargs.items() }
if any(needsSampling(arg) for arg in itertools.chain(args, kwargs.values())):
return MethodDistribution(method, self, args, kwargs)
elif any(needsLazyEvaluation(arg)
for arg in itertools.chain(args, kwargs.values())):
# see analogous comment in distributionFunction
return makeDelayedFunctionCall(helper, (method, self) + args, kwargs)
else:
return method(self, *args, **kwargs)
try:
newMethod = decorator.decorate(method, helper, kwsyntax=True)
except ValueError:
# See analogous comment in distributionFunction
@functools.wraps(method)
def newMethod(*args, **kwargs):
return helper(method, *args, **kwargs)
return unpacksDistributions(newMethod)
[docs]class AttributeDistribution(Distribution):
"""Distribution resulting from accessing an attribute of a distribution"""
def __init__(self, attribute, obj, valueType=None):
if valueType is None:
valueType = self.inferType(obj, attribute)
super().__init__(obj, valueType=valueType)
self.attribute = attribute
self.object = obj
[docs] @staticmethod
def inferType(obj, attribute):
"""Attempt to infer the type of the given attribute."""
# If the object's type is known, see if we have an attribute type annotation.
ty = type_support.underlyingType(obj)
try:
hints = get_type_hints(ty)
attrTy = hints.get(attribute)
if attrTy:
return attrTy
except Exception:
pass # couldn't get type annotations
# We can't tell what the attribute type is.
return None
def sampleGiven(self, value):
obj = value[self.object]
return getattr(obj, self.attribute)
def evaluateInner(self, context):
obj = valueInContext(self.object, context)
return AttributeDistribution(self.attribute, obj)
def supportInterval(self):
obj = self.object
if isinstance(obj, Options):
attrs = (getattr(opt, self.attribute) for opt in obj.options)
return unionOfSupports(supportInterval(attr) for attr in attrs)
return None, None
def __call__(self, *args):
vty = self.object._valueType
retTy = None
if vty is not object:
func = getattr(vty, self.attribute, None)
if func:
if isinstance(func, property):
func = func.fget
retTy = get_type_hints(func).get('return')
return OperatorDistribution('__call__', self, args, valueType=retTy)
def __repr__(self):
return f'{self.object!r}.{self.attribute}'
[docs]class OperatorDistribution(Distribution):
"""Distribution resulting from applying an operator to one or more distributions"""
def __init__(self, operator, obj, operands, valueType=None):
operands = tuple(toDistribution(arg) for arg in operands)
if valueType is None:
valueType = self.inferType(obj, operator, operands)
super().__init__(obj, *operands, valueType=valueType)
self.operator = operator
self.object = obj
self.operands = operands
[docs] @staticmethod
def inferType(obj, operator, operands):
"""Attempt to infer the result type of the given operator application."""
# If the object's type is known, see if we have a return type annotation.
ty = type_support.underlyingType(obj)
op = getattr(ty, operator, None)
if op:
retTy = get_type_hints(op).get('return')
if retTy:
return retTy
# The supported arithmetic operations on scalars all return scalars.
def scalar(thing):
ty = type_support.underlyingType(thing)
return type_support.canCoerceType(ty, float)
if scalar(obj) and all(scalar(operand) for operand in operands):
return float
# We can't tell what the result type is.
return None
def sampleGiven(self, value):
first = value[self.object]
rest = [value[child] for child in self.operands]
op = getattr(first, self.operator)
result = op(*rest)
# handle horrible int/float mismatch
# TODO what is the right way to fix this???
if result is NotImplemented and isinstance(first, int):
first = float(first)
op = getattr(first, self.operator)
result = op(*rest)
return result
def evaluateInner(self, context):
obj = valueInContext(self.object, context)
operands = tuple(valueInContext(arg, context) for arg in self.operands)
return OperatorDistribution(self.operator, obj, operands)
def supportInterval(self):
if self.operator in ('__add__', '__radd__', '__sub__', '__rsub__',
'__mul__', '__rmul__', '__truediv__', '__rtruediv__'):
assert len(self.operands) == 1
l1, r1 = supportInterval(self.object)
l2, r2 = supportInterval(self.operands[0])
if l1 is None or l2 is None or r1 is None or r2 is None:
return None, None
if self.operator == '__add__' or self.operator == '__radd__':
l = l1 + l2
r = r1 + r2
elif self.operator == '__sub__':
l = l1 - r2
r = r1 - l2
elif self.operator == '__rsub__':
l = l2 - r1
r = r2 - l1
elif self.operator in ('__mul__', '__rmul__'):
prods = (l1*l2, l1*r2, r1*l2, r1*r2)
l = min(*prods)
r = max(*prods)
elif self.operator == '__truediv__':
if l2 > 0:
l = l1 / r2 if l1 >= 0 else l1 / l2
r = r1 / l2 if r1 >= 0 else r1 / r2
else:
l, r = None, None # TODO improve
elif self.operator == '__rtruediv__':
if l1 > 0:
l = l2 / r1 if l2 >= 0 else l2 / l1
r = r2 / l1 if r2 >= 0 else r2 / r1
else:
l, r = None, None
else:
raise AssertionError(f'unexpected operator {self.operator}')
return l, r
elif self.operator in ('__neg__', '__abs__'):
assert len(self.operands) == 0
l, r = supportInterval(self.object)
if self.operator == '__neg__':
return -r, -l
elif self.operator == '__abs__':
if r < 0:
return -r, -l
elif l < 0:
return 0, max(-l, r)
else:
return l, r
else:
raise AssertionError(f'unexpected operator {self.operator}')
return None, None
def __repr__(self):
return f'{self.object!r}.{self.operator}({argsToString(self.operands)})'
# Operators which can be applied to distributions.
# Note that we deliberately do not include comparisons and __bool__,
# since Scenic does not allow control flow to depend on random variables.
allowedOperators = (
'__neg__',
'__pos__',
'__abs__',
'__add__', '__radd__',
'__sub__', '__rsub__',
'__mul__', '__rmul__',
'__truediv__', '__rtruediv__',
'__floordiv__', '__rfloordiv__',
'__mod__', '__rmod__',
'__divmod__', '__rdivmod__',
'__pow__', '__rpow__',
'__round__',
'__getitem__',
)
def makeOperatorHandler(op):
def handler(self, *args):
return OperatorDistribution(op, self, args)
return handler
for op in allowedOperators:
setattr(Distribution, op, makeOperatorHandler(op))
import scenic.core.type_support as type_support
[docs]class MultiplexerDistribution(Distribution):
"""Distribution selecting among values based on another distribution."""
def __init__(self, index, options):
self.index = index
self.options = tuple(toDistribution(opt) for opt in options)
assert len(self.options) > 0
valueType = type_support.unifyingType(self.options)
super().__init__(index, *self.options, valueType=valueType)
def sampleGiven(self, value):
idx = value[self.index]
assert 0 <= idx < len(self.options), (idx, len(self.options))
return value[self.options[idx]]
def supportInterval(self):
return unionOfSupports(supportInterval(opt) for opt in self.options)
## Simple distributions
[docs]class Range(Distribution):
"""Uniform distribution over a range"""
def __init__(self, low, high):
low = type_support.toScalar(low, f'Range endpoint {low} is not a scalar')
high = type_support.toScalar(high, f'Range endpoint {high} is not a scalar')
super().__init__(low, high, valueType=float)
self.low = low
self.high = high
def clone(self):
return type(self)(self.low, self.high)
def bucket(self, buckets=None):
if buckets is None:
buckets = 5
if not isinstance(buckets, int) or buckets < 1:
raise ValueError(f'Invalid buckets for Range.bucket: {buckets}')
if not isinstance(self.low, float) or not isinstance(self.high, float):
raise RuntimeError(f'Cannot bucket Range with non-constant endpoints')
endpoints = numpy.linspace(self.low, self.high, buckets+1)
ranges = []
for i, left in enumerate(endpoints[:-1]):
right = endpoints[i+1]
ranges.append(Range(left, right))
return Options(ranges)
def sampleGiven(self, value):
return random.uniform(value[self.low], value[self.high])
def evaluateInner(self, context):
low = valueInContext(self.low, context)
high = valueInContext(self.high, context)
return Range(low, high)
def supportInterval(self):
return unionOfSupports((supportInterval(self.low), supportInterval(self.high)))
def __repr__(self):
return f'Range({self.low!r}, {self.high!r})'
[docs]class Normal(Distribution):
"""Normal distribution"""
def __init__(self, mean, stddev):
mean = type_support.toScalar(mean, f'Normal mean {mean} is not a scalar')
stddev = type_support.toScalar(stddev, f'Normal stddev {stddev} is not a scalar')
super().__init__(mean, stddev, valueType=float)
self.mean = mean
self.stddev = stddev
@staticmethod
def cdf(mean, stddev, x):
return (1 + math.erf((x - mean) / (sqrt2 * stddev))) / 2
@staticmethod
def cdfinv(mean, stddev, x):
import scipy # slow import not often needed
return mean + (sqrt2 * stddev * scipy.special.erfinv(2*x - 1))
def clone(self):
return type(self)(self.mean, self.stddev)
def bucket(self, buckets=None):
if not isinstance(self.stddev, float): # TODO relax restriction?
raise RuntimeError(f'Cannot bucket Normal with non-constant standard deviation')
if buckets is None:
buckets = 5
if isinstance(buckets, int):
if buckets < 1:
raise ValueError(f'Invalid buckets for Normal.bucket: {buckets}')
elif buckets == 1:
endpoints = []
elif buckets == 2:
endpoints = [0]
else:
left = self.stddev * (-(buckets-3)/2 - 0.5)
right = self.stddev * ((buckets-3)/2 + 0.5)
endpoints = numpy.linspace(left, right, buckets-1)
else:
endpoints = tuple(buckets)
for i, v in enumerate(endpoints[:-1]):
if v >= endpoints[i+1]:
raise ValueError('Non-increasing bucket endpoints for '
f'Normal.bucket: {endpoints}')
if len(endpoints) == 0:
return Options([self.clone()])
buckets = [(-math.inf, endpoints[0])]
buckets.extend((v, endpoints[i+1]) for i, v in enumerate(endpoints[:-1]))
buckets.append((endpoints[-1], math.inf))
pieces = []
probs = []
for left, right in buckets:
pieces.append(self.mean + TruncatedNormal(0, self.stddev, left, right))
prob = (Normal.cdf(0, self.stddev, right)
- Normal.cdf(0, self.stddev, left))
probs.append(prob)
assert math.isclose(math.fsum(probs), 1), probs
return Options(dict(zip(pieces, probs)))
def sampleGiven(self, value):
return random.gauss(value[self.mean], value[self.stddev])
def evaluateInner(self, context):
mean = valueInContext(self.mean, context)
stddev = valueInContext(self.stddev, context)
return Normal(mean, stddev)
def __repr__(self):
return f'Normal({self.mean!r}, {self.stddev!r})'
[docs]class TruncatedNormal(Normal):
"""Truncated normal distribution."""
def __init__(self, mean, stddev, low, high):
if (not isinstance(low, (float, int))
or not isinstance(high, (float, int))): # TODO relax restriction?
raise ValueError('Endpoints of TruncatedNormal must be constant')
if low >= high:
raise ValueError('low endpoint of TruncatedNormal must be below high endpoint')
super().__init__(mean, stddev)
self.low = low
self.high = high
def clone(self):
return type(self)(self.mean, self.stddev, self.low, self.high)
def bucket(self, buckets=None):
if not isinstance(self.stddev, float): # TODO relax restriction?
raise RuntimeError('Cannot bucket TruncatedNormal with '
'non-constant standard deviation')
if buckets is None:
buckets = 5
if isinstance(buckets, int):
if buckets < 1:
raise ValueError(f'Invalid buckets for TruncatedNormal.bucket: {buckets}')
endpoints = numpy.linspace(self.low, self.high, buckets+1)
else:
endpoints = tuple(buckets)
if len(endpoints) < 2:
raise ValueError('Too few bucket endpoints for '
f'TruncatedNormal.bucket: {endpoints}')
if endpoints[0] != self.low or endpoints[-1] != self.high:
raise ValueError(f'TruncatedNormal.bucket endpoints {endpoints} '
'do not match domain')
for i, v in enumerate(endpoints[:-1]):
if v >= endpoints[i+1]:
raise ValueError('Non-increasing bucket endpoints for '
f'TruncatedNormal.bucket: {endpoints}')
pieces, probs = [], []
for i, left in enumerate(endpoints[:-1]):
right = endpoints[i+1]
pieces.append(TruncatedNormal(self.mean, self.stddev, left, right))
prob = (Normal.cdf(self.mean, self.stddev, right)
- Normal.cdf(self.mean, self.stddev, left))
probs.append(prob)
return Options(dict(zip(pieces, probs)))
def sampleGiven(self, value):
# TODO switch to method less prone to underflow?
mean, stddev = value[self.mean], value[self.stddev]
alpha = (self.low - mean) / stddev
beta = (self.high - mean) / stddev
alpha_cdf = Normal.cdf(0, 1, alpha)
beta_cdf = Normal.cdf(0, 1, beta)
if beta_cdf - alpha_cdf < 1e-15:
warnings.warn('low precision when sampling TruncatedNormal')
unif = random.random()
p = alpha_cdf + unif * (beta_cdf - alpha_cdf)
return mean + (stddev * Normal.cdfinv(0, 1, p))
def evaluateInner(self, context):
mean = valueInContext(self.mean, context)
stddev = valueInContext(self.stddev, context)
return TruncatedNormal(mean, stddev, self.low, self.high)
def supportInterval(self):
return self.low, self.high
def __repr__(self):
return f'TruncatedNormal({self.mean!r}, {self.stddev!r}, {self.low!r}, {self.high!r})'
[docs]class DiscreteRange(Distribution):
"""Distribution over a range of integers."""
def __init__(self, low, high, weights=None):
if not isinstance(low, int):
raise ValueError(f'DiscreteRange endpoint {low} is not a constant integer')
if not isinstance(high, int):
raise ValueError(f'DiscreteRange endpoint {high} is not a constant integer')
if not low <= high:
raise ValueError(f'DiscreteRange lower bound {low} is above upper bound {high}')
if weights is None:
weights = (1,) * (high - low + 1)
else:
weights = tuple(weights)
assert len(weights) == high - low + 1
super().__init__(valueType=int)
self.low = low
self.high = high
self.weights = weights
self.cumulativeWeights = tuple(itertools.accumulate(weights))
self.options = tuple(range(low, high+1))
def clone(self):
return type(self)(self.low, self.high, self.weights)
def bucket(self, buckets=None):
return self.clone() # already bucketed
def sampleGiven(self, value):
return random.choices(self.options, cum_weights=self.cumulativeWeights)[0]
def supportInterval(self):
return self.low, self.high
def __repr__(self):
weights = self.weights
if all(weight == weights[0] for weight in weights):
return f'DiscreteRange({self.low!r}, {self.high!r})'
else:
return f'DiscreteRange({self.low!r}, {self.high!r}, {self.weights})'
[docs]class Options(MultiplexerDistribution):
"""Distribution over a finite list of options.
Specified by a dict giving probabilities; otherwise uniform over a given iterable.
"""
def __init__(self, opts):
if isinstance(opts, dict):
options, weights = [], []
for opt, prob in opts.items():
if not isinstance(prob, (float, int)):
raise RuntimeParseError(f'discrete distribution weight {prob}'
' is not a constant number')
if prob < 0:
raise RuntimeParseError(f'discrete distribution weight {prob} is negative')
if prob == 0:
continue
options.append(opt)
weights.append(prob)
self.optWeights = dict(zip(options, weights))
else:
weights = None
options = tuple(opts)
self.optWeights = None
if len(options) == 0:
raise RejectionException('tried to make discrete distribution over empty domain!')
index = self.makeSelector(len(options)-1, weights)
super().__init__(index, options)
@staticmethod
def makeSelector(n, weights):
return DiscreteRange(0, n, weights)
def clone(self):
return type(self)(self.optWeights if self.optWeights else self.options)
def bucket(self, buckets=None):
return self.clone() # already bucketed
def evaluateInner(self, context):
if self.optWeights is None:
return type(self)(valueInContext(opt, context) for opt in self.options)
else:
return type(self)({valueInContext(opt, context): wt
for opt, wt in self.optWeights.items() })
def __repr__(self):
if self.optWeights is not None:
return f'{type(self).__name__}({self.optWeights!r})'
else:
args = ', '.join(repr(opt) for opt in self.options)
return f'{type(self).__name__}({args})'