TF Activity and GRN Inference Tutorial¶

scRegulate v0.1¶

Compiled: May 13, 2025¶

In this tutorial, we demonstrate how to use scRegulate for in-silico transcription factor (TF) activity and gene regulatory network (GRN) inference using the PBMC 3K dataset. This dataset is a well-known 10X Genomics experiment and is readily accessible via scanpy.

The scRegulate pipeline consists of the following steps:

  1. Preprocessing: Filter low-quality cells and apply log-normalization.
  2. Prior GRN selection: Use TF-Link (default) or custom priors.
  3. Variational inference of TF activities using a VAE-based architecture. Followed by Global fine-tuning across all cells (no clustering needed).
  4. Downstream analysis: Extract TF activity embeddings, perform Leiden clustering, identify differentially active (DA) TFs, and reconstruct posterior GRNs.
  5. Optional but recommended if you know the labels: Cell-type-specific fine-tuning using provided cluster labels.

1. Preprocessing scRNA data¶

We are going to need pandas, scanpy and scregulate, thus we import these libraries first.

In [10]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import scregulate as reg

We then load the PBMC 3k data from scanpy as a convenient way to access the data:

In [2]:
rna_data_all = sc.datasets.pbmc3k_processed()

The dataset contains 2,638 cells. The processed expression matrix in .X includes 1,838 highly variable genes, while the raw counts matrix in .raw.X retains all 13,714 genes.

Cell-type labels are stored in the .obs attribute under the louvain column.

In [3]:
print("🔍 AnnData summary:")
print(rna_data_all)

print("\n📐 Shape of processed matrix (.X):")
print(f"{rna_data_all.X.shape}  # (cells, HVGs)")

print("\n📦 Shape of raw matrix (.raw.X):")
print(f"{rna_data_all.raw.X.shape}  # (cells, all genes)")

print("\n🏷️ Available Louvain cluster labels:")
print(rna_data_all.obs['louvain'].cat.categories.tolist())
🔍 AnnData summary:
AnnData object with n_obs × n_vars = 2638 × 1838
    obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain'
    var: 'n_cells'
    uns: 'draw_graph', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups'
    obsm: 'X_pca', 'X_tsne', 'X_umap', 'X_draw_graph_fr'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

📐 Shape of processed matrix (.X):
(2638, 1838)  # (cells, HVGs)

📦 Shape of raw matrix (.raw.X):
(2638, 13714)  # (cells, all genes)

🏷️ Available Louvain cluster labels:
['CD4 T cells', 'CD14+ Monocytes', 'B cells', 'CD8 T cells', 'NK cells', 'FCGR3A+ Monocytes', 'Dendritic cells', 'Megakaryocytes']

Converting to Raw Expression Matrix

By default, the rna_data_all.X matrix contains log-normalized expression of highly variable genes (HVGs). To work with the full raw gene expression matrix (13,714 genes), we extract the .raw.X layer and recreate a new AnnData object:

In [4]:
# Extract raw data matrix and relevant metadata
raw_X = rna_data_all.raw.X.copy()
obs = rna_data_all.obs.copy()        # Cell-level metadata
var = rna_data_all.raw.var.copy()    # Gene-level metadata

# Create a new AnnData object from raw expression
rna_data_all = sc.AnnData(X=raw_X, obs=obs, var=var)

# Verify the new AnnData object
print(rna_data_all)
AnnData object with n_obs × n_vars = 2638 × 13714
    obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain'
    var: 'n_cells'

Quality Control and Preprocessing

After loading the raw PBMC dataset, we perform some essential QC steps to remove low-quality cells and rarely expressed genes.

In [5]:
rna_data_all.var_names_make_unique()
rna_data_all.obs_names_make_unique()

sc.pp.filter_cells(rna_data_all, min_genes=3)
sc.pp.filter_genes(rna_data_all, min_cells=10)

# Shuffling with fixded seed
sc.pp.subsample(rna_data_all, fraction=1, random_state=42)

# Data is already normalized, otherwise run the below code
#sc.pp.normalize_total(rna_data_all)
#sc.pp.log1p(rna_data_all)

print('Final Data:')
print(f'rna_data_all.shape: {rna_data_all.shape}')
Final Data:
rna_data_all.shape: (2638, 11095)

2. Prior GRN selection¶

