cautrigger.causaleffect.joint_uncond_single_dim_v2#
- cautrigger.causaleffect.joint_uncond_single_dim_v2(params, model, data, index, dim, alpha_vi=False, beta_vi=True, eps=1e-06, device=None)#
Estimate -I(z_i; Ŷ) for a single latent dimension i via conditional Monte Carlo sampling.
This function computes the mutual information between one specific latent dimension
z_i(wherei = dim) and the model’s output Ŷ, marginalizing over all other latent dimensions. It is useful for per-dimension causal attribution.- The procedure:
Sample z_i independently
N_alphatimes.For each z_i, sample the remaining (z_dim - 1) dimensions
N_betatimes.Estimate p(Ŷ | z_i) and p(Ŷ) to compute I(z_i; Ŷ).
- Parameters:
params (dict) – Dictionary containing: - ‘N_alpha’ (int): Samples for the target latent dimension. - ‘N_beta’ (int): Samples for all other latent dimensions per fixed z_i. - ‘K’, ‘L’ (int): Causal and spurious latent dimensions (used to infer total
z_dim = K + L). - ‘M’ (int): Number of output classes.model (torch.nn.Module) – Same architecture assumptions as in
joint_uncond_v2.data (torch.Tensor) – Input data tensor.
index (int) – Index of the sample in
datato analyze.dim (int) – Zero-indexed latent dimension to evaluate (0 ≤ dim < K + L).
alpha_vi (bool, optional) – If True, sample the target dimension from its empirical posterior mean/std; else use standard normal. Default is False.
beta_vi (bool, optional) – If True, sample all latent dimensions (including non-target) from their joint empirical posterior; else use standard normal for background dimensions. Default is True.
eps (float, optional) – Clamping epsilon for log-probabilities. Slightly larger (1e-6) for single-dim stability. Default is 1e-6.
device (torch.device or str, optional) – Computation device.
- Returns:
-mutual_info_estimate (
Tensor) Scalar tensor representing I(z_dim; Ŷ) (note: not negated, unlike the other two functions).