TF Activity and GRN Inference Tutorial¶
scRegulate v0.1
¶
Compiled: May 13, 2025¶
In this tutorial, we demonstrate how to use scRegulate
for in-silico transcription factor (TF) activity and gene regulatory network (GRN) inference using the PBMC 3K
dataset. This dataset is a well-known 10X Genomics experiment and is readily accessible via scanpy
.
The scRegulate
pipeline consists of the following steps:
- Preprocessing: Filter low-quality cells and apply
log
-normalization. - Prior GRN selection: Use TF-Link (default) or custom priors.
- Variational inference of TF activities using a
VAE
-based architecture. Followed by Global fine-tuning across all cells (no clustering needed). - Downstream analysis: Extract TF activity embeddings, perform
Leiden
clustering, identify differentially active (DA
) TFs, and reconstruct posteriorGRNs
. - Optional but recommended if you know the labels: Cell-type-specific fine-tuning using provided cluster
labels
.
1. Preprocessing scRNA data¶
We are going to need pandas
, scanpy
and scregulate
, thus we import these libraries first.
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import scregulate as reg
We then load the PBMC 3k data from scanpy
as a convenient way to access the data:
rna_data_all = sc.datasets.pbmc3k_processed()
The dataset contains 2,638 cells. The processed expression matrix in .X
includes 1,838 highly variable genes, while the raw counts matrix in .raw.X
retains all 13,714 genes.
Cell-type labels are stored in the .obs
attribute under the louvain
column.
print("🔍 AnnData summary:")
print(rna_data_all)
print("\n📐 Shape of processed matrix (.X):")
print(f"{rna_data_all.X.shape} # (cells, HVGs)")
print("\n📦 Shape of raw matrix (.raw.X):")
print(f"{rna_data_all.raw.X.shape} # (cells, all genes)")
print("\n🏷️ Available Louvain cluster labels:")
print(rna_data_all.obs['louvain'].cat.categories.tolist())
🔍 AnnData summary: AnnData object with n_obs × n_vars = 2638 × 1838 obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain' var: 'n_cells' uns: 'draw_graph', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups' obsm: 'X_pca', 'X_tsne', 'X_umap', 'X_draw_graph_fr' varm: 'PCs' obsp: 'distances', 'connectivities' 📐 Shape of processed matrix (.X): (2638, 1838) # (cells, HVGs) 📦 Shape of raw matrix (.raw.X): (2638, 13714) # (cells, all genes) 🏷️ Available Louvain cluster labels: ['CD4 T cells', 'CD14+ Monocytes', 'B cells', 'CD8 T cells', 'NK cells', 'FCGR3A+ Monocytes', 'Dendritic cells', 'Megakaryocytes']
Converting to Raw Expression Matrix
By default, the rna_data_all.X matrix contains log-normalized expression of highly variable genes (HVGs). To work with the full raw gene expression matrix (13,714 genes), we extract the .raw.X layer and recreate a new AnnData object:
# Extract raw data matrix and relevant metadata
raw_X = rna_data_all.raw.X.copy()
obs = rna_data_all.obs.copy() # Cell-level metadata
var = rna_data_all.raw.var.copy() # Gene-level metadata
# Create a new AnnData object from raw expression
rna_data_all = sc.AnnData(X=raw_X, obs=obs, var=var)
# Verify the new AnnData object
print(rna_data_all)
AnnData object with n_obs × n_vars = 2638 × 13714 obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain' var: 'n_cells'
Quality Control and Preprocessing
After loading the raw PBMC dataset, we perform some essential QC steps to remove low-quality cells and rarely expressed genes.
rna_data_all.var_names_make_unique()
rna_data_all.obs_names_make_unique()
sc.pp.filter_cells(rna_data_all, min_genes=3)
sc.pp.filter_genes(rna_data_all, min_cells=10)
# Shuffling with fixded seed
sc.pp.subsample(rna_data_all, fraction=1, random_state=42)
# Data is already normalized, otherwise run the below code
#sc.pp.normalize_total(rna_data_all)
#sc.pp.log1p(rna_data_all)
print('Final Data:')
print(f'rna_data_all.shape: {rna_data_all.shape}')
Final Data: rna_data_all.shape: (2638, 11095)
2. Prior GRN selection¶
Loading the Prior GRN Network for TF Inference:
To guide the model in inferring transcription factor (TF) activities, we load a prior gene regulatory network (GRN) from the collectri_human_net.csv file. This prior encodes known TF–target gene relationships, which serve as structural constraints for downstream inference. COLLECTRI network is included in this package, however, any other prior GRN can be used too.
net = reg.collectri_prior("human") # or "mouse"
print(net.head())
print("Network `collectri_human_net` loaded with shape:", net.shape)
source target weight PMID 0 MYC TERT 1 10022128;10491298;10606235;10637317;10723141;1... 1 SPI1 BGLAP 1 10022617 2 SMAD3 JUN 1 10022869;12374795 3 SMAD4 JUN 1 10022869;12374795 4 STAT5A IL2 1 10022878;11435608;17182565;17911616;22854263;2... Network `collectri_human_net` loaded with shape: (43178, 4)
⚠️ Note: This is the academic version of the COLLECTRI network. Please respect its academic use license. If you intend to use this data for commercial purposes, refer to the original COLLECTRI repository for licensing terms and contact the authors if needed. Alternatively, you are free to use your own GRN by providing a .csv file with two columns: TF and target.
3. Training and Global Fine-Tuning with scRegulate
¶
We now train the scRegulate model to infer transcription factor (TF) activities and learn a regulatory embedding consistent with the provided GRN prior:
model, processed_adata, GRN = reg.train_model(
rna_data=rna_data_all, net=net,
encode_dims=[2048, 256, 64], decode_dims=[256], z_dim=40,
train_val_split_ratio=0.85,
epochs=10000,
freeze_epochs=4500,
learning_rate=1e-4,
alpha_max=1,
alpha_scale=0.0025,
early_stopping_patience=4500,
min_targets=10
)
[INFO - 2025-05-13 14:06:07] Using provided batch size: 3500 [INFO - 2025-05-13 14:06:07] ======================================== [INFO - 2025-05-13 14:06:07] Starting scRegulate TF inference Training Pipeline [INFO - 2025-05-13 14:06:07] ======================================== [INFO - 2025-05-13 14:06:07] Adapting prior and data... [INFO - 2025-05-13 14:06:07] Initial genes in RNA data: 11095 [INFO - 2025-05-13 14:06:07] Genes retained after intersection with network: 3644 [INFO - 2025-05-13 14:06:07] Initial TFs in GRN matrix: 1186 [INFO - 2025-05-13 14:06:08] Retained 3278 genes and 298 transcription factors. [INFO - 2025-05-13 14:06:08] Splitting data with train-validation split ratio=0.85 [INFO - 2025-05-13 14:06:08] Running ULM... [INFO - 2025-05-13 14:06:13] ULM completed in 5.17s [INFO - 2025-05-13 14:06:13] Validating ULM estimates... [INFO - 2025-05-13 14:06:13] ULM estimates validation passed. Shape: torch.Size([2638, 298]) [INFO - 2025-05-13 14:06:13] Transferring data to device cuda [INFO - 2025-05-13 14:06:13] Transfer complete... [INFO - 2025-05-13 14:06:13] Epoch 1: Avg Train Loss = 0.0003, Avg Val Loss = 0.0015, Alpha = 0.0025, Beta = 0.0000, Gamma = 0.0000, Mask Factor: 1.00 [INFO - 2025-05-13 14:06:24] Epoch 501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9500, Beta = 0.0333, Gamma = 0.2167, Mask Factor: 1.00 [INFO - 2025-05-13 14:06:35] Epoch 1001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9700, Beta = 0.0583, Gamma = 0.3787, Mask Factor: 1.00 [INFO - 2025-05-13 14:06:46] Epoch 1501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9725, Beta = 0.0835, Gamma = 0.5425, Mask Factor: 0.97 [INFO - 2025-05-13 14:06:57] Epoch 2001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9725, Beta = 0.0835, Gamma = 0.5425, Mask Factor: 0.75 [INFO - 2025-05-13 14:07:08] Epoch 2501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.25 [INFO - 2025-05-13 14:07:19] Epoch 3001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.03 [INFO - 2025-05-13 14:07:30] Epoch 3501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:07:41] Epoch 4001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:07:53] Epoch 4501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:08:04] Epoch 5001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:08:15] Epoch 5501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:08:26] Epoch 6001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:08:38] Epoch 6501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:08:49] Epoch 7001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:09:00] Epoch 7501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:09:11] Epoch 8001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:09:23] Epoch 8501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:09:34] Epoch 9001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:09:46] Epoch 9501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00 [INFO - 2025-05-13 14:09:57] Training completed in 223.83s [INFO - 2025-05-13 14:09:57] Default modality set to: RNA [INFO - 2025-05-13 14:09:57] Current modality: RNA [INFO - 2025-05-13 14:09:57] `GRN_prior` and `GRN_posterior` stored in the AnnData object under .uns [INFO - 2025-05-13 14:09:57] ======================================== [INFO - 2025-05-13 14:09:57] [FINAL SUMMARY] [INFO - 2025-05-13 14:09:57] Training stopped after 10000 epochs. [INFO - 2025-05-13 14:09:57] Final Train Loss: 150.7574 [INFO - 2025-05-13 14:09:57] Final Valid Loss: 184.3279 [INFO - 2025-05-13 14:09:57] Final Alpha: 0.9750, Beta: 0.1509, Gamma: 0.9806 [INFO - 2025-05-13 14:09:57] Total Training Time: 223.83s [INFO - 2025-05-13 14:09:57] Latent Space Shape: (2638, 40) [INFO - 2025-05-13 14:09:57] TF Space Shape: (2638, 298) [INFO - 2025-05-13 14:09:57] Reconstructed RNA Shape: (2638, 3278) [INFO - 2025-05-13 14:09:57] Original RNA Shape: (2638, 3644) [INFO - 2025-05-13 14:09:57] TFs: 298, Genes: 3644 [INFO - 2025-05-13 14:09:57] ========================================
This initializes and trains a VAE-based model where the encoder infers TF activities, constrained by the provided prior GRN. Training proceeds in two phases:
Pretraining phase with frozen GRN structure (freeze_epochs=4500)
Joint learning phase with full backpropagation
Once trained, we perform global fine-tuning to better capture subtle regulatory cues:
processed_adata, fine_tuned_tf_activities, fine_model, GRN = reg.fine_tuning.fine_tune_clusters(
model=model, processed_adata=processed_adata, epochs=2000
)
2025-05-13 14:09:57,502 - finetune - INFO - Starting fine-tuning for cluster(s)... 2025-05-13 14:09:57,504 - finetune - INFO - Aligning data to 3278 genes matching the GRN. 2025-05-13 14:09:57,504 - finetune - INFO - Cluster key not provided, fine-tuning on all cells together... 2025-05-13 14:09:57,506 - finetune - INFO - Fine-tuning on all cells for 2000 epochs... 2025-05-13 14:09:57,555 - finetune - INFO - Epoch 1, Avg Loss: 102.7285 2025-05-13 14:09:59,848 - finetune - INFO - Epoch 101, Avg Loss: 100.3262 2025-05-13 14:10:02,079 - finetune - INFO - Epoch 201, Avg Loss: 99.4416 2025-05-13 14:10:04,436 - finetune - INFO - Epoch 301, Avg Loss: 98.7599 2025-05-13 14:10:06,745 - finetune - INFO - Epoch 401, Avg Loss: 98.1552 2025-05-13 14:10:09,058 - finetune - INFO - Epoch 501, Avg Loss: 97.5885 2025-05-13 14:10:11,358 - finetune - INFO - Epoch 601, Avg Loss: 97.0624 2025-05-13 14:10:13,637 - finetune - INFO - Epoch 701, Avg Loss: 96.6306 2025-05-13 14:10:15,917 - finetune - INFO - Epoch 801, Avg Loss: 96.1426 2025-05-13 14:10:18,260 - finetune - INFO - Epoch 901, Avg Loss: 95.7572 2025-05-13 14:10:20,509 - finetune - INFO - Epoch 1001, Avg Loss: 95.4072 2025-05-13 14:10:22,806 - finetune - INFO - Epoch 1101, Avg Loss: 95.1007 2025-05-13 14:10:25,167 - finetune - INFO - Epoch 1201, Avg Loss: 94.8288 2025-05-13 14:10:27,457 - finetune - INFO - Epoch 1301, Avg Loss: 94.3621 2025-05-13 14:10:29,771 - finetune - INFO - Epoch 1401, Avg Loss: 94.0332 2025-05-13 14:10:31,986 - finetune - INFO - Epoch 1501, Avg Loss: 93.6669 2025-05-13 14:10:34,301 - finetune - INFO - Epoch 1601, Avg Loss: 93.2813 2025-05-13 14:10:36,591 - finetune - INFO - Epoch 1701, Avg Loss: 92.9463 2025-05-13 14:10:38,861 - finetune - INFO - Epoch 1801, Avg Loss: 92.7483 2025-05-13 14:10:41,154 - finetune - INFO - Epoch 1901, Avg Loss: 92.4543 2025-05-13 14:10:43,585 - finetune - INFO - Epoch 2000, Avg Loss: 92.0715 2025-05-13 14:10:43,610 - finetune - INFO - Fine-tuning completed for all clusters.
This additional global fine-tuning step further refines TF activities across all cells jointly, enhancing biological signal and robustness. You can now proceed to analyze the TF activity embeddings, extract GRNs, perform clustering, and identify differentially active TFs across conditions or cell types.
🔔 Note: This is global fine-tuning. Cell-type-specific fine-tuning using cluster labels can be performed in a separate step, described later in this tutorial (section 5).
4. Downstream analysis¶
Once training and fine-tuning are complete, we treat the inferred transcription factor (TF) activities in scRegulate
as we would normalized gene expression values. This makes the downstream workflow intuitive and fully compatible with standard single-cell analysis tools.
In fact, TF activity embeddings (processed_adata.obsm['TF_activity']
) can be visualized, clustered, and compared across cell type labels (e.g. louvain
) or conditions using the same methods typically applied to gene expression.
🎻 Violin Plot of TF Activity
We can visualize TF activity distributions across clusters using violin plots — just like we do for marker genes:
sc.pl.violin(fine_tuned_tf_activities, keys=['PAX5','EOMES'], rotation=45, groupby='louvain', jitter=False)
UMAP Visualization of Inferred TF Activity
After fine-tuning the model, we can visualize cell states using the inferred transcription factor (TF) activities, just like we would with gene expression.
# Use the TF activity representation for UMAP
sc.pp.neighbors(processed_adata, use_rep="TF_finetuned", n_neighbors=15)
sc.tl.umap(processed_adata)
# Visualize clusters and selected TF activities
sc.pl.umap(processed_adata, color=["louvain", "PAX5", "FOS", "STAT1"], title="UMAP of TF Activity")
WARNING: The title list is shorter than the number of panels. Using 'color' value instead for some plots. WARNING: The title list is shorter than the number of panels. Using 'color' value instead for some plots. WARNING: The title list is shorter than the number of panels. Using 'color' value instead for some plots.
Heatmap of Differentially Active Transcription Factors (DA TFs)
We can identify and visualize differentially active transcription factors across clusters using the same approach as differential gene expression. Here, we use the wilcoxon method to rank TFs that vary across Louvain clusters.
sc.tl.rank_genes_groups(fine_tuned_tf_activities, groupby='louvain', method='wilcoxon')
sc.tl.dendrogram(fine_tuned_tf_activities, groupby='louvain', use_rep='X')
top_5_DA_TFs = fine_tuned_tf_activities.uns['rank_genes_groups']['names'][:5]
all_top_5_DA_TFs = np.concatenate([top_5_DA_TFs[col] for col in top_5_DA_TFs.dtype.names])
# Now plot
sc.pl.heatmap(
fine_tuned_tf_activities,
var_names=all_top_5_DA_TFs,
groupby='louvain',
dendrogram=False,
cmap='viridis',
show_gene_labels=True,
standard_scale="var",
swap_axes=True
)
This heatmap reveals regulatory programs that define each cluster. Although we're using rank_genes_groups, the values represent TF activity, not raw gene expression.
We can save the top 30 DA TFs per cluster:
top_DA_TFs = fine_tuned_tf_activities.uns['rank_genes_groups']['names'][:30]
top_DA_TFs = pd.DataFrame.from_records(top_DA_TFs, columns=top_DA_TFs.dtype.names)
top_DA_TFs.to_csv('top_TFs_pre.csv', index=False)
To summarize TF activity across clusters more compactly, we use a matrix plot, which shows average TF activity (Z-scored) per cluster. This complements the detailed heatmap and is especially useful for presentations or reports:
sc.pl.matrixplot(
fine_tuned_tf_activities,
var_names=all_top_DA_TFs, # Use the top 10 DA TFs
groupby='louvain', # Group by cell types
#dendrogram=True, # Show dendrogram for clustering
standard_scale='var', # Z-score normalization per gene
cmap='RdBu_r', # Red-Blue reversed colormap
colorbar_title='Z-scaled scores', # Title for the colorbar
swap_axes=True,
figsize=(5, 6) # Adjust figure size
)
5. Cell-type-specific fine-tuning¶
processed_adata, fine_tuned_tf_activities, final_model, GRN_final = reg.fine_tuning.fine_tune_clusters(
processed_adata=processed_adata,
model=fine_model,
cluster_key="louvain",
min_epochs=5000,
epochs=10000,
log_interval=1000,
tf_mapping_lr=4e-04, # Learning rate for tf_mapping layer
fc_output_lr=2e-05/100, # Learning rate for fc_output layer
default_lr=3.5e-05/100, # Default learning rate for other layers
)
2025-05-13 14:35:21,852 - finetune - INFO - Starting fine-tuning for cluster(s)... 2025-05-13 14:35:21,858 - finetune - INFO - Aligning data to 3278 genes matching the GRN. 2025-05-13 14:35:21,880 - finetune - INFO - Fine-tuning CD4 T cells for 10000 epochs... 2025-05-13 14:35:21,900 - finetune - INFO - Epoch 1, Avg Loss: 38.4518 2025-05-13 14:35:31,263 - finetune - INFO - Epoch 1001, Avg Loss: 36.6287 2025-05-13 14:35:40,492 - finetune - INFO - Epoch 2001, Avg Loss: 36.1222 2025-05-13 14:35:49,744 - finetune - INFO - Epoch 3001, Avg Loss: 35.8816 2025-05-13 14:35:59,151 - finetune - INFO - Epoch 4001, Avg Loss: 35.7924 2025-05-13 14:36:08,540 - finetune - INFO - Fine-tuning FCGR3A+ Monocytes for 10000 epochs... 2025-05-13 14:36:08,548 - finetune - INFO - Epoch 1, Avg Loss: 6.4385 2025-05-13 14:36:12,010 - finetune - INFO - Epoch 1001, Avg Loss: 3.7633 2025-05-13 14:36:15,472 - finetune - INFO - Epoch 2001, Avg Loss: 3.3483 2025-05-13 14:36:18,937 - finetune - INFO - Epoch 3001, Avg Loss: 3.3856 2025-05-13 14:36:22,401 - finetune - INFO - Epoch 4001, Avg Loss: 3.7342 2025-05-13 14:36:25,870 - finetune - INFO - Fine-tuning CD14+ Monocytes for 10000 epochs... 2025-05-13 14:36:25,883 - finetune - INFO - Epoch 1, Avg Loss: 17.2654 2025-05-13 14:36:31,209 - finetune - INFO - Epoch 1001, Avg Loss: 14.9461 2025-05-13 14:36:36,544 - finetune - INFO - Epoch 2001, Avg Loss: 14.2644 2025-05-13 14:36:41,891 - finetune - INFO - Epoch 3001, Avg Loss: 14.0190 2025-05-13 14:36:47,254 - finetune - INFO - Epoch 4001, Avg Loss: 14.0489 2025-05-13 14:36:52,615 - finetune - INFO - Fine-tuning CD8 T cells for 10000 epochs... 2025-05-13 14:36:52,624 - finetune - INFO - Epoch 1, Avg Loss: 10.9590 2025-05-13 14:36:56,982 - finetune - INFO - Epoch 1001, Avg Loss: 8.5738 2025-05-13 14:37:01,334 - finetune - INFO - Epoch 2001, Avg Loss: 8.0712 2025-05-13 14:37:05,673 - finetune - INFO - Epoch 3001, Avg Loss: 7.9476 2025-05-13 14:37:10,027 - finetune - INFO - Epoch 4001, Avg Loss: 8.1929 2025-05-13 14:37:14,373 - finetune - INFO - Fine-tuning B cells for 10000 epochs... 2025-05-13 14:37:14,382 - finetune - INFO - Epoch 1, Avg Loss: 10.1211 2025-05-13 14:37:18,967 - finetune - INFO - Epoch 1001, Avg Loss: 8.4657 2025-05-13 14:37:23,561 - finetune - INFO - Epoch 2001, Avg Loss: 8.1492 2025-05-13 14:37:28,148 - finetune - INFO - Epoch 3001, Avg Loss: 8.0634 2025-05-13 14:37:32,744 - finetune - INFO - Epoch 4001, Avg Loss: 8.1980 2025-05-13 14:37:37,336 - finetune - INFO - Fine-tuning NK cells for 10000 epochs... 2025-05-13 14:37:37,344 - finetune - INFO - Epoch 1, Avg Loss: 5.4363 2025-05-13 14:37:40,866 - finetune - INFO - Epoch 1001, Avg Loss: 3.2324 2025-05-13 14:37:44,381 - finetune - INFO - Epoch 2001, Avg Loss: 2.9269 2025-05-13 14:37:47,900 - finetune - INFO - Epoch 3001, Avg Loss: 3.1224 2025-05-13 14:37:51,414 - finetune - INFO - Epoch 4001, Avg Loss: 3.4899 2025-05-13 14:37:54,926 - finetune - INFO - Fine-tuning Dendritic cells for 10000 epochs... 2025-05-13 14:37:54,933 - finetune - INFO - Epoch 1, Avg Loss: 1.4314 2025-05-13 14:37:57,821 - finetune - INFO - Epoch 1001, Avg Loss: 0.3601 2025-05-13 14:38:00,704 - finetune - INFO - Epoch 2001, Avg Loss: 0.6294 2025-05-13 14:38:03,591 - finetune - INFO - Epoch 3001, Avg Loss: 0.9232 2025-05-13 14:38:06,470 - finetune - INFO - Epoch 4001, Avg Loss: 1.2135 2025-05-13 14:38:09,348 - finetune - INFO - Fine-tuning Megakaryocytes for 10000 epochs... 2025-05-13 14:38:09,365 - finetune - INFO - Epoch 1, Avg Loss: 0.2003 2025-05-13 14:38:11,870 - finetune - INFO - Epoch 1001, Avg Loss: 0.2499 2025-05-13 14:38:14,375 - finetune - INFO - Epoch 2001, Avg Loss: 0.4438 2025-05-13 14:38:16,882 - finetune - INFO - Epoch 3001, Avg Loss: 0.6341 2025-05-13 14:38:19,390 - finetune - INFO - Epoch 4001, Avg Loss: 0.8167 2025-05-13 14:38:21,915 - finetune - INFO - Fine-tuning completed for all clusters.
from sklearn.preprocessing import minmax_scale
from sklearn.metrics.pairwise import cosine_similarity
import seaborn as sns
W_posteriors_per_cluster = processed_adata.uns["W_posteriors_per_cluster"]
cell_type_columns = list(W_posteriors_per_cluster.keys())
# Correct: Build average W matrices following the exact cell_type_columns order
average_W_matrices = {
cell_type: minmax_scale(np.abs(W_posteriors_per_cluster[cell_type]).ravel()).reshape(W_posteriors_per_cluster[cell_type].shape).mean(axis=0)
for cell_type in cell_type_columns
}
# Combine average matrices into a DataFrame where rows = cell types and columns = TFs
combined_average_W = pd.DataFrame(average_W_matrices).T
# Compute cosine similarity across cell types
cosine_sim = np.clip(cosine_similarity(combined_average_W), 0, 1)
similarity_matrix = cosine_sim**8
# Create a DataFrame for the similarity matrix
similarity_df = pd.DataFrame(
similarity_matrix,
index=cell_type_columns,
columns=cell_type_columns
)
sns_plot = sns.clustermap(
similarity_df,
figsize=(12, 10),
annot=True,
fmt=".2f",
annot_kws={"size": 20},
cmap="RdBu_r",
cbar_kws={'label': 'Cosine Similarity'},
center=0,
cbar=False
)
sns_plot.cax.set_visible(False)
# 🔥 Modify the correct axes from sns_plot
sns_plot.ax_heatmap.set_xlabel("", fontsize=16)
sns_plot.ax_heatmap.set_ylabel("", fontsize=16)
sns_plot.ax_heatmap.set_xticklabels(sns_plot.ax_heatmap.get_xticklabels(), fontsize=18, rotation=45, ha='right')
sns_plot.ax_heatmap.set_yticklabels(sns_plot.ax_heatmap.get_yticklabels(), fontsize=18)
[Text(1, 0.5, 'Megakaryocytes'), Text(1, 1.5, 'B cells'), Text(1, 2.5, 'Dendritic cells'), Text(1, 3.5, 'NK cells'), Text(1, 4.5, 'CD4 T cells'), Text(1, 5.5, 'CD8 T cells'), Text(1, 6.5, 'FCGR3A+ Monocytes'), Text(1, 7.5, 'CD14+ Monocytes')]
Saving the model¶
import torch
import pickle
import os
import anndata as ad
processed_adata.modality
{'RNA': AnnData object with n_obs × n_vars = 2638 × 3644 obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain' var: 'n_cells' uns: 'type' obsm: 'ulm_estimate', 'ulm_pvals', 'TF': AnnData object with n_obs × n_vars = 2638 × 298 obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain' uns: 'type', 'recon_RNA': AnnData object with n_obs × n_vars = 2638 × 3278 obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain' uns: 'type', 'latent': AnnData object with n_obs × n_vars = 2638 × 40 obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain' uns: 'type'}
# We only need the RNA assay
# Paths to save components
save_dir = "saved_model_outputs"
os.makedirs(save_dir, exist_ok=True)
# Save PyTorch model
torch.save(fine_model.state_dict(), os.path.join(save_dir, "fine_model_PBMC.pt"))
# Save processed AnnData
if 'W_posteriors_per_cluster' in processed_adata.uns:
processed_adata.uns['W_posteriors_per_cluster'] = {
str(k): v for k, v in processed_adata.uns['W_posteriors_per_cluster'].items()
}
processed_adata.write(os.path.join(save_dir, "processed_adata_PBMC.h5ad"))
# Save TF activities (AnnData)
fine_tuned_tf_activities.write(os.path.join(save_dir, "tf_activities_PBMC.h5ad"))
# Save GRN as pickle (assuming it's a dict or matrix)
with open(os.path.join(save_dir, "GRN_PBMC.pkl"), "wb") as f:
pickle.dump(GRN, f)
Loading the model¶
# Load AnnData objects
save_dir = "./saved_model_outputs"
processed_adata = ad.read_h5ad(os.path.join(save_dir, "processed_adata_PBMC.h5ad"))
tf_activities = ad.read_h5ad(os.path.join(save_dir, "tf_activities_PBMC.h5ad"))
processed_adata.modality = {}
processed_adata.modality['RNA'] = processed_adata.copy()
processed_adata.modality['TF'] = tf_activities.copy()
# Load GRN
with open(os.path.join(save_dir, "GRN_PBMC.pkl"), "rb") as f:
GRN = pickle.load(f)
# Load model (make sure to reinitialize with the exact same architecture)
input_dim = GRN.shape[0]
tf_dim = GRN.shape[1]
encode_dims=[2048, 256, 64]
decode_dims=[256]
z_dim=40
fine_model = reg.scRNA_VAE(input_dim=input_dim, encode_dims=encode_dims, decode_dims=decode_dims, z_dim=z_dim, tf_dim=tf_dim)
fine_model.load_state_dict(torch.load(os.path.join(save_dir, "fine_model_PBMC.pt")))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fine_model = fine_model.to(device)