Loading the Prior GRN Network for TF Inference:

To guide the model in inferring transcription factor (TF) activities, we load a prior gene regulatory network (GRN) from the collectri_human_net.csv file. This prior encodes known TF–target gene relationships, which serve as structural constraints for downstream inference. COLLECTRI network is included in this package, however, any other prior GRN can be used too.

In [6]:
net = reg.collectri_prior("human")  # or "mouse"
print(net.head())
print("Network `collectri_human_net` loaded with shape:", net.shape)
   source target  weight                                               PMID
0     MYC   TERT       1  10022128;10491298;10606235;10637317;10723141;1...
1    SPI1  BGLAP       1                                           10022617
2   SMAD3    JUN       1                                  10022869;12374795
3   SMAD4    JUN       1                                  10022869;12374795
4  STAT5A    IL2       1  10022878;11435608;17182565;17911616;22854263;2...
Network `collectri_human_net` loaded with shape: (43178, 4)

⚠️ Note: This is the academic version of the COLLECTRI network. Please respect its academic use license. If you intend to use this data for commercial purposes, refer to the original COLLECTRI repository for licensing terms and contact the authors if needed. Alternatively, you are free to use your own GRN by providing a .csv file with two columns: TF and target.

3. Training and Global Fine-Tuning with scRegulate¶

We now train the scRegulate model to infer transcription factor (TF) activities and learn a regulatory embedding consistent with the provided GRN prior:

In [7]:
model, processed_adata, GRN = reg.train_model(  
    rna_data=rna_data_all, net=net,
    encode_dims=[2048, 256, 64], decode_dims=[256], z_dim=40,
    train_val_split_ratio=0.85,
    epochs=10000,
    freeze_epochs=4500,
    learning_rate=1e-4,
    alpha_max=1,
    alpha_scale=0.0025,
    early_stopping_patience=4500,
    min_targets=10 
)
[INFO - 2025-05-13 14:06:07] Using provided batch size: 3500
[INFO - 2025-05-13 14:06:07] ========================================
[INFO - 2025-05-13 14:06:07] Starting scRegulate TF inference Training Pipeline
[INFO - 2025-05-13 14:06:07] ========================================
[INFO - 2025-05-13 14:06:07] Adapting prior and data...
[INFO - 2025-05-13 14:06:07] Initial genes in RNA data: 11095
[INFO - 2025-05-13 14:06:07] Genes retained after intersection with network: 3644
[INFO - 2025-05-13 14:06:07] Initial TFs in GRN matrix: 1186
[INFO - 2025-05-13 14:06:08] Retained 3278 genes and 298 transcription factors.
[INFO - 2025-05-13 14:06:08] Splitting data with train-validation split ratio=0.85
[INFO - 2025-05-13 14:06:08] Running ULM...
[INFO - 2025-05-13 14:06:13] ULM completed in 5.17s
[INFO - 2025-05-13 14:06:13] Validating ULM estimates...
[INFO - 2025-05-13 14:06:13] ULM estimates validation passed. Shape: torch.Size([2638, 298])
[INFO - 2025-05-13 14:06:13] Transferring data to device cuda
[INFO - 2025-05-13 14:06:13] Transfer complete...
[INFO - 2025-05-13 14:06:13] Epoch 1: Avg Train Loss = 0.0003, Avg Val Loss = 0.0015, Alpha = 0.0025, Beta = 0.0000, Gamma = 0.0000, Mask Factor: 1.00
[INFO - 2025-05-13 14:06:24] Epoch 501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9500, Beta = 0.0333, Gamma = 0.2167, Mask Factor: 1.00
[INFO - 2025-05-13 14:06:35] Epoch 1001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9700, Beta = 0.0583, Gamma = 0.3787, Mask Factor: 1.00
[INFO - 2025-05-13 14:06:46] Epoch 1501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9725, Beta = 0.0835, Gamma = 0.5425, Mask Factor: 0.97
[INFO - 2025-05-13 14:06:57] Epoch 2001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9725, Beta = 0.0835, Gamma = 0.5425, Mask Factor: 0.75
[INFO - 2025-05-13 14:07:08] Epoch 2501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.25
[INFO - 2025-05-13 14:07:19] Epoch 3001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.03
[INFO - 2025-05-13 14:07:30] Epoch 3501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:07:41] Epoch 4001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:07:53] Epoch 4501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:08:04] Epoch 5001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:08:15] Epoch 5501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:08:26] Epoch 6001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:08:38] Epoch 6501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:08:49] Epoch 7001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:09:00] Epoch 7501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:09:11] Epoch 8001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:09:23] Epoch 8501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:09:34] Epoch 9001: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:09:46] Epoch 9501: Avg Train Loss = 0.0000, Avg Val Loss = 0.0001, Alpha = 0.9750, Beta = 0.1509, Gamma = 0.9806, Mask Factor: 0.00
[INFO - 2025-05-13 14:09:57] Training completed in 223.83s
[INFO - 2025-05-13 14:09:57] Default modality set to: RNA
[INFO - 2025-05-13 14:09:57] Current modality: RNA
[INFO - 2025-05-13 14:09:57] `GRN_prior` and `GRN_posterior` stored in the AnnData object under .uns
[INFO - 2025-05-13 14:09:57] ========================================
[INFO - 2025-05-13 14:09:57] [FINAL SUMMARY]
[INFO - 2025-05-13 14:09:57] Training stopped after 10000 epochs.
[INFO - 2025-05-13 14:09:57] Final Train Loss: 150.7574
[INFO - 2025-05-13 14:09:57] Final Valid Loss: 184.3279
[INFO - 2025-05-13 14:09:57] Final Alpha: 0.9750, Beta: 0.1509, Gamma: 0.9806
[INFO - 2025-05-13 14:09:57] Total Training Time: 223.83s
[INFO - 2025-05-13 14:09:57] Latent Space Shape: (2638, 40)
[INFO - 2025-05-13 14:09:57] TF Space Shape: (2638, 298)
[INFO - 2025-05-13 14:09:57] Reconstructed RNA Shape: (2638, 3278)
[INFO - 2025-05-13 14:09:57] Original RNA Shape: (2638, 3644)
[INFO - 2025-05-13 14:09:57] TFs: 298, Genes: 3644
[INFO - 2025-05-13 14:09:57] ========================================

