cautrigger.model.CauTrigger1L#

class cautrigger.model.CauTrigger1L(adata, n_latent=10, n_causal=2, n_state=2, **model_kwargs)#

First-layer causal generative model for perturbation-response modelling. CauTrigger1L wraps a single-layer DualVAE1L module and provides training and interpretation utilities for modelling how features signals causally affect state. The model is designed for scenarios where all features are in one layer (first-layer decomposition).

Parameters:
  • adata (AnnData) – Annotated data matrix containing upstream features in adata.X and any downstream or auxiliary arrays in adata.obsm as required by the module.

  • n_latent (int (default: 10)) – Dimensionality of the latent space (default: 10).

  • n_causal (int (default: 2)) – Number of causal latent factors (default: 2).

  • n_state (int (default: 2)) – Number of discrete states the model may represent (default: 2).

  • **model_kwargs – Additional keyword arguments forwarded to the underlying DualVAE1L module.

Methods table#

compute_information_flow([adata, dims, ...])

Compute information flow for latent dimensions.

get_model_output([adata, batch_size])

Obtain model predictions and latent representations for a given dataset.

get_up_feature_weights([adata, method, ...])

Compute and return feature importance weights for the upstream feature mapper.

get_up_significance([adata, method, ...])

Compute significance of upstream feature weights using: - Grad → Binomial sign-consistency test - SHAP → Binomial sign-consistency test or permutation test

plot_train_losses([fig_size])

Plot training loss curves for all recorded losses during training.

pretrain_attention([prior_probs, ...])

Pretrain attention network.

train([max_epochs, lr, use_gpu, train_size, ...])

Trains the model using fractal variational autoencoder.

Methods#

CauTrigger1L.compute_information_flow(adata=None, dims=None, zero_floor=False, plot_info_flow=True, skip_single_info=True, save_fig=False, save_dir=None)#

Compute information flow for latent dimensions.

Parameters:
  • adata (Optional[AnnData] (default: None)) – AnnData object with input data

  • dims (Optional[List[int]] (default: None)) – Dimensions to compute information flow for

  • zero_floor (bool (default: False)) – Whether to subtract minimum value

  • plot_info_flow (Optional[bool] (default: True)) – Whether to plot information flow

  • skip_single_info (Optional[bool] (default: True)) – Whether to skip single dimension plots

  • save_fig (Optional[bool] (default: False)) – Whether to save figures

  • save_dir (Optional[str] (default: None)) – Directory to save figures

Returns:

info_flow

Information flow for each dimension

info_flow_cat

Categorical information flow (causal vs spurious)

CauTrigger1L.get_model_output(adata=None, batch_size=None)#

Obtain model predictions and latent representations for a given dataset.

This method runs the trained model in evaluation mode and returns: - Concatenated latent embeddings from two latent spaces, - Logits and predicted probabilities from the downstream classifier, - Binary class predictions (thresholded at 0.5).

If no adata is provided, the method uses the internal self.adata.

Parameters:
  • adata (AnnData, optional) – Annotated data matrix to generate outputs for. If None, defaults to self.adata. Default is None.

  • batch_size (int, optional) – Number of samples per batch during inference. If None, uses self.batch_size. Default is None.

Returns:

-output (dict) Dictionary containing the following keys: - 'latent': numpy.ndarray of shape (n_samples, n_latent1 + n_latent2),

concatenated latent vectors from both latent modules.

  • 'logits': numpy.ndarray of shape (n_samples,) or (n_samples, n_classes), raw classifier logits.

  • 'probs': numpy.ndarray of same shape as 'logits', predicted probabilities after sigmoid/softmax activation.

  • 'preds': numpy.ndarray of shape (n_samples,), binary predictions (1 if probability > 0.5, else 0).

CauTrigger1L.get_up_feature_weights(adata=None, method='SHAP', n_bg_samples=100, grad_source='prob', normalize=True, sort_by_weight=True, class_idx=None, background_data=None, return_background=False)#

Compute and return feature importance weights for the upstream feature mapper.

This method supports multiple strategies to estimate feature contributions: - “Model”: uses internal attention or learned weights from the model. - “SHAP”: computes SHAP values using DeepExplainer. - “Grad”: computes input gradients w.r.t. a specified output (e.g., probability or logit). - “Ensemble”: averages normalized absolute weights from all three methods above.

The resulting weights are aggregated across samples (by mean), optionally normalized, and returned both as a sorted DataFrame aligned with self.adata.var and as a full sample-by-feature weight matrix.

Parameters:
  • adata (AnnData, optional) – AnnData object containing the data to compute feature weights on.

  • method (str, optional) – Method to compute feature weights. One of {“Model”, “SHAP”, “Grad”, “Ensemble”}. Default is “SHAP”.

  • n_bg_samples (int, optional) – Number of background samples used for SHAP explanation. Only relevant if method="SHAP". Default is 100.

  • grad_source (str, optional) – Target output for gradient computation when method="Grad". Options are: - “prob”: gradients w.r.t. predicted probabilities, - “logit”: gradients w.r.t. logits, - “loss”: gradients w.r.t. the DPD loss. Default is “prob”.

  • normalize (bool, optional) – Whether to normalize the final feature weights to sum to 1. Default is True.

  • sort_by_weight (bool, optional) – Whether to sort the returned DataFrame by weight in descending order. Default is True.

  • class_idx (int, optional) – Class index for which to compute SHAP values (only used when labels are present and method="SHAP"). If None, SHAP values are averaged over all classes or computed on the full dataset. Default is None.

Returns:

-weights_df (DataFrame)

DataFrame with the same index as self.adata.var, containing a new column 'weight' with the computed feature importance scores. Sorted by weight if sort_by_weight=True.

