cautrigger.model.CauTrigger3L#
- class cautrigger.model.CauTrigger3L(adata, n_latent=10, n_causal=2, n_state=2, **model_kwargs)#
Third-layer hierarchical causal decomposition model. CauTrigger3L wraps a DualVAE3L module and supports third-stage causal decomposition (eg. x3 → xc2 → xc1 → y). It is intended for modelling complex cascades where effects propagate through multiple intermediate regulatory layers (for instance, multi-omic cascades).
- Parameters:
adata (
AnnData) – Annotated data matrix. Expects upstream features inadata.Xand downstream representations inadata.obsm['X_down1']andadata.obsm['X_down2'].n_latent (
int(default:10)) – Latent dimensionality (default: 10).n_causal (
int(default:2)) – Number of causal latent factors (default: 2).n_state (
int(default:2)) – Number of discrete states (default: 2).**model_kwargs – Extra args passed to DualVAE3L.
Methods table#
|
Compute information flow for latent dimensions. |
|
Compute IG values from TFs to TGs. |
|
Compute IG values from TFs to REs. |
|
Compute SHAP values from TFs to REs using either DeepExplainer or GradientExplainer. |
|
Obtain model predictions and latent representations for a given dataset. |
|
Obtain model predictions and latent representations for a given dataset. |
|
|
|
Plot training loss curves for all recorded losses during training. |
|
Pretrain attention network. |
|
Trains the model using fractal variational autoencoder. |
Methods#
- CauTrigger3L.compute_information_flow(adata=None, dims=None, zero_floor=False, plot_info_flow=True, skip_single_info=True, save_fig=False, save_dir=None)#
Compute information flow for latent dimensions.
- Parameters:
adata (
Optional[AnnData] (default:None)) – AnnData object with input datadims (
Optional[List[int]] (default:None)) – Dimensions to compute information flow forzero_floor (
bool(default:False)) – Whether to subtract minimum valueplot_info_flow (
Optional[bool] (default:True)) – Whether to plot information flowskip_single_info (
Optional[bool] (default:True)) – Whether to skip single dimension plotssave_fig (
Optional[bool] (default:False)) – Whether to save figuressave_dir (
Optional[str] (default:None)) – Directory to save figures
- Returns:
- info_flow
Information flow for each dimension
- info_flow_cat
Categorical information flow (causal vs spurious)
- CauTrigger3L.get_3to1_ig(adata=None, key='prob', celltype=None, baseline=None)#
Compute IG values from TFs to TGs.
- CauTrigger3L.get_3to2_ig(adata=None, key='prob', celltype=None, baseline=None)#
Compute IG values from TFs to REs.
- CauTrigger3L.get_3to2_shap(adata=None, n_bg_samples=5, key='prob', celltype=None, explainer_type='gradient', data_dir=None)#
Compute SHAP values from TFs to REs using either DeepExplainer or GradientExplainer.
- CauTrigger3L.get_model_output(adata=None, batch_size=None)#
Obtain model predictions and latent representations for a given dataset.
This method runs the trained model in evaluation mode and returns: - Concatenated latent embeddings from two latent spaces, - Logits and predicted probabilities from the downstream classifier, - Binary class predictions (thresholded at 0.5).
If no
adatais provided, the method uses the internalself.adata.- Parameters:
adata (AnnData, optional) – Annotated data matrix to generate outputs for. If None, defaults to
self.adata. Default is None.batch_size (int, optional) – Number of samples per batch during inference. If None, uses
self.batch_size. Default is None.
- Returns:
-output (
dict) Dictionary containing the following keys: -'latent': numpy.ndarray of shape(n_samples, n_latent1 + n_latent2),concatenated latent vectors from both latent modules.
'logits': numpy.ndarray of shape(n_samples,)or(n_samples, n_classes), raw classifier logits.'probs': numpy.ndarray of same shape as'logits', predicted probabilities after sigmoid/softmax activation.'preds': numpy.ndarray of shape(n_samples,), binary predictions (1 if probability > 0.5, else 0).
- CauTrigger3L.get_up_feature_weights(method='SHAP', n_bg_samples=100, grad_source='prob', normalize=True, sort_by_weight=False, class_idx=None)#
Obtain model predictions and latent representations for a given dataset.
This method runs the trained model in evaluation mode and returns: - Concatenated latent embeddings from two latent spaces, - Logits and predicted probabilities from the downstream classifier, - Binary class predictions (thresholded at 0.5).
If no
adatais provided, the method uses the internalself.adata.- Parameters:
adata (AnnData, optional) – Annotated data matrix to generate outputs for. If None, defaults to
self.adata. Default is None.batch_size (int, optional) – Number of samples per batch during inference. If None, uses
self.batch_size. Default is None.
- Returns:
-output (
dict) Dictionary containing the following keys: -'latent': numpy.ndarray of shape(n_samples, n_latent1 + n_latent2),concatenated latent vectors from both latent modules.
'logits': numpy.ndarray of shape(n_samples,)or(n_samples, n_classes), raw classifier logits.'probs': numpy.ndarray of same shape as'logits', predicted probabilities after sigmoid/softmax activation.'preds': numpy.ndarray of shape(n_samples,), binary predictions (1 if probability > 0.5, else 0).
- CauTrigger3L.get_up_significance(*args, **kwargs)#
- CauTrigger3L.plot_train_losses(fig_size=(8, 8))#
Plot training loss curves for all recorded losses during training.
This method visualizes the evolution of each loss component over epochs using subplots. It requires that the model has been trained and that training history is available in the
self.historyattribute.
- CauTrigger3L.pretrain_attention(prior_probs=None, max_epochs=50, pretrain_lr=0.001, batch_size=128, use_gpu=None, train_size=1.0, validation_size=None)#
Pretrain attention network.
- Parameters:
prior_probs (
Optional[ndarray] (default:None)) – Prior probabilities for attention weightsmax_epochs (
Optional[int] (default:50)) – Maximum number of pretraining epochspretrain_lr (
float(default:0.001)) – Learning rate for pretrainingbatch_size (
int(default:128)) – Number of samples per batchuse_gpu (
Union[str,int,bool,None] (default:None)) – Whether to use GPU for pretrainingtrain_size (
float(default:1.0)) – Proportion of data to use for pretrainingvalidation_size (
Optional[float] (default:None)) – Proportion of data to use for validation
- CauTrigger3L.train(max_epochs=400, lr=0.0005, use_gpu=None, train_size=1.0, validation_size=None, batch_size=128, early_stopping=False, weight_decay=1e-06, n_x=5, n_alpha=25, n_beta=100, recons_weight=1.0, kl_weight=0.02, up_weight=1.0, down_weight=1.0, feat_l1_weight=0.05, dpd_weight=3.0, fide_kl_weight=0.05, causal_weight=1.0, down_fold=1.0, causal_fold=1.0, spurious_fold=1.0, stage_training=True, weight_scheme=None, im_factor=None, drop_last=False, **kwargs)#
Trains the model using fractal variational autoencoder.
- Parameters:
max_epochs (
Optional[int] (default:400)) – Maximum number of training epochslr (
float(default:0.0005)) – Learning rate for optimizeruse_gpu (
Union[str,int,bool,None] (default:None)) – Whether to use GPU for trainingtrain_size (
float(default:1.0)) – Proportion of data to use for trainingvalidation_size (
Optional[float] (default:None)) – Proportion of data to use for validationbatch_size (
int(default:128)) – Number of samples per batchearly_stopping (
bool(default:False)) – Whether to use early stoppingweight_decay (
float(default:1e-06)) – Weight decay for optimizern_x (
int(default:5)) – Number of samples for causal effect computationn_alpha (
int(default:25)) – Monte-carlo samples per causal factorn_beta (
int(default:100)) – Monte-carlo samples per noncausal factorrecons_weight (
float(default:1.0)) – Weight for reconstruction losskl_weight (
float(default:0.02)) – Weight for KL divergence lossup_weight (
float(default:1.0)) – Weight for upstream reconstructiondown_weight (
float(default:1.0)) – Weight for downstream reconstructionfeat_l1_weight (
float(default:0.05)) – Weight for feature L1 lossdpd_weight (
float(default:3.0)) – Weight for DPD lossfide_kl_weight (
float(default:0.05)) – Weight for fidelity KL losscausal_weight (
float(default:1.0)) – Weight for causal lossdown_fold (
float(default:1.0)) – Downstream loss scaling factorcausal_fold (
float(default:1.0)) – Causal loss scaling factorspurious_fold (
float(default:1.0)) – Spurious loss scaling factorstage_training (
bool(default:True)) – Whether to use staged trainingweight_scheme (
Optional[str] (default:None)) – Weight update schemeim_factor (
Optional[float] (default:None)) – Imbalance factor for loss computationdrop_last (
Union[bool,int] (default:False)) – Number of samples to drop from last batch**kwargs – Additional arguments