This initializes and trains a VAE-based model where the encoder infers TF activities, constrained by the provided prior GRN. Training proceeds in two phases:

  1. Pretraining phase with frozen GRN structure (freeze_epochs=4500)

  2. Joint learning phase with full backpropagation

Once trained, we perform global fine-tuning to better capture subtle regulatory cues:

In [8]:
processed_adata, fine_tuned_tf_activities, fine_model, GRN = reg.fine_tuning.fine_tune_clusters(
    model=model, processed_adata=processed_adata, epochs=2000
)
2025-05-13 14:09:57,502 - finetune - INFO - Starting fine-tuning for cluster(s)...
2025-05-13 14:09:57,504 - finetune - INFO - Aligning data to 3278 genes matching the GRN.
2025-05-13 14:09:57,504 - finetune - INFO - Cluster key not provided, fine-tuning on all cells together...
2025-05-13 14:09:57,506 - finetune - INFO - Fine-tuning on all cells for 2000 epochs...
2025-05-13 14:09:57,555 - finetune - INFO - Epoch 1, Avg Loss: 102.7285
2025-05-13 14:09:59,848 - finetune - INFO - Epoch 101, Avg Loss: 100.3262
2025-05-13 14:10:02,079 - finetune - INFO - Epoch 201, Avg Loss: 99.4416
2025-05-13 14:10:04,436 - finetune - INFO - Epoch 301, Avg Loss: 98.7599
2025-05-13 14:10:06,745 - finetune - INFO - Epoch 401, Avg Loss: 98.1552
2025-05-13 14:10:09,058 - finetune - INFO - Epoch 501, Avg Loss: 97.5885
2025-05-13 14:10:11,358 - finetune - INFO - Epoch 601, Avg Loss: 97.0624
2025-05-13 14:10:13,637 - finetune - INFO - Epoch 701, Avg Loss: 96.6306
2025-05-13 14:10:15,917 - finetune - INFO - Epoch 801, Avg Loss: 96.1426
2025-05-13 14:10:18,260 - finetune - INFO - Epoch 901, Avg Loss: 95.7572
2025-05-13 14:10:20,509 - finetune - INFO - Epoch 1001, Avg Loss: 95.4072
2025-05-13 14:10:22,806 - finetune - INFO - Epoch 1101, Avg Loss: 95.1007
2025-05-13 14:10:25,167 - finetune - INFO - Epoch 1201, Avg Loss: 94.8288
2025-05-13 14:10:27,457 - finetune - INFO - Epoch 1301, Avg Loss: 94.3621
2025-05-13 14:10:29,771 - finetune - INFO - Epoch 1401, Avg Loss: 94.0332
2025-05-13 14:10:31,986 - finetune - INFO - Epoch 1501, Avg Loss: 93.6669
2025-05-13 14:10:34,301 - finetune - INFO - Epoch 1601, Avg Loss: 93.2813
2025-05-13 14:10:36,591 - finetune - INFO - Epoch 1701, Avg Loss: 92.9463
2025-05-13 14:10:38,861 - finetune - INFO - Epoch 1801, Avg Loss: 92.7483
2025-05-13 14:10:41,154 - finetune - INFO - Epoch 1901, Avg Loss: 92.4543
2025-05-13 14:10:43,585 - finetune - INFO - Epoch 2000, Avg Loss: 92.0715
2025-05-13 14:10:43,610 - finetune - INFO - Fine-tuning completed for all clusters.

