cautrigger.model.CauTrigger3L#

class cautrigger.model.CauTrigger3L(adata, n_latent=10, n_causal=2, n_state=2, **model_kwargs)#

Third-layer hierarchical causal decomposition model. CauTrigger3L wraps a DualVAE3L module and supports third-stage causal decomposition (eg. x3 → xc2 → xc1 → y). It is intended for modelling complex cascades where effects propagate through multiple intermediate regulatory layers (for instance, multi-omic cascades).

Parameters:
  • adata (AnnData) – Annotated data matrix. Expects upstream features in adata.X and downstream representations in adata.obsm['X_down1'] and adata.obsm['X_down2'].

  • n_latent (int (default: 10)) – Latent dimensionality (default: 10).

  • n_causal (int (default: 2)) – Number of causal latent factors (default: 2).

  • n_state (int (default: 2)) – Number of discrete states (default: 2).

  • **model_kwargs – Extra args passed to DualVAE3L.

Methods table#

compute_information_flow([adata, dims, ...])

Compute information flow for latent dimensions.

get_3to1_ig([adata, key, celltype, baseline])

Compute IG values from TFs to TGs.

get_3to2_ig([adata, key, celltype, baseline])

Compute IG values from TFs to REs.

get_3to2_shap([adata, n_bg_samples, key, ...])

Compute SHAP values from TFs to REs using either DeepExplainer or GradientExplainer.

get_model_output([adata, batch_size])

Obtain model predictions and latent representations for a given dataset.

get_up_feature_weights([method, ...])

Obtain model predictions and latent representations for a given dataset.

get_up_significance(*args, **kwargs)

plot_train_losses([fig_size])

Plot training loss curves for all recorded losses during training.

pretrain_attention([prior_probs, ...])

Pretrain attention network.

train([max_epochs, lr, use_gpu, train_size, ...])

Trains the model using fractal variational autoencoder.

Methods#

CauTrigger3L.compute_information_flow(adata=None, dims=None, zero_floor=False, plot_info_flow=True, skip_single_info=True, save_fig=False, save_dir=None)#

Compute information flow for latent dimensions.

Parameters:
  • adata (Optional[AnnData] (default: None)) – AnnData object with input data

  • dims (Optional[List[int]] (default: None)) – Dimensions to compute information flow for

  • zero_floor (bool (default: False)) – Whether to subtract minimum value

  • plot_info_flow (Optional[bool] (default: True)) – Whether to plot information flow

  • skip_single_info (Optional[bool] (default: True)) – Whether to skip single dimension plots

  • save_fig (Optional[bool] (default: False)) – Whether to save figures

  • save_dir (Optional[str] (default: None)) – Directory to save figures

Returns:

info_flow

Information flow for each dimension

info_flow_cat

Categorical information flow (causal vs spurious)

CauTrigger3L.get_3to1_ig(adata=None, key='prob', celltype=None, baseline=None)#

Compute IG values from TFs to TGs.

CauTrigger3L.get_3to2_ig(adata=None, key='prob', celltype=None, baseline=None)#

Compute IG values from TFs to REs.

CauTrigger3L.get_3to2_shap(adata=None, n_bg_samples=5, key='prob', celltype=None, explainer_type='gradient', data_dir=None)#

Compute SHAP values from TFs to REs using either DeepExplainer or GradientExplainer.

CauTrigger3L.get_model_output(adata=None, batch_size=None)#

Obtain model predictions and latent representations for a given dataset.

This method runs the trained model in evaluation mode and returns: - Concatenated latent embeddings from two latent spaces, - Logits and predicted probabilities from the downstream classifier, - Binary class predictions (thresholded at 0.5).

If no adata is provided, the method uses the internal self.adata.

Parameters:
  • adata (AnnData, optional) – Annotated data matrix to generate outputs for. If None, defaults to self.adata. Default is None.

  • batch_size (int, optional) – Number of samples per batch during inference. If None, uses self.batch_size. Default is None.

Returns:

-output (dict) Dictionary containing the following keys: - 'latent': numpy.ndarray of shape (n_samples, n_latent1 + n_latent2),

concatenated latent vectors from both latent modules.

  • 'logits': numpy.ndarray of shape (n_samples,) or (n_samples, n_classes), raw classifier logits.

  • 'probs': numpy.ndarray of same shape as 'logits', predicted probabilities after sigmoid/softmax activation.

  • 'preds': numpy.ndarray of shape (n_samples,), binary predictions (1 if probability > 0.5, else 0).

CauTrigger3L.get_up_feature_weights(method='SHAP', n_bg_samples=100, grad_source='prob', normalize=True, sort_by_weight=False, class_idx=None)#

Obtain model predictions and latent representations for a given dataset.

This method runs the trained model in evaluation mode and returns: - Concatenated latent embeddings from two latent spaces, - Logits and predicted probabilities from the downstream classifier, - Binary class predictions (thresholded at 0.5).