-weights_full (ndarray)

Full sample-by-feature matrix of absolute weights before aggregation. Shape: (n_samples, n_features).

CauTrigger1L.get_up_significance(adata=None, method='SHAP', test_mode='permutation', perm_mode='global', n_perm=100, n_bg_samples=100, grad_source='prob', normalize=False, class_idx=None, target_genes=None, fdr_correct=True, use_signed=True, show_progress=True, random_state=42)#

Compute significance of upstream feature weights using: - Grad → Binomial sign-consistency test - SHAP → Binomial sign-consistency test or permutation test

Parameters:
  • adata (AnnData, optional) – Input AnnData object (default: self.adata)

  • method (str) – “SHAP” or “Grad”

  • test_mode (str) – For SHAP: “sign_test” or “permutation” (ignored for Grad)

  • perm_mode (str) – “global” or “per_feature” shuffle strategy (for permutation test)

  • n_perm (int) – Number of permutations (for permutation test)

  • n_bg_samples (int) – Number of background samples for SHAP

  • grad_source (str) – Source for gradient-based attribution (“prob”, “logit”, or “loss”)

  • normalize (bool) – Whether to normalize weights across features

  • class_idx (int, optional) – Target class index for class-specific analysis

  • target_genes (list of str, optional) – Genes/features to test; if None, use all

  • fdr_correct (bool) – Apply Benjamini–Hochberg correction

  • use_signed (bool) – Whether to use signed weights for p-value calculation

  • show_progress (bool) – Display progress bar

  • random_state (int, optional) – Random seed

Returns:

-df_result (DataFrame)

DataFrame with [‘weight’, ‘weight_signed’, ‘pvalue’, (‘qvalue’)]

perm_matrixnp.ndarray or None

Permutation matrix for SHAP (None for Grad or sign_test)

CauTrigger1L.plot_train_losses(fig_size=(8, 8))#

Plot training loss curves for all recorded losses during training.

This method visualizes the evolution of each loss component over epochs using subplots. It requires that the model has been trained and that training history is available in the self.history attribute.

Parameters:

fig_size (tuple of int, optional) – Figure size (width, height) in inches. Default is (8, 8).

CauTrigger1L.pretrain_attention(prior_probs=None, max_epochs=50, pretrain_lr=0.001, batch_size=128, use_gpu=None, train_size=1.0, validation_size=None)#

Pretrain attention network.

Parameters:
  • prior_probs (Optional[ndarray] (default: None)) – Prior probabilities for attention weights

  • max_epochs (Optional[int] (default: 50)) – Maximum number of pretraining epochs

  • pretrain_lr (float (default: 0.001)) – Learning rate for pretraining

  • batch_size (int (default: 128)) – Number of samples per batch

  • use_gpu (Union[str, int, bool, None] (default: None)) – Whether to use GPU for pretraining

  • train_size (float (default: 1.0)) – Proportion of data to use for pretraining

  • validation_size (Optional[float] (default: None)) – Proportion of data to use for validation

CauTrigger1L.train(max_epochs=400, lr=0.0005, use_gpu=None, train_size=1.0, validation_size=None, batch_size=128, early_stopping=False, weight_decay=1e-06, n_x=5, n_alpha=25, n_beta=100, recons_weight=1.0, kl_weight=0.02, up_weight=1.0, down_weight=1.0, feat_l1_weight=0.05, dpd_weight=3.0, fide_kl_weight=0.05, causal_weight=1.0, down_fold=1.0, causal_fold=1.0, spurious_fold=1.0, stage_training=True, weight_scheme=None, im_factor=None, drop_last=10, **kwargs)#

Trains the model using fractal variational autoencoder.

Parameters:
  • max_epochs (Optional[int] (default: 400)) – Maximum number of training epochs

  • lr (float (default: 0.0005)) – Learning rate for optimizer

  • use_gpu (Union[str, int, bool, None] (default: None)) – Whether to use GPU for training

  • train_size (float (default: 1.0)) – Proportion of data to use for training

  • validation_size (Optional[float] (default: None)) – Proportion of data to use for validation

  • batch_size (int (default: 128)) – Number of samples per batch

  • early_stopping (bool (default: False)) – Whether to use early stopping

  • weight_decay (float (default: 1e-06)) – Weight decay for optimizer

  • n_x (int (default: 5)) – Number of samples for causal effect computation

  • n_alpha (int (default: 25)) – Monte-carlo samples per causal factor

  • n_beta (int (default: 100)) – Monte-carlo samples per noncausal factor

  • recons_weight (float (default: 1.0)) – Weight for reconstruction loss

  • kl_weight (float (default: 0.02)) – Weight for KL divergence loss

  • up_weight (float (default: 1.0)) – Weight for upstream reconstruction

  • down_weight (float (default: 1.0)) – Weight for downstream reconstruction

  • feat_l1_weight (float (default: 0.05)) – Weight for feature L1 loss

  • dpd_weight (float (default: 3.0)) – Weight for DPD loss

  • fide_kl_weight (float (default: 0.05)) – Weight for fidelity KL loss

  • causal_weight (float (default: 1.0)) – Weight for causal loss

  • down_fold (float (default: 1.0)) – Downstream loss scaling factor

  • causal_fold (float (default: 1.0)) – Causal loss scaling factor

  • spurious_fold (float (default: 1.0)) – Spurious loss scaling factor

  • stage_training (bool (default: True)) – Whether to use staged training

  • weight_scheme (Optional[str] (default: None)) – Weight update scheme

  • im_factor (Optional[float] (default: None)) – Imbalance factor for loss computation

  • drop_last (int (default: 10)) – Number of samples to drop from last batch

  • **kwargs – Additional arguments