This additional global fine-tuning step further refines TF activities across all cells jointly, enhancing biological signal and robustness. You can now proceed to analyze the TF activity embeddings, extract GRNs, perform clustering, and identify differentially active TFs across conditions or cell types.

🔔 Note: This is global fine-tuning. Cell-type-specific fine-tuning using cluster labels can be performed in a separate step, described later in this tutorial (section 5).

4. Downstream analysis¶

Once training and fine-tuning are complete, we treat the inferred transcription factor (TF) activities in scRegulate as we would normalized gene expression values. This makes the downstream workflow intuitive and fully compatible with standard single-cell analysis tools.

In fact, TF activity embeddings (processed_adata.obsm['TF_activity']) can be visualized, clustered, and compared across cell type labels (e.g. louvain) or conditions using the same methods typically applied to gene expression.

🎻 Violin Plot of TF Activity

We can visualize TF activity distributions across clusters using violin plots — just like we do for marker genes:

In [9]:
sc.pl.violin(fine_tuned_tf_activities, keys=['PAX5','EOMES'], rotation=45, groupby='louvain', jitter=False)
No description has been provided for this image

UMAP Visualization of Inferred TF Activity

After fine-tuning the model, we can visualize cell states using the inferred transcription factor (TF) activities, just like we would with gene expression.

In [ ]:
# Use the TF activity representation for UMAP
sc.pp.neighbors(processed_adata, use_rep="TF_finetuned", n_neighbors=15)
sc.tl.umap(processed_adata)

# Visualize clusters and selected TF activities
sc.pl.umap(processed_adata, color=["louvain", "PAX5", "FOS", "STAT1"], title="UMAP of TF Activity")
WARNING: The title list is shorter than the number of panels. Using 'color' value instead for some plots.
WARNING: The title list is shorter than the number of panels. Using 'color' value instead for some plots.
WARNING: The title list is shorter than the number of panels. Using 'color' value instead for some plots.
No description has been provided for this image

Heatmap of Differentially Active Transcription Factors (DA TFs)

We can identify and visualize differentially active transcription factors across clusters using the same approach as differential gene expression. Here, we use the wilcoxon method to rank TFs that vary across Louvain clusters.

In [12]:
sc.tl.rank_genes_groups(fine_tuned_tf_activities, groupby='louvain', method='wilcoxon')
sc.tl.dendrogram(fine_tuned_tf_activities, groupby='louvain', use_rep='X')
top_5_DA_TFs = fine_tuned_tf_activities.uns['rank_genes_groups']['names'][:5]
all_top_5_DA_TFs = np.concatenate([top_5_DA_TFs[col] for col in top_5_DA_TFs.dtype.names])


# Now plot
sc.pl.heatmap(
    fine_tuned_tf_activities,
    var_names=all_top_5_DA_TFs,
    groupby='louvain',
    dendrogram=False,
    cmap='viridis',
    show_gene_labels=True,
    standard_scale="var",
    swap_axes=True
)
No description has been provided for this image

This heatmap reveals regulatory programs that define each cluster. Although we're using rank_genes_groups, the values represent TF activity, not raw gene expression.