If no adata is provided, the method uses the internal self.adata.

Parameters:
  • adata (AnnData, optional) – Annotated data matrix to generate outputs for. If None, defaults to self.adata. Default is None.

  • batch_size (int, optional) – Number of samples per batch during inference. If None, uses self.batch_size. Default is None.

Returns:

-output (dict) Dictionary containing the following keys: - 'latent': numpy.ndarray of shape (n_samples, n_latent1 + n_latent2),

concatenated latent vectors from both latent modules.

  • 'logits': numpy.ndarray of shape (n_samples,) or (n_samples, n_classes), raw classifier logits.

  • 'probs': numpy.ndarray of same shape as 'logits', predicted probabilities after sigmoid/softmax activation.

  • 'preds': numpy.ndarray of shape (n_samples,), binary predictions (1 if probability > 0.5, else 0).

CauTrigger3L.get_up_significance(*args, **kwargs)#
CauTrigger3L.plot_train_losses(fig_size=(8, 8))#

Plot training loss curves for all recorded losses during training.

This method visualizes the evolution of each loss component over epochs using subplots. It requires that the model has been trained and that training history is available in the self.history attribute.

Parameters:

fig_size (tuple of int, optional) – Figure size (width, height) in inches. Default is (8, 8).

CauTrigger3L.pretrain_attention(prior_probs=None, max_epochs=50, pretrain_lr=0.001, batch_size=128, use_gpu=None, train_size=1.0, validation_size=None)#

Pretrain attention network.

Parameters:
  • prior_probs (Optional[ndarray] (default: None)) – Prior probabilities for attention weights

  • max_epochs (Optional[int] (default: 50)) – Maximum number of pretraining epochs

  • pretrain_lr (float (default: 0.001)) – Learning rate for pretraining

  • batch_size (int (default: 128)) – Number of samples per batch

  • use_gpu (Union[str, int, bool, None] (default: None)) – Whether to use GPU for pretraining

  • train_size (float (default: 1.0)) – Proportion of data to use for pretraining

  • validation_size (Optional[float] (default: None)) – Proportion of data to use for validation

CauTrigger3L.train(max_epochs=400, lr=0.0005, use_gpu=None, train_size=1.0, validation_size=None, batch_size=128, early_stopping=False, weight_decay=1e-06, n_x=5, n_alpha=25, n_beta=100, recons_weight=1.0, kl_weight=0.02, up_weight=1.0, down_weight=1.0, feat_l1_weight=0.05, dpd_weight=3.0, fide_kl_weight=0.05, causal_weight=1.0, down_fold=1.0, causal_fold=1.0, spurious_fold=1.0, stage_training=True, weight_scheme=None, im_factor=None, drop_last=False, **kwargs)#

Trains the model using fractal variational autoencoder.

Parameters:
  • max_epochs (Optional[int] (default: 400)) – Maximum number of training epochs

  • lr (float (default: 0.0005)) – Learning rate for optimizer

  • use_gpu (Union[str, int, bool, None] (default: None)) – Whether to use GPU for training

  • train_size (float (default: 1.0)) – Proportion of data to use for training

  • validation_size (Optional[float] (default: None)) – Proportion of data to use for validation

  • batch_size (int (default: 128)) – Number of samples per batch

  • early_stopping (bool (default: False)) – Whether to use early stopping

  • weight_decay (float (default: 1e-06)) – Weight decay for optimizer

  • n_x (int (default: 5)) – Number of samples for causal effect computation

  • n_alpha (int (default: 25)) – Monte-carlo samples per causal factor

  • n_beta (int (default: 100)) – Monte-carlo samples per noncausal factor

  • recons_weight (float (default: 1.0)) – Weight for reconstruction loss

  • kl_weight (float (default: 0.02)) – Weight for KL divergence loss

  • up_weight (float (default: 1.0)) – Weight for upstream reconstruction

  • down_weight (float (default: 1.0)) – Weight for downstream reconstruction

  • feat_l1_weight (float (default: 0.05)) – Weight for feature L1 loss

  • dpd_weight (float (default: 3.0)) – Weight for DPD loss

  • fide_kl_weight (float (default: 0.05)) – Weight for fidelity KL loss

  • causal_weight (float (default: 1.0)) – Weight for causal loss

  • down_fold (float (default: 1.0)) – Downstream loss scaling factor

  • causal_fold (float (default: 1.0)) – Causal loss scaling factor

  • spurious_fold (float (default: 1.0)) – Spurious loss scaling factor

  • stage_training (bool (default: True)) – Whether to use staged training

  • weight_scheme (Optional[str] (default: None)) – Weight update scheme

  • im_factor (Optional[float] (default: None)) – Imbalance factor for loss computation

  • drop_last (Union[bool, int] (default: False)) – Number of samples to drop from last batch

  • **kwargs – Additional arguments