torch_mist.estimators.hybrid.factories
Module Contents
Functions
|
|
|
|
|
- torch_mist.estimators.hybrid.factories.hybrid_pq(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, quantize_y: torch_mist.quantization.QuantizationFunction | str | None = None, hidden_dims: List[int] | None = None, q_QY_given_X: pyro.distributions.ConditionalDistribution | None = None, temperature: float = 0.1, n_bins: int | None = 32, quantization_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.PQHybridMIEstimator
- torch_mist.estimators.hybrid.factories.resampled_hybrid(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, generative_estimator: torch_mist.estimators.generative.GenerativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, discriminative_params: Dict[str, Any] | None = None, generative_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.ResampledHybridMIEstimator
- torch_mist.estimators.hybrid.factories.reweighed_hybrid(discriminative_estimator: torch_mist.estimators.discriminative.DiscriminativeMIEstimator | str, generative_estimator: torch_mist.estimators.generative.GenerativeMIEstimator | str, x_dim: int | None = None, y_dim: int | None = None, hidden_dims: List[int] | None = None, discriminative_params: Dict[str, Any] | None = None, generative_params: Dict[str, Any] | None = None, **kwargs) torch_mist.estimators.hybrid.ReweighedHybridMIEstimator