cautrigger.causaleffect.joint_uncond_v2#
- cautrigger.causaleffect.joint_uncond_v2(params, model, data, index, alpha_vi=False, beta_vi=True, eps=1e-08, device=None)#
Estimate the negative unconditional mutual information -I(α; Ŷ) via Monte Carlo sampling.
This function computes a sample-based estimate of the causal effect of the causal latent factors (denoted α, dimension K) on the model’s predicted output Ŷ. It marginalizes over both causal and non-causal (spurious) latent variables using variational posterior statistics or standard normal priors.
- The mutual information is approximated as:
I(α; Ŷ) ≈ E_{α}[KL(p(Ŷ|α) || p(Ŷ))] = H(Ŷ) - H(Ŷ|α)
and this function returns -I(α; Ŷ).
- Parameters:
params (dict) – Dictionary containing simulation parameters: - ‘N_alpha’ (int): Number of Monte Carlo samples for causal latents (α). - ‘N_beta’ (int): Number of Monte Carlo samples for non-causal latents (β). - ‘K’ (int): Dimensionality of causal latent space. - ‘L’ (int): Dimensionality of non-causal (spurious) latent space. - ‘M’ (int): Number of output classes (dimension of classifier logits/probabilities).
model (torch.nn.Module) – Trained model with components:
feature_mapper_up,encoder1,encoder2,decoder_down(ordecoder_down1/2),feature_mapper_down, anddpd_model.data (torch.Tensor) – Input data tensor of shape
(n_samples, n_features).index (int) – Index of the input sample in
datato evaluate.alpha_vi (bool, optional) – If True, sample α from the empirical mean/variance of its inferred posterior; otherwise, use standard normal prior (μ=0, σ=1). Default is False.
beta_vi (bool, optional) – If True, sample β from the empirical mean/variance of its inferred posterior; otherwise, use standard normal prior. Default is True.
eps (float, optional) – Small constant for numerical stability in log-probability clamping. Default is 1e-8.
device (torch.device or str, optional) – Device to perform computation on (e.g., ‘cuda’ or ‘cpu’). If None, uses model’s device.
- Returns:
- -neg_causal_effect (
Tensor) Scalar tensor representing the estimated -I(α; Ŷ).
- -info (
None) Placeholder for compatibility; always returns None.
- -neg_causal_effect (