We can save the top 30 DA TFs per cluster:

In [ ]:
top_DA_TFs = fine_tuned_tf_activities.uns['rank_genes_groups']['names'][:30]
top_DA_TFs = pd.DataFrame.from_records(top_DA_TFs, columns=top_DA_TFs.dtype.names)
top_DA_TFs.to_csv('top_TFs_pre.csv', index=False) 

To summarize TF activity across clusters more compactly, we use a matrix plot, which shows average TF activity (Z-scored) per cluster. This complements the detailed heatmap and is especially useful for presentations or reports:

In [16]:
sc.pl.matrixplot(
    fine_tuned_tf_activities,
    var_names=all_top_DA_TFs,  # Use the top 10 DA TFs
    groupby='louvain',        # Group by cell types
    #dendrogram=True,           # Show dendrogram for clustering
    standard_scale='var',      # Z-score normalization per gene
    cmap='RdBu_r',             # Red-Blue reversed colormap
    colorbar_title='Z-scaled scores',  # Title for the colorbar
    swap_axes=True,
    figsize=(5, 6)            # Adjust figure size
)
No description has been provided for this image

5. Cell-type-specific fine-tuning¶

In [19]:
processed_adata, fine_tuned_tf_activities, final_model, GRN_final = reg.fine_tuning.fine_tune_clusters(
    processed_adata=processed_adata,
    model=fine_model,
    cluster_key="louvain",
    min_epochs=5000,
    epochs=10000,
    log_interval=1000,
    tf_mapping_lr=4e-04,  # Learning rate for tf_mapping layer
    fc_output_lr=2e-05/100,   # Learning rate for fc_output layer
    default_lr=3.5e-05/100,     # Default learning rate for other layers
)
2025-05-13 14:35:21,852 - finetune - INFO - Starting fine-tuning for cluster(s)...
2025-05-13 14:35:21,858 - finetune - INFO - Aligning data to 3278 genes matching the GRN.
2025-05-13 14:35:21,880 - finetune - INFO - Fine-tuning CD4 T cells for 10000 epochs...
2025-05-13 14:35:21,900 - finetune - INFO - Epoch 1, Avg Loss: 38.4518
2025-05-13 14:35:31,263 - finetune - INFO - Epoch 1001, Avg Loss: 36.6287
2025-05-13 14:35:40,492 - finetune - INFO - Epoch 2001, Avg Loss: 36.1222
2025-05-13 14:35:49,744 - finetune - INFO - Epoch 3001, Avg Loss: 35.8816
2025-05-13 14:35:59,151 - finetune - INFO - Epoch 4001, Avg Loss: 35.7924
2025-05-13 14:36:08,540 - finetune - INFO - Fine-tuning FCGR3A+ Monocytes for 10000 epochs...
2025-05-13 14:36:08,548 - finetune - INFO - Epoch 1, Avg Loss: 6.4385
2025-05-13 14:36:12,010 - finetune - INFO - Epoch 1001, Avg Loss: 3.7633
2025-05-13 14:36:15,472 - finetune - INFO - Epoch 2001, Avg Loss: 3.3483
2025-05-13 14:36:18,937 - finetune - INFO - Epoch 3001, Avg Loss: 3.3856
2025-05-13 14:36:22,401 - finetune - INFO - Epoch 4001, Avg Loss: 3.7342
2025-05-13 14:36:25,870 - finetune - INFO - Fine-tuning CD14+ Monocytes for 10000 epochs...
2025-05-13 14:36:25,883 - finetune - INFO - Epoch 1, Avg Loss: 17.2654
2025-05-13 14:36:31,209 - finetune - INFO - Epoch 1001, Avg Loss: 14.9461
2025-05-13 14:36:36,544 - finetune - INFO - Epoch 2001, Avg Loss: 14.2644
2025-05-13 14:36:41,891 - finetune - INFO - Epoch 3001, Avg Loss: 14.0190
2025-05-13 14:36:47,254 - finetune - INFO - Epoch 4001, Avg Loss: 14.0489
2025-05-13 14:36:52,615 - finetune - INFO - Fine-tuning CD8 T cells for 10000 epochs...
2025-05-13 14:36:52,624 - finetune - INFO - Epoch 1, Avg Loss: 10.9590
2025-05-13 14:36:56,982 - finetune - INFO - Epoch 1001, Avg Loss: 8.5738
2025-05-13 14:37:01,334 - finetune - INFO - Epoch 2001, Avg Loss: 8.0712
2025-05-13 14:37:05,673 - finetune - INFO - Epoch 3001, Avg Loss: 7.9476
2025-05-13 14:37:10,027 - finetune - INFO - Epoch 4001, Avg Loss: 8.1929
2025-05-13 14:37:14,373 - finetune - INFO - Fine-tuning B cells for 10000 epochs...
2025-05-13 14:37:14,382 - finetune - INFO - Epoch 1, Avg Loss: 10.1211
2025-05-13 14:37:18,967 - finetune - INFO - Epoch 1001, Avg Loss: 8.4657
2025-05-13 14:37:23,561 - finetune - INFO - Epoch 2001, Avg Loss: 8.1492
2025-05-13 14:37:28,148 - finetune - INFO - Epoch 3001, Avg Loss: 8.0634
2025-05-13 14:37:32,744 - finetune - INFO - Epoch 4001, Avg Loss: 8.1980
2025-05-13 14:37:37,336 - finetune - INFO - Fine-tuning NK cells for 10000 epochs...
2025-05-13 14:37:37,344 - finetune - INFO - Epoch 1, Avg Loss: 5.4363
2025-05-13 14:37:40,866 - finetune - INFO - Epoch 1001, Avg Loss: 3.2324
2025-05-13 14:37:44,381 - finetune - INFO - Epoch 2001, Avg Loss: 2.9269
2025-05-13 14:37:47,900 - finetune - INFO - Epoch 3001, Avg Loss: 3.1224
2025-05-13 14:37:51,414 - finetune - INFO - Epoch 4001, Avg Loss: 3.4899
2025-05-13 14:37:54,926 - finetune - INFO - Fine-tuning Dendritic cells for 10000 epochs...
2025-05-13 14:37:54,933 - finetune - INFO - Epoch 1, Avg Loss: 1.4314
2025-05-13 14:37:57,821 - finetune - INFO - Epoch 1001, Avg Loss: 0.3601
2025-05-13 14:38:00,704 - finetune - INFO - Epoch 2001, Avg Loss: 0.6294
2025-05-13 14:38:03,591 - finetune - INFO - Epoch 3001, Avg Loss: 0.9232
2025-05-13 14:38:06,470 - finetune - INFO - Epoch 4001, Avg Loss: 1.2135
2025-05-13 14:38:09,348 - finetune - INFO - Fine-tuning Megakaryocytes for 10000 epochs...
2025-05-13 14:38:09,365 - finetune - INFO - Epoch 1, Avg Loss: 0.2003
2025-05-13 14:38:11,870 - finetune - INFO - Epoch 1001, Avg Loss: 0.2499
2025-05-13 14:38:14,375 - finetune - INFO - Epoch 2001, Avg Loss: 0.4438
2025-05-13 14:38:16,882 - finetune - INFO - Epoch 3001, Avg Loss: 0.6341
2025-05-13 14:38:19,390 - finetune - INFO - Epoch 4001, Avg Loss: 0.8167
2025-05-13 14:38:21,915 - finetune - INFO - Fine-tuning completed for all clusters.
In [24]:
from sklearn.preprocessing import minmax_scale
from sklearn.metrics.pairwise import cosine_similarity
import seaborn as sns

