cautrigger.model.CauTrigger2L#
- class cautrigger.model.CauTrigger2L(adata, n_latent=10, n_causal=2, n_state=2, **model_kwargs)#
Second-layer hierarchical causal decomposition model. CauTrigger2L wraps a DualVAE2L backbone and supports a second-stage causal decomposition (e.g., upstream -> downstream -> state).
- Parameters:
adata (
AnnData) – Annotated data matrix. Expects upstream features inadata.Xand the first downstream representation stored inadata.obsm['X_down'].n_latent (
int(default:10)) – Latent dimension (default: 10).n_causal (
int(default:2)) – Number of causal factors (default: 2).n_state (
int(default:2)) – Number of discrete states (default: 2).**model_kwargs – Forwarded to the DualVAE2L constructor.
Methods table#
|
Compute information flow for latent dimensions. |
|
Compute Integrated Gradients (IG) attributions from UP features to DOWN features. |
|
Compute Integrated Gradients for input to latent space. |
|
Obtain model predictions and latent representations for a given dataset. |
|
Compute and return feature importance weights for the upstream feature mapper. |
|
Compute significance of upstream feature weights using: - Grad → Binomial sign-consistency test - SHAP → Binomial sign-consistency test or permutation test |
|
Plot training loss curves for all recorded losses during training. |
|
Pretrain attention network. |
|
Trains the model using fractal variational autoencoder. |
Methods#
- CauTrigger2L.compute_information_flow(adata=None, dims=None, zero_floor=False, plot_info_flow=True, skip_single_info=True, save_fig=False, save_dir=None)#
Compute information flow for latent dimensions.
- Parameters:
adata (
Optional[AnnData] (default:None)) – AnnData object with input datadims (
Optional[List[int]] (default:None)) – Dimensions to compute information flow forzero_floor (
bool(default:False)) – Whether to subtract minimum valueplot_info_flow (
Optional[bool] (default:True)) – Whether to plot information flowskip_single_info (
Optional[bool] (default:True)) – Whether to skip single dimension plotssave_fig (
Optional[bool] (default:False)) – Whether to save figuressave_dir (
Optional[str] (default:None)) – Directory to save figures
- Returns:
- info_flow
Information flow for each dimension
- info_flow_cat
Categorical information flow (causal vs spurious)
- CauTrigger2L.get_2to1_ig(adata=None, key='prob', celltype=None, baseline=None)#
Compute Integrated Gradients (IG) attributions from UP features to DOWN features. This method calculates feature-wise attribution scores using Integrated Gradients to understand how each up influences each down in the model.
- Parameters:
adata (AnnData, optional) – Annotated data matrix to compute attributions for. If None, defaults to
self.adata. Default is None.key (str, optional) – Model output key to attribute to. Typically
'prob'(probability) or'logit'. Default is'prob'.celltype (str or None, optional) – Specific cell type to subset for attribution. If None or
'all', uses all cells. If specified, must exist inadata.obs['celltype']. Default is None.baseline (torch.Tensor or None, optional) – Baseline input for Integrated Gradients (same shape as input). If None, uses a zero tensor as baseline. Default is None.
- Returns:
-ig_scores (
ndarray) Integrated Gradients attribution scores with shape(n_cells, n_up, n_down),
- CauTrigger2L.get_input2z_ig(adata=None, key='prob', baseline=None)#
Compute Integrated Gradients for input to latent space.
- CauTrigger2L.get_model_output(adata=None, batch_size=None)#
Obtain model predictions and latent representations for a given dataset.
This method runs the trained model in evaluation mode and returns: - Concatenated latent embeddings from two latent spaces, - Logits and predicted probabilities from the downstream classifier, - Binary class predictions (thresholded at 0.5).
If no
adatais provided, the method uses the internalself.adata.- Parameters:
adata (AnnData, optional) – Annotated data matrix to generate outputs for. If None, defaults to
self.adata. Default is None.batch_size (int, optional) – Number of samples per batch during inference. If None, uses
self.batch_size. Default is None.
- Returns:
-output (
dict) Dictionary containing the following keys: -'latent': numpy.ndarray of shape(n_samples, n_latent1 + n_latent2),concatenated latent vectors from both latent modules.
'logits': numpy.ndarray of shape(n_samples,)or(n_samples, n_classes), raw classifier logits.'probs': numpy.ndarray of same shape as'logits', predicted probabilities after sigmoid/softmax activation.'preds': numpy.ndarray of shape(n_samples,), binary predictions (1 if probability > 0.5, else 0).
- CauTrigger2L.get_up_feature_weights(adata=None, method='SHAP', n_bg_samples=100, grad_source='prob', normalize=True, sort_by_weight=True, class_idx=None, background_data=None, return_background=False)#
Compute and return feature importance weights for the upstream feature mapper.
This method supports multiple strategies to estimate feature contributions: - “Model”: uses internal attention or learned weights from the model. - “SHAP”: computes SHAP values using DeepExplainer. - “Grad”: computes input gradients w.r.t. a specified output (e.g., probability or logit). - “Ensemble”: averages normalized absolute weights from all three methods above.
The resulting weights are aggregated across samples (by mean), optionally normalized, and returned both as a sorted DataFrame aligned with
self.adata.varand as a full sample-by-feature weight matrix.- Parameters:
method (str, optional) – Method to compute feature weights. One of {“Model”, “SHAP”, “Grad”, “Ensemble”}. Default is “SHAP”.
n_bg_samples (int, optional) – Number of background samples used for SHAP explanation. Only relevant if
method="SHAP". Default is 100.grad_source (str, optional) – Target output for gradient computation when
method="Grad". Options are: - “prob”: gradients w.r.t. predicted probabilities, - “logit”: gradients w.r.t. logits, - “loss”: gradients w.r.t. the DPD loss. Default is “prob”.normalize (bool, optional) – Whether to normalize the final feature weights to sum to 1. Default is True.
sort_by_weight (bool, optional) – Whether to sort the returned DataFrame by weight in descending order. Default is True.
class_idx (int, optional) – Class index for which to compute SHAP values (only used when labels are present and
method="SHAP"). If None, SHAP values are averaged over all classes or computed on the full dataset. Default is None.
- Returns:
- -weights_df (
DataFrame) DataFrame with the same index as
self.adata.var, containing a new column'weight'with the computed feature importance scores. Sorted by weight ifsort_by_weight=True.- -weights_full (
ndarray) Full sample-by-feature matrix of absolute weights before aggregation. Shape:
(n_samples, n_features).
- -weights_df (
- CauTrigger2L.get_up_significance(adata=None, method='SHAP', test_mode='permutation', perm_mode='global', n_perm=100, n_bg_samples=100, grad_source='prob', normalize=False, class_idx=None, target_genes=None, fdr_correct=True, use_signed=True, show_progress=True, random_state=42)#
Compute significance of upstream feature weights using: - Grad → Binomial sign-consistency test - SHAP → Binomial sign-consistency test or permutation test
- Parameters:
adata (AnnData, optional) – Input AnnData object (default: self.adata)
method (str) – “SHAP” or “Grad”
test_mode (str) – For SHAP: “sign_test” or “permutation” (ignored for Grad)
perm_mode (str) – “global” or “per_feature” shuffle strategy (for permutation test)
n_perm (int) – Number of permutations (for permutation test)
n_bg_samples (int) – Number of background samples for SHAP
grad_source (str) – Source for gradient-based attribution (“prob”, “logit”, or “loss”)
normalize (bool) – Whether to normalize weights across features
class_idx (int, optional) – Target class index for class-specific analysis
target_genes (list of str, optional) – Genes/features to test; if None, use all
fdr_correct (bool) – Apply Benjamini–Hochberg correction
use_signed (bool) – Whether to use signed weights for p-value calculation
show_progress (bool) – Display progress bar
random_state (int, optional) – Random seed
- Returns:
- -df_result (
DataFrame) DataFrame with [‘weight’, ‘weight_signed’, ‘pvalue’, (‘qvalue’)]
- perm_matrixnp.ndarray or None
Permutation matrix for SHAP (None for Grad or sign_test)
- -df_result (
- CauTrigger2L.plot_train_losses(fig_size=(8, 8))#
Plot training loss curves for all recorded losses during training.
This method visualizes the evolution of each loss component over epochs using subplots. It requires that the model has been trained and that training history is available in the
self.historyattribute.
- CauTrigger2L.pretrain_attention(prior_probs=None, max_epochs=50, pretrain_lr=0.001, batch_size=128, use_gpu=None, train_size=1.0, validation_size=None)#
Pretrain attention network.
- Parameters:
prior_probs (
Optional[ndarray] (default:None)) – Prior probabilities for attention weightsmax_epochs (
Optional[int] (default:50)) – Maximum number of pretraining epochspretrain_lr (
float(default:0.001)) – Learning rate for pretrainingbatch_size (
int(default:128)) – Number of samples per batchuse_gpu (
Union[str,int,bool,None] (default:None)) – Whether to use GPU for pretrainingtrain_size (
float(default:1.0)) – Proportion of data to use for pretrainingvalidation_size (
Optional[float] (default:None)) – Proportion of data to use for validation
- CauTrigger2L.train(max_epochs=400, lr=0.0005, use_gpu=None, train_size=1.0, validation_size=None, batch_size=128, early_stopping=False, weight_decay=1e-06, n_x=5, n_alpha=25, n_beta=100, recons_weight=1.0, kl_weight=0.02, up_weight=1.0, down_weight=1.0, feat_l1_weight=0.05, dpd_weight=3.0, fide_kl_weight=0.05, causal_weight=1.0, down_fold=1.0, causal_fold=1.0, spurious_fold=1.0, stage_training=True, weight_scheme=None, im_factor=None, drop_last=10, **kwargs)#
Trains the model using fractal variational autoencoder.
- Parameters:
max_epochs (
Optional[int] (default:400)) – Maximum number of training epochslr (
float(default:0.0005)) – Learning rate for optimizeruse_gpu (
Union[str,int,bool,None] (default:None)) – Whether to use GPU for trainingtrain_size (
float(default:1.0)) – Proportion of data to use for trainingvalidation_size (
Optional[float] (default:None)) – Proportion of data to use for validationbatch_size (
int(default:128)) – Number of samples per batchearly_stopping (
bool(default:False)) – Whether to use early stoppingweight_decay (
float(default:1e-06)) – Weight decay for optimizern_x (
int(default:5)) – Number of samples for causal effect computationn_alpha (
int(default:25)) – Monte-carlo samples per causal factorn_beta (
int(default:100)) – Monte-carlo samples per noncausal factorrecons_weight (
float(default:1.0)) – Weight for reconstruction losskl_weight (
float(default:0.02)) – Weight for KL divergence lossup_weight (
float(default:1.0)) – Weight for upstream reconstructiondown_weight (
float(default:1.0)) – Weight for downstream reconstructionfeat_l1_weight (
float(default:0.05)) – Weight for feature L1 lossdpd_weight (
float(default:3.0)) – Weight for DPD lossfide_kl_weight (
float(default:0.05)) – Weight for fidelity KL losscausal_weight (
float(default:1.0)) – Weight for causal lossdown_fold (
float(default:1.0)) – Downstream loss scaling factorcausal_fold (
float(default:1.0)) – Causal loss scaling factorspurious_fold (
float(default:1.0)) – Spurious loss scaling factorstage_training (
bool(default:True)) – Whether to use staged trainingweight_scheme (
Optional[str] (default:None)) – Weight update schemeim_factor (
Optional[float] (default:None)) – Imbalance factor for loss computationdrop_last (
int(default:10)) – Number of samples to drop from last batch**kwargs – Additional arguments