HESC differentiation causal analysis#

In this tutorial, we demonstrate how to apply CauTrigger to a real single-cell RNA-seq dataset profiling the differentiation of human embryonic stem cells (hESCs) toward definitive endoderm.

Cells collected at 0 h (state 0) and 96 h (state 1) are treated as two distinct system states.

We will:

  • preprocess data (highly variable genes),

  • construct TF/TG inputs,

  • train CauTrigger,

  • examine top causal drivers,

  • compare against a curated gold standard,

  • visualize information flow between molecular layers.

Import libraries and set working directory#

import os
import sys
import warnings, logging
warnings.filterwarnings("ignore", module="matplotlib")
logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR)

import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']

import numpy as np
import pandas as pd
import scanpy as sc
import anndata
import matplotlib.pyplot as plt
from cautrigger.model import CauTrigger2L
from cautrigger.utils import set_seed

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['font.family'] = 'sans-serif'


BASE_DIR = os.path.abspath("..")  # go to 03_hesc/
case_path = BASE_DIR
data_path = os.path.join(case_path, 'data/')
output_path = os.path.join(case_path, 'output/')
os.makedirs(output_path, exist_ok=True)
data_path, output_path
('/mnt/e/Project_Research/CauTrigger_Project/CauTrigger-reproducibility/03_hesc/data/',
 '/mnt/e/Project_Research/CauTrigger_Project/CauTrigger-reproducibility/03_hesc/output/')

Prepare data and run CauTrigger#

We load the hESC expression matrix, select highly variable genes (HVG), split transcription factors (TFs) and downstream targets (TGs) based on TRRUST, construct aligned TF/TG inputs for CauTrigger, and train a two-layer model.

# Load expression matrix
expData = pd.read_csv(os.path.join(data_path, 'hESC_ExpressionData.csv'), index_col=0).transpose()
adata = sc.AnnData(X=expData, dtype=np.float32)
sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor='cell_ranger')
hESC = adata[:, adata.var.highly_variable]

# TF prior (TRRUST v2)
Trrust = pd.read_table(os.path.join(data_path, 'trrust_rawdata.human.tsv'), header=None)
Trrust_TF = Trrust.iloc[:, 0].dropna().unique()
Trrust_nonTF = np.setdiff1d(Trrust.iloc[:, 1].dropna().unique(), Trrust_TF)

# Pseudotime → binary labels
hESC_pt = pd.read_csv(os.path.join(data_path, 'hESC_PseudoTime.csv'), index_col=0)
hESC_pt['cell_type_1'] = np.where(hESC_pt.index.str.contains('00h'), 0, np.where(hESC_pt.index.str.contains('96h'), 1, np.nan))
hESC.obs = hESC_pt

# Split TF/downstream genes by state
start_TF = hESC[(hESC.obs['cell_type_1'] == 0), np.intersect1d(hESC.var_names, Trrust_TF)]
end_TF = hESC[(hESC.obs['cell_type_1'] == 1), np.intersect1d(hESC.var_names, Trrust_TF)]
start_down = hESC[(hESC.obs['cell_type_1'] == 0), np.intersect1d(hESC.var_names, Trrust_nonTF)]
end_down = hESC[(hESC.obs['cell_type_1'] == 1), np.intersect1d(hESC.var_names, Trrust_nonTF)]

# Build AnnData for CauTrigger
adata = anndata.concat([start_TF.copy(), end_TF.copy()])
adata.obs['labels'] = np.repeat([0, 1], [start_TF.shape[0], end_TF.shape[0]])
adata.obsm['X_down'] = anndata.concat([start_down, end_down]).X.copy()

# Train CauTrigger
set_seed(42)
model = CauTrigger2L(
        adata,
        n_causal=2,
        n_latent=10,
    )
model.train(max_epochs=300, stage_training=True, im_factor=1)
# Extract ranked causal drivers
weight_df = model.get_up_feature_weights(normalize=True, method="Model", sort_by_weight=False)[0]
model_res = pd.DataFrame({'weight_value': weight_df['weight'], }).sort_values('weight_value', ascending=False)
model_res.head(10)
weight_value
SMAD2 0.027519
ZIC2 0.027348
EOMES 0.027201
OTX2 0.027046
APC 0.027044
GATA6 0.026952
ZIC3 0.026899
NANOG 0.026859
PRDM1 0.026833
WWP1 0.026580

Venn overlap with ground truths#

def load_ground_truth(data_path):
    hESC_ground_truth = {}
    Trrust_human = pd.read_table(os.path.join(data_path, 'trrust_rawdata.human.tsv'), header=None)
    Trrust_human_TF = Trrust_human.iloc[:, 0].dropna().unique()
    hESC_files = [
        ('cell_fate_commitment', 'GO_CELL_FATE_COMMITMENT_my.txt'),
        ('stem_cell_population_maintenance', 'GO_STEM_CELL_POPULATION_MAINTENANCE_my.txt'),
        ('endoderm_development', 'GO_ENDODERM_DEVELOPMENT_my.txt')
    ]
    for name, file in hESC_files:
        df = pd.read_csv(os.path.join(data_path, file))
        hESC_ground_truth[name] = set(np.intersect1d(df.iloc[:, 0], Trrust_human_TF))
    cell2011_genes = set(pd.read_csv(os.path.join(data_path, 'ESC_Cell2011.csv'), encoding='latin1')['TFs'])
    reproduction2008_genes = set(pd.read_csv(os.path.join(data_path, 'ESC_Reproduction2008.csv'))['TFs'])
    literature_curated = cell2011_genes.union(reproduction2008_genes)
    hESC_ground_truth['literature_curated'] = literature_curated
    hESC_ground_truth['all'] = set.union(*hESC_ground_truth.values())
    return hESC_ground_truth
    
from matplotlib_venn import venn2

gt = load_ground_truth(data_path)['all']  # reuse previous function
pred = set(model_res.head(10).index)

plt.figure(figsize=(3,3))
venn2([pred, gt], set_labels=('Top 10 predicted', 'ground truths'))
plt.show()
../_images/8ee9136543db0cbc91ab5f76693324e0b09051a709ab842059e8989b6883d3da.png

Visualize causal information flow#

flow = model.compute_information_flow(zero_floor=True, save_fig=False, skip_single_info=True)
../_images/30bf934fe9d9087bbf548a771dbf9d9462da9f16f29bacb14ce6197b33de2529.png