W_posteriors_per_cluster = processed_adata.uns["W_posteriors_per_cluster"]
cell_type_columns = list(W_posteriors_per_cluster.keys())
# Correct: Build average W matrices following the exact cell_type_columns order

average_W_matrices = {
    cell_type: minmax_scale(np.abs(W_posteriors_per_cluster[cell_type]).ravel()).reshape(W_posteriors_per_cluster[cell_type].shape).mean(axis=0)
    for cell_type in cell_type_columns
}

# Combine average matrices into a DataFrame where rows = cell types and columns = TFs
combined_average_W = pd.DataFrame(average_W_matrices).T

# Compute cosine similarity across cell types
cosine_sim = np.clip(cosine_similarity(combined_average_W), 0, 1)
similarity_matrix = cosine_sim**8

# Create a DataFrame for the similarity matrix
similarity_df = pd.DataFrame(
    similarity_matrix, 
    index=cell_type_columns, 
    columns=cell_type_columns
)

sns_plot = sns.clustermap(
    similarity_df, 
    figsize=(12, 10),
    annot=True, 
    fmt=".2f", 
    annot_kws={"size": 20},
    cmap="RdBu_r",
    cbar_kws={'label': 'Cosine Similarity'},
    center=0,
    cbar=False
)
sns_plot.cax.set_visible(False)

# 🔥 Modify the correct axes from sns_plot
sns_plot.ax_heatmap.set_xlabel("", fontsize=16)
sns_plot.ax_heatmap.set_ylabel("", fontsize=16)

sns_plot.ax_heatmap.set_xticklabels(sns_plot.ax_heatmap.get_xticklabels(), fontsize=18, rotation=45, ha='right')
sns_plot.ax_heatmap.set_yticklabels(sns_plot.ax_heatmap.get_yticklabels(), fontsize=18)
Out[24]:
[Text(1, 0.5, 'Megakaryocytes'),
 Text(1, 1.5, 'B cells'),
 Text(1, 2.5, 'Dendritic cells'),
 Text(1, 3.5, 'NK cells'),
 Text(1, 4.5, 'CD4 T cells'),
 Text(1, 5.5, 'CD8 T cells'),
 Text(1, 6.5, 'FCGR3A+ Monocytes'),
 Text(1, 7.5, 'CD14+ Monocytes')]
No description has been provided for this image

Saving the model¶

In [25]:
import torch
import pickle
import os
import anndata as ad
In [26]:
processed_adata.modality
Out[26]:
{'RNA': AnnData object with n_obs × n_vars = 2638 × 3644
     obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain'
     var: 'n_cells'
     uns: 'type'
     obsm: 'ulm_estimate', 'ulm_pvals',
 'TF': AnnData object with n_obs × n_vars = 2638 × 298
     obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain'
     uns: 'type',
 'recon_RNA': AnnData object with n_obs × n_vars = 2638 × 3278
     obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain'
     uns: 'type',
 'latent': AnnData object with n_obs × n_vars = 2638 × 40
     obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain'
     uns: 'type'}
In [27]:
# We only need the RNA assay

# Paths to save components
save_dir = "saved_model_outputs"
os.makedirs(save_dir, exist_ok=True)

# Save PyTorch model
torch.save(fine_model.state_dict(), os.path.join(save_dir, "fine_model_PBMC.pt"))

# Save processed AnnData
if 'W_posteriors_per_cluster' in processed_adata.uns:
    processed_adata.uns['W_posteriors_per_cluster'] = {
        str(k): v for k, v in processed_adata.uns['W_posteriors_per_cluster'].items()
    }
processed_adata.write(os.path.join(save_dir, "processed_adata_PBMC.h5ad"))

# Save TF activities (AnnData)
fine_tuned_tf_activities.write(os.path.join(save_dir, "tf_activities_PBMC.h5ad"))

# Save GRN as pickle (assuming it's a dict or matrix)
with open(os.path.join(save_dir, "GRN_PBMC.pkl"), "wb") as f:
    pickle.dump(GRN, f)

Loading the model¶

In [29]:
# Load AnnData objects
save_dir = "./saved_model_outputs"
processed_adata = ad.read_h5ad(os.path.join(save_dir, "processed_adata_PBMC.h5ad"))
tf_activities = ad.read_h5ad(os.path.join(save_dir, "tf_activities_PBMC.h5ad"))
processed_adata.modality = {}
processed_adata.modality['RNA'] = processed_adata.copy()
processed_adata.modality['TF'] = tf_activities.copy()

# Load GRN
with open(os.path.join(save_dir, "GRN_PBMC.pkl"), "rb") as f:
    GRN = pickle.load(f)

# Load model (make sure to reinitialize with the exact same architecture)
input_dim = GRN.shape[0]
tf_dim = GRN.shape[1]
encode_dims=[2048, 256, 64]
decode_dims=[256]
z_dim=40

fine_model = reg.scRNA_VAE(input_dim=input_dim, encode_dims=encode_dims, decode_dims=decode_dims, z_dim=z_dim, tf_dim=tf_dim)
fine_model.load_state_dict(torch.load(os.path.join(save_dir, "fine_model_PBMC.pt")))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fine_model = fine_model.to(device)
In [ ]: