torch_mist.distributions.caching

Module Contents

Classes

CachedTransformModule

CachedConditionalTransformModule

Functions

add_cache(transform)

class torch_mist.distributions.caching.CachedTransformModule(transform: torch.distributions.Transform)

Bases: pyro.distributions.TransformModule

_call(x)
_inverse(y)
log_abs_det_jacobian(x, y)
__repr__()
class torch_mist.distributions.caching.CachedConditionalTransformModule(conditional_transform: pyro.distributions.ConditionalTransform)

Bases: pyro.distributions.ConditionalTransformModule

condition(context)
__repr__()
torch_mist.distributions.caching.add_cache(transform: torch.distributions.Transform | pyro.distributions.ConditionalTransform)