cautrigger.causaleffect.joint_uncond_v2

cautrigger.causaleffect.joint_uncond_v2#

cautrigger.causaleffect.joint_uncond_v2(params, model, data, index, alpha_vi=False, beta_vi=True, eps=1e-08, device=None)#

Estimate the negative unconditional mutual information -I(α; Ŷ) via Monte Carlo sampling.

This function computes a sample-based estimate of the causal effect of the causal latent factors (denoted α, dimension K) on the model’s predicted output Ŷ. It marginalizes over both causal and non-causal (spurious) latent variables using variational posterior statistics or standard normal priors.

The mutual information is approximated as:

I(α; Ŷ) ≈ E_{α}[KL(p(Ŷ|α) || p(Ŷ))] = H(Ŷ) - H(Ŷ|α)

and this function returns -I(α; Ŷ).

Parameters:
  • params (dict) – Dictionary containing simulation parameters: - ‘N_alpha’ (int): Number of Monte Carlo samples for causal latents (α). - ‘N_beta’ (int): Number of Monte Carlo samples for non-causal latents (β). - ‘K’ (int): Dimensionality of causal latent space. - ‘L’ (int): Dimensionality of non-causal (spurious) latent space. - ‘M’ (int): Number of output classes (dimension of classifier logits/probabilities).

  • model (torch.nn.Module) – Trained model with components: feature_mapper_up, encoder1, encoder2, decoder_down (or decoder_down1/2), feature_mapper_down, and dpd_model.

  • data (torch.Tensor) – Input data tensor of shape (n_samples, n_features).

  • index (int) – Index of the input sample in data to evaluate.

  • alpha_vi (bool, optional) – If True, sample α from the empirical mean/variance of its inferred posterior; otherwise, use standard normal prior (μ=0, σ=1). Default is False.

  • beta_vi (bool, optional) – If True, sample β from the empirical mean/variance of its inferred posterior; otherwise, use standard normal prior. Default is True.

  • eps (float, optional) – Small constant for numerical stability in log-probability clamping. Default is 1e-8.

  • device (torch.device or str, optional) – Device to perform computation on (e.g., ‘cuda’ or ‘cpu’). If None, uses model’s device.

Returns:

-neg_causal_effect (Tensor)

Scalar tensor representing the estimated -I(α; Ŷ).

-info (None)

Placeholder for compatibility; always returns None.