Genomic Data Analysis Transformer Encoder Example¶
A compact notebook showing a Transformer encoder workflow for genomic / single-cell data:
- scRNA-seq data from healty and diabetic pancreas
- expression binning
- gene and bin embeddings
- small Transformer encoder
- classifier training
- attention-based module discovery
Single-cell RNA-seq analysis of human pancreas from healthy individuals and type 2 diabetes patients¶
Segerstolpe, Åsa, et al. "Single-cell transcriptome profiling of human pancreatic islets in health and type 2 diabetes." Cell metabolism 24.4 (2016): 593-607.
Download the scRNA-seq data from¶
https://www.ebi.ac.uk/biostudies/arrayexpress/studies/E-MTAB-5061
In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import random
import matplotlib.pyplot as plt
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
Device: cpu
In [2]:
import os
current_directory = os.getcwd()
print(f"Current working directory: {current_directory}")
Current working directory: /Users/hoyen/Desktop/STAT718-2026/PythonCode/Transformer
In [3]:
from __future__ import annotations
from pathlib import Path
import pandas as pd
def load_expression_matrix(matrix_path):
matrix_path = Path(matrix_path)
# Read the first line to get sample IDs
with matrix_path.open("r", encoding="utf-8") as fh:
first_line = fh.readline().rstrip("\n")
parts = first_line.split("\t")
if not parts or not parts[0].startswith("#samples"):
raise ValueError(f"Unexpected first line: {parts[:5]}")
sample_ids = parts[1:]
n = len(sample_ids)
# Read the table, skipping the first comment line
df = pd.read_csv(
matrix_path,
sep="\t",
comment="#",
header=None,
engine="python",
)
print("Raw shape:", df.shape)
# Expected: 2 ID columns + n RPKM + n counts
expected_cols = 2 + 2 * n
if df.shape[1] != expected_cols:
raise ValueError(
f"Expected {expected_cols} columns total "
f"(2 IDs + {n} RPKM + {n} counts), but got {df.shape[1]}."
)
gene_symbol = df.iloc[:, 0].astype(str)
transcript_id = df.iloc[:, 1].astype(str)
rpkm = df.iloc[:, 2 : 2 + n].copy()
counts = df.iloc[:, 2 + n : 2 + 2 * n].copy()
rpkm.columns = sample_ids
counts.columns = sample_ids
gene_symbol = df.iloc[:, 0].astype(str)
transcript_id = df.iloc[:, 1].astype(str)
feature_id = gene_symbol + "|" + transcript_id
# feature_id = gene_symbol
rpkm.index = feature_id
counts.index = feature_id
return rpkm, counts
def load_sdrf_metadata(sdrf_path):
sdrf_path = Path(sdrf_path)
meta = pd.read_csv(sdrf_path, sep="\t", dtype=str, low_memory=False)
def find_col(*keywords):
for c in meta.columns:
cl = c.lower()
if all(k in cl for k in keywords):
return c
return None
# ---------- Find sample column ----------
sample_col = None
for candidate in ["source name", "sample name", "assay name"]:
if candidate in [c.lower().strip() for c in meta.columns]:
sample_col = next(
c for c in meta.columns if c.lower().strip() == candidate
)
break
# ---------- Find cell type ----------
celltype_col = None
for c in meta.columns:
if "inferred cell type" in c.lower():
celltype_col = c
break
# ---------- Find disease ----------
disease_col = None
for c in meta.columns:
if "disease" in c.lower():
disease_col = c
break
# ---------- Find quality ----------
quality_col = None
for c in meta.columns:
if "submitted single cell quality" in c.lower():
quality_col = c
break
# ---------- Build cell_meta ----------
cell_meta = pd.DataFrame(index=meta[sample_col].astype(str))
cell_meta.index.name = "sample_id"
cell_meta["cell_type"] = (
meta[celltype_col].values if celltype_col is not None else None
)
cell_meta["disease"] = (
meta[disease_col].values if disease_col is not None else None
)
cell_meta["quality"] = (
meta[quality_col].values if quality_col is not None else None
)
return cell_meta
base_path = Path("/Users/hoyen/Desktop/STAT718-2026/PythonCode/Transformer/E-MTAB-5061")
matrix_file = base_path / "count.txt"
sdrf_file = base_path / "E-MTAB-5061.sdrf.txt"
rpkm, counts = load_expression_matrix(matrix_file)
cell_meta = load_sdrf_metadata(sdrf_file)
Raw shape: (26271, 7030)
Remove low quality cells¶
In [4]:
ok_mask = cell_meta["quality"].str.strip().str.lower().eq("ok")
cell_meta_ok = cell_meta.loc[ok_mask].copy()
rpkm_ok = rpkm.loc[:, rpkm.columns.intersection(cell_meta_ok.index)]
cell_meta_ok.to_csv(base_path / "E-MTAB-5061_cell_metadata.tsv", sep="\t", index=False)
In [5]:
cell_meta_ok.shape
Out[5]:
(2209, 3)
In [6]:
rpkm_ok.shape
Out[6]:
(26271, 2209)
In [7]:
cell_meta_ok.head
Out[7]:
<bound method NDFrame.head of cell_type disease quality sample_id AZ_A10 delta cell normal OK AZ_A11 alpha cell normal OK AZ_A12 delta cell normal OK AZ_A2 gamma cell normal OK AZ_A5 ductal cell normal OK ... ... ... ... HP1526901T2D_P22 delta cell type II diabetes mellitus OK HP1526901T2D_P23 gamma cell type II diabetes mellitus OK HP1526901T2D_P4 beta cell type II diabetes mellitus OK HP1526901T2D_P7 beta cell type II diabetes mellitus OK HP1526901T2D_P9 beta cell type II diabetes mellitus OK [2209 rows x 3 columns]>
In [8]:
# X: cells x genes
X = rpkm_ok.T
# y: cell types aligned to X
y = cell_meta_ok.loc[X.index, "cell_type"]
In [9]:
print(X.shape) # (num_cells, num_genes)
print(y.shape) # (num_cells,)
(2209, 26271) (2209,)
In [10]:
# check alignment
print((X.index == y.index).all()) # should be True
True
In [11]:
### select 2000 highly variable genes
mean = X.mean(axis=0)
var = X.var(axis=0)
dispersion = var / (mean + 1e-6)
top_genes = dispersion.sort_values(ascending=False).head(2000).index
X_hvg=X[top_genes]
X_hvg.shape
Out[11]:
(2209, 2001)
In [12]:
duplicates = X_hvg.columns[X_hvg.columns.duplicated()]
print("Number of duplicate columns:", len(duplicates))
print("Duplicate column names:", duplicates.tolist())
Number of duplicate columns: 1 Duplicate column names: ['CFC1B|NM_001079530']
In [13]:
import pandas as pd
import numpy as np
# variance for every column position, even if names repeat
col_var = X_hvg.var(axis=0)
# keep the position with the highest variance for each column name
best_pos = {}
for i, col in enumerate(X_hvg.columns):
v = col_var.iloc[i]
if col not in best_pos or v > best_pos[col][0]:
best_pos[col] = (v, i)
# preserve original order of the kept columns
keep_idx = sorted(pos for _, pos in best_pos.values())
X_hvg_nodup = X_hvg.iloc[:, keep_idx]
print("Before:", X_hvg.shape)
print("After:", X_hvg_nodup.shape)
print("Any duplicated columns left?", X_hvg_nodup.columns.duplicated().any())
Before: (2209, 2001) After: (2209, 2000) Any duplicated columns left? False
In [14]:
disease = cell_meta_ok.loc[X_hvg_nodup.index, "disease"].str.strip().str.lower()
disease.head
Out[14]:
<bound method NDFrame.head of HP1502401_H13 normal
HP1502401_J14 normal
HP1502401_B14 normal
HP1502401_A14 normal
HP1502401_C14 normal
...
HP1526901T2D_F7 type ii diabetes mellitus
HP1525301T2D_K3 type ii diabetes mellitus
HP1525301T2D_J10 type ii diabetes mellitus
HP1526901T2D_N8 type ii diabetes mellitus
HP1526901T2D_A8 type ii diabetes mellitus
Name: disease, Length: 2209, dtype: object>
In [15]:
normal_mask = disease.eq("normal")
diabetes_mask = disease.str.contains("diabetes", na=False)
X1 = X_hvg_nodup.loc[normal_mask]
y1 = y.loc[normal_mask]
X2 = X_hvg_nodup.loc[diabetes_mask]
y2 = y.loc[diabetes_mask]
In [16]:
print(X1.shape)
print(y1.shape)
(1097, 2000) (1097,)
In [17]:
print(X2.shape)
print(y2.shape)
(1112, 2000) (1112,)
In [18]:
print("Normal (y1):")
print(y1.value_counts())
print("\nDiabetes (y2):")
print(y2.value_counts())
Normal (y1): cell_type alpha cell 443 beta cell 171 ductal cell 135 acinar cell 112 gamma cell 75 delta cell 59 unclassified endocrine cell 29 co-expression cell 26 PSC cell 23 endothelial cell 13 epsilon cell 5 mast cell 4 MHC class II cell 1 unclassified cell 1 Name: count, dtype: int64 Diabetes (y2): cell_type alpha cell 443 ductal cell 251 gamma cell 122 beta cell 99 acinar cell 73 delta cell 55 PSC cell 31 co-expression cell 13 unclassified endocrine cell 12 MHC class II cell 4 endothelial cell 3 mast cell 3 epsilon cell 2 unclassified cell 1 Name: count, dtype: int64
In [19]:
### remove cell with less than 10 cells in healthy cells
counts = y1.value_counts()
valid_classes = counts[counts >= 10].index
mask1 = y1.isin(valid_classes)
y1 = y1[mask1]
X1 = X1[mask1]
print(y1.value_counts())
cell_type alpha cell 443 beta cell 171 ductal cell 135 acinar cell 112 gamma cell 75 delta cell 59 unclassified endocrine cell 29 co-expression cell 26 PSC cell 23 endothelial cell 13 Name: count, dtype: int64
In [20]:
### remove cell with less than 10 cells in healthy cells
counts = y2.value_counts()
valid_classes = counts[counts >= 10].index
mask2 = y2.isin(valid_classes)
y2 = y2[mask2]
X2 = X2[mask2]
print(y2.value_counts())
cell_type alpha cell 443 ductal cell 251 gamma cell 122 beta cell 99 acinar cell 73 delta cell 55 PSC cell 31 co-expression cell 13 unclassified endocrine cell 12 Name: count, dtype: int64
In [21]:
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
# fit on combined labels to ensure consistent encoding across both datasets
le.fit(np.concatenate([y1, y2]))
y1_encoded = le.transform(y1)
y2_encoded = le.transform(y2)
print(le.classes_)
['PSC cell' 'acinar cell' 'alpha cell' 'beta cell' 'co-expression cell' 'delta cell' 'ductal cell' 'endothelial cell' 'gamma cell' 'unclassified endocrine cell']
In [22]:
y1_encoded
Out[22]:
array([8, 2, 3, ..., 8, 8, 2])
In [23]:
y2_encoded
Out[23]:
array([8, 2, 2, ..., 5, 2, 2])
Binning and embeddings from healthy cells¶
In [24]:
def percentile_bin_global(expr_mat, n_bins=7):
expr_mat = expr_mat.to_numpy()
flat = expr_mat.flatten()
q = np.nanpercentile(flat, np.linspace(0, 100, n_bins + 1))
binned = np.digitize(expr_mat, q[1:-1], right=True).astype(np.int64)
return np.clip(binned, 0, n_bins - 1) # ← clamp to valid range [0, 6]
n_bins = 7
binned1 = percentile_bin_global(X1, n_bins=n_bins)
print(binned1.min(), binned1.max())
0 6
In [26]:
n_genes = X1.shape[1]
print(f"n_genes: {n_genes}, X1 shape: {X1.shape}, binned1 shape: {binned1.shape}")
n_genes: 2000, X1 shape: (1086, 2000), binned1 shape: (1086, 2000)
In [27]:
n_classes = len(np.unique(y1))
gene_emb_dim = 128
gene_emb1 = nn.Embedding(n_genes, gene_emb_dim).to(device)
bin_emb1 = nn.Embedding(n_bins, gene_emb_dim).to(device)
genes = X1.columns.tolist()
def tokens_from_bins(binned_np, gene_emb, bin_emb, device):
b = torch.from_numpy(binned_np).long().to(device)
batch, seq_len = b.shape
gene_idx = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch, 1)
return gene_emb(gene_idx) + bin_emb(b)
# call:
print(binned1.shape, tokens_from_bins(binned1[:2], gene_emb1, bin_emb1, device).shape)
(1086, 2000) torch.Size([2, 2000, 128])
Transformer encoder with attention weights¶
In [28]:
class TransformerLayerWithAttn(nn.Module):
def __init__(self, d_model, nhead, dim_ff=256, dropout=0.1):
super().__init__()
self.mha = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(nn.Linear(d_model, dim_ff), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim_ff, d_model))
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_out, attn_w = self.mha(x, x, x, need_weights=True)
x = self.norm1(x + attn_out)
x = self.norm2(x + self.ff(x))
return x, attn_w
class TinyTransformerEncoder(nn.Module):
def __init__(self, d_model, nhead, n_layers=2):
super().__init__()
self.layers = nn.ModuleList([TransformerLayerWithAttn(d_model, nhead) for _ in range(n_layers)])
def forward(self, x):
attns = []
for layer in self.layers:
x, aw = layer(x)
attns.append(aw)
return x, attns
class ClassifierHead(nn.Module):
def __init__(self, d_model, n_classes, hidden=256):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(d_model, hidden), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden, n_classes))
def forward(self, enc_out):
return self.mlp(enc_out.mean(dim=1))
class ScBERTToy(nn.Module):
def __init__(self, encoder, classifier):
super().__init__()
self.encoder = encoder
self.classifier = classifier
def forward(self, tok_emb):
enc_out, attns = self.encoder(tok_emb)
logits = self.classifier(enc_out)
return logits, enc_out, attns
encoder = TinyTransformerEncoder(gene_emb_dim, nhead=4, n_layers=2).to(device)
classifier = ClassifierHead(gene_emb_dim, n_classes).to(device)
model = ScBERTToy(encoder, classifier).to(device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(gene_emb1.parameters()) + list(bin_emb1.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss()
print(model)
ScBERTToy(
(encoder): TinyTransformerEncoder(
(layers): ModuleList(
(0-1): 2 x TransformerLayerWithAttn(
(mha): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
)
(norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(ff): Sequential(
(0): Linear(in_features=128, out_features=256, bias=True)
(1): ReLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=256, out_features=128, bias=True)
)
(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
)
)
(classifier): ClassifierHead(
(mlp): Sequential(
(0): Linear(in_features=128, out_features=256, bias=True)
(1): ReLU()
(2): Dropout(p=0.2, inplace=False)
(3): Linear(in_features=256, out_features=10, bias=True)
)
)
)
Data loaders¶
In [29]:
class BinnedTokenDataset(Dataset):
def __init__(self, binned_matrix, labels):
self.binned = binned_matrix.astype(np.int64)
self.labels = labels.astype(np.int64)
def __len__(self):
return self.binned.shape[0]
def __getitem__(self, idx):
return self.binned[idx], int(self.labels[idx])
train_idx, test_idx = train_test_split(np.arange(len(y1_encoded)), test_size=0.3, stratify=y1_encoded, random_state=42)
train_ds = BinnedTokenDataset(binned1[train_idx], y1_encoded[train_idx])
test_ds = BinnedTokenDataset(binned1[test_idx], y1_encoded[test_idx])
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)
print(len(train_ds), len(test_ds))
760 326
Training¶
In [30]:
def train_epoch():
model.train()
total = 0.0
preds_all = []
labels_all = []
for bins_np, labels in train_loader:
bins_np = np.array(bins_np)
labels = labels.to(device)
tok = tokens_from_bins(bins_np, gene_emb1, bin_emb1, device)
optimizer.zero_grad()
logits, _, _ = model(tok)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total += float(loss.item()) * len(labels)
preds_all.append(logits.argmax(dim=1).detach().cpu().numpy())
labels_all.append(labels.detach().cpu().numpy())
preds_all = np.concatenate(preds_all)
labels_all = np.concatenate(labels_all)
return total / len(train_ds), accuracy_score(labels_all, preds_all)
@torch.no_grad()
def evaluate():
model.eval()
total = 0.0
preds_all = []
labels_all = []
for bins_np, labels in test_loader:
bins_np = np.array(bins_np)
labels = labels.to(device)
tok = tokens_from_bins(bins_np, gene_emb1, bin_emb1, device)
logits, _, _ = model(tok)
loss = criterion(logits, labels)
total += float(loss.item()) * len(labels)
preds_all.append(logits.argmax(dim=1).detach().cpu().numpy())
labels_all.append(labels.detach().cpu().numpy())
preds_all = np.concatenate(preds_all)
labels_all = np.concatenate(labels_all)
return total / len(test_ds), accuracy_score(labels_all, preds_all), labels_all, preds_all
for epoch in range(1, 7):
tr_loss, tr_acc = train_epoch()
te_loss, te_acc, _, _ = evaluate()
print(f'Epoch {epoch:02d} | train_loss={tr_loss:.4f} acc={tr_acc:.4f} | val_loss={te_loss:.4f} acc={te_acc:.4f}')
Epoch 01 | train_loss=1.8642 acc=0.4000 | val_loss=1.7437 acc=0.4110 Epoch 02 | train_loss=1.6885 acc=0.4355 | val_loss=1.5680 acc=0.4693 Epoch 03 | train_loss=1.3955 acc=0.5184 | val_loss=1.3400 acc=0.5153 Epoch 04 | train_loss=1.2928 acc=0.5461 | val_loss=1.1991 acc=0.5675 Epoch 05 | train_loss=1.1484 acc=0.6026 | val_loss=1.1677 acc=0.6104 Epoch 06 | train_loss=1.1118 acc=0.6079 | val_loss=1.0369 acc=0.6411
Attention-based module discovery¶
In [31]:
model.eval()
attn_sums = []
labels_test = []
pooled = []
with torch.no_grad():
for bins_np, labels in test_loader:
bins_np = np.array(bins_np)
tok = tokens_from_bins(bins_np, gene_emb1, bin_emb1, device)
logits, enc_out, attn_list = model(tok)
aw = torch.stack(attn_list, dim=0).cpu().numpy() # (layers, batch, heads, S, S)
aw_mean = aw.mean(axis=0).mean(axis=1) # (batch, S, S)
received = aw_mean.sum(axis=1) # (batch, S)
attn_sums.append(received)
labels_test.append(labels.numpy())
pooled.append(enc_out.mean(dim=1).cpu().numpy())
attn_sums = np.concatenate(attn_sums, axis=0)
labels_test = np.concatenate(labels_test, axis=0)
pooled = np.concatenate(pooled, axis=0)
def top_genes_for_class(attn_sums, labels, gene_names, class_id, topk=8):
gene_names = list(gene_names) # ensure it's a list
idxs = np.where(labels == class_id)[0]
avg_attn = attn_sums[idxs].mean(axis=0)
avg_attn = np.array(avg_attn).flatten() # ensure 1D array
top_idx = np.argsort(-avg_attn)[:topk]
return [(gene_names[i], float(avg_attn[i])) for i in top_idx]
for c in range(n_classes):
class_name = le.inverse_transform([c])[0]
print(f'Class {c} ({class_name}):', top_genes_for_class(attn_sums, labels_test, genes, c))
Class 0 (PSC cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 1 (acinar cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 2 (alpha cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 3 (beta cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 4 (co-expression cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 5 (delta cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 6 (ductal cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 7 (endothelial cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 0.9999999403953552)]
Class 8 (gamma cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 9 (unclassified endocrine cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0000001192092896)]
2D visualization of pooled embeddings¶
In [32]:
try:
import umap
z = umap.UMAP(n_components=2, random_state=42).fit_transform(pooled)
except Exception:
from sklearn.decomposition import PCA
z = PCA(n_components=2).fit_transform(pooled)
plt.figure(figsize=(6, 5))
plt.scatter(z[:, 0], z[:, 1], c=labels_test, cmap='tab10', s=8, alpha=0.7)
plt.title('Pooled encoder embeddings (test set)')
plt.colorbar()
plt.show()
/opt/anaconda3/envs/genomics-cnn/lib/python3.9/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism. warn(
In [33]:
import numpy as np
import matplotlib.pyplot as plt
def batch_attention_matrix(model, bins_np):
model.eval()
with torch.no_grad():
tok_emb = tokens_from_bins(bins_np, gene_emb1, bin_emb1, device) # ← fix
logits, enc_out, attn_list = model(tok_emb)
attn_stack = torch.stack(attn_list, dim=0).detach().cpu().numpy()
if attn_stack.ndim == 5:
attn_mean = attn_stack.mean(axis=0).mean(axis=1)
elif attn_stack.ndim == 4:
attn_mean = attn_stack.mean(axis=0)
else:
raise ValueError(f"Unexpected attention shape: {attn_stack.shape}")
return attn_mean
def average_attention_by_class(model, loader):
"""
Average attention matrices by class label.
Returns:
class_attn: dict[class_id] = averaged attention matrix (S, S)
class_counts: dict[class_id] = number of cells used
"""
class_sum = {}
class_count = {}
model.eval()
with torch.no_grad():
for bins_np, labels in loader:
bins_np = np.array(bins_np)
labels_np = np.array(labels)
attn_mean = batch_attention_matrix(model, bins_np) # (batch, S, S)
for i, c in enumerate(labels_np):
if c not in class_sum:
class_sum[c] = attn_mean[i].copy()
class_count[c] = 1
else:
class_sum[c] += attn_mean[i]
class_count[c] += 1
class_attn = {c: class_sum[c] / class_count[c] for c in class_sum}
return class_attn, class_count
def plot_attention_heatmap(attn_mat, gene_names, top_k=30, title="Attention heatmap"):
"""
Plot a readable gene-gene attention heatmap using the top_k genes
with the highest total incoming attention.
"""
gene_scores = attn_mat.sum(axis=0)
top_idx = np.argsort(-gene_scores)[:top_k]
sub_mat = attn_mat[np.ix_(top_idx, top_idx)]
sub_genes = [gene_names[i] for i in top_idx]
plt.figure(figsize=(10, 8))
im = plt.imshow(sub_mat, aspect="auto", interpolation="nearest")
plt.colorbar(im, label="Attention weight")
plt.xticks(range(top_k), sub_genes, rotation=90)
plt.yticks(range(top_k), sub_genes)
plt.title(title)
plt.tight_layout()
plt.show()
In [34]:
genes = [g.split('|')[0] for g in genes]
print(genes[:5]) # ['INS', 'SST', 'PPY', 'CLPS', 'GHRL']
['INS', 'SST', 'PPY', 'CLPS', 'GHRL']
In [35]:
##### print attention heatmap for all cell types
class_attn, class_counts = average_attention_by_class(model, test_loader)
for c in sorted(class_attn.keys()):
name = le.inverse_transform([c])[0]
plot_attention_heatmap(
class_attn[c],
genes,
top_k=30,
title=f"{name} attention heatmap (n={class_counts[c]})"
)
In [36]:
target_name = "beta" # adjust if your label is "Beta cells", etc.
for c in sorted(class_attn.keys()):
name = le.inverse_transform([c])[0]
if target_name.lower() in name.lower():
plot_attention_heatmap(
class_attn[c],
genes,
top_k=30,
title=f"{name} attention heatmap (n={class_counts[c]})"
)
In [37]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform
def attention_to_gene_modules(attn, gene_names, top_k=30, n_modules=4):
"""
Convert an attention heatmap into gene modules by clustering the top-k genes.
Parameters
----------
attn : array-like
1D gene importance vector or 2D gene x gene attention matrix.
gene_names : list-like
Gene names in the same order as attn.
top_k : int
Number of top genes to keep before clustering.
n_modules : int
Number of modules to cut the dendrogram into.
Returns
-------
modules : dict
{module_id: [genes...]}
sub_attn : np.ndarray
Top-k attention submatrix.
top_genes : np.ndarray
Selected top-k genes.
labels : np.ndarray
Cluster labels for top-k genes.
"""
attn = np.asarray(attn)
gene_names = np.asarray(gene_names)
# If 1D, just rank genes by attention score
if attn.ndim == 1:
top_idx = np.argsort(attn)[::-1][:top_k]
top_genes = gene_names[top_idx]
modules = {1: list(top_genes)}
return modules, None, top_genes, np.ones(len(top_genes), dtype=int)
# If 2D, use a symmetrized attention matrix
sym_attn = (attn + attn.T) / 2.0
# Score genes by average attention, then keep top_k
gene_scores = sym_attn.mean(axis=0)
top_idx = np.argsort(gene_scores)[::-1][:top_k]
top_genes = gene_names[top_idx]
sub_attn = sym_attn[np.ix_(top_idx, top_idx)]
# Convert similarity to distance for clustering
sub_min = sub_attn.min()
sub_max = sub_attn.max()
sub_norm = (sub_attn - sub_min) / (sub_max - sub_min + 1e-8)
dist = 1.0 - sub_norm
np.fill_diagonal(dist, 0.0)
# Hierarchical clustering
condensed = squareform(dist, checks=False)
Z = linkage(condensed, method="average")
labels = fcluster(Z, t=n_modules, criterion="maxclust")
# Build modules
modules = {}
for m in sorted(np.unique(labels)):
modules[f"Module {m}"] = list(top_genes[labels == m])
return modules, sub_attn, top_genes, labels
In [38]:
beta_name = "beta cell" # change to your exact label if needed
beta_class = le.transform([beta_name])[0]
modules, sub_attn, top_genes, labels = attention_to_gene_modules(
class_attn[beta_class],
genes,
top_k=30,
n_modules=3
)
for mod_name, mod_genes in modules.items():
print(mod_name, ":", mod_genes)
Module 1 : ['CPE', 'PRSS3P2', 'SLC44A3', 'CLK1', 'SOX4', 'NDUFA13', 'DYNLRB1', 'TIMP1', 'CTSD', 'RPL38', 'GC', 'ATP5J2', 'NEU1', 'SQSTM1', 'G6PC2', 'SURF4', 'CTRC', 'SRP14', 'EIF4EBP1', 'PRDX4', 'SLIRP', 'ATP6V0E1', 'SPINT2', 'ERCC_3750:mix1_7500:mix2', 'LRRC75A-AS1', 'BEX1', 'CYSTM1', 'CD63'] Module 2 : ['PAM'] Module 3 : ['GHITM']
In [39]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
def plot_gene_module_network(sub_attn, top_genes, labels, modules=None, edge_quantile=0.85, title="Gene module network"):
"""
sub_attn : 2D attention matrix for top genes
top_genes : list/array of gene names in the same order as sub_attn
labels : cluster labels for each gene in top_genes
modules : optional dict like {"Module 1": [...], ...}
edge_quantile : threshold for drawing edges
"""
top_genes = list(top_genes)
sub_attn = np.asarray(sub_attn)
G = nx.Graph()
# module assignment
gene_to_module = {}
if modules is not None:
for mod_name, genes_in_mod in modules.items():
for g in genes_in_mod:
gene_to_module[g] = mod_name
else:
for g, lab in zip(top_genes, labels):
gene_to_module[g] = f"Module {lab}"
# add nodes
for g in top_genes:
G.add_node(g, module=gene_to_module.get(g, "Module 0"))
# add edges for strong attention values
vals = sub_attn[np.triu_indices_from(sub_attn, k=1)]
thresh = np.quantile(vals, edge_quantile)
for i in range(len(top_genes)):
for j in range(i + 1, len(top_genes)):
w = sub_attn[i, j]
if w >= thresh:
G.add_edge(top_genes[i], top_genes[j], weight=float(w))
# layout
pos = nx.spring_layout(G, seed=42, k=1.0)
# colors by module
module_names = sorted(set(nx.get_node_attributes(G, "module").values()))
cmap = plt.cm.tab10
module_color_map = {m: cmap(i % 10) for i, m in enumerate(module_names)}
node_colors = [module_color_map[G.nodes[n]["module"]] for n in G.nodes()]
# edge widths from weights
edge_weights = [G[u][v]["weight"] for u, v in G.edges()]
if edge_weights:
wmin, wmax = min(edge_weights), max(edge_weights)
edge_widths = [1 + 4 * (w - wmin) / (wmax - wmin + 1e-8) for w in edge_weights]
else:
edge_widths = []
plt.figure(figsize=(12, 10))
nx.draw_networkx_edges(G, pos, alpha=0.35, width=edge_widths)
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=900, linewidths=1, edgecolors="black")
nx.draw_networkx_labels(G, pos, font_size=9)
handles = [
plt.Line2D([0], [0], marker='o', color='w',
markerfacecolor=module_color_map[m], markeredgecolor='black',
markersize=10, label=m)
for m in module_names
]
plt.legend(handles=handles, title="Modules", loc="best")
plt.title(title)
plt.axis("off")
plt.tight_layout()
plt.show()
In [40]:
beta_name = "beta cell"
beta_class = le.transform([beta_name])[0]
modules, sub_attn, top_genes, labels = attention_to_gene_modules(
class_attn[beta_class],
genes,
top_k=30,
n_modules=3
)
plot_gene_module_network(
sub_attn=sub_attn,
top_genes=top_genes,
labels=labels,
modules=modules,
edge_quantile=0.85,
title="Beta-cell gene module network"
)
Getting beta cell gene modules from diabetes cells¶
In [41]:
def average_attention_by_class(model, loader, gene_emb, bin_emb): # updated signature
class_sum = {}
class_count = {}
model.eval()
with torch.no_grad():
for bins_np, labels in loader:
bins_np = np.array(bins_np, copy=True)
labels_np = np.array(labels)
attn_mean = batch_attention_matrix(model, bins_np, gene_emb, bin_emb) # ✅
for i, c in enumerate(labels_np):
if c not in class_sum:
class_sum[c] = attn_mean[i].copy()
class_count[c] = 1
else:
class_sum[c] += attn_mean[i]
class_count[c] += 1
class_attn = {c: class_sum[c] / class_count[c] for c in class_sum}
return class_attn, class_count
def batch_attention_matrix(model, bins_np, gene_emb, bin_emb): # updated signature
model.eval()
with torch.no_grad():
tok_emb = tokens_from_bins(bins_np, gene_emb, bin_emb, device) # ✅
logits, enc_out, attn_list = model(tok_emb)
attn_stack = torch.stack(attn_list, dim=0).detach().cpu().numpy()
if attn_stack.ndim == 5:
attn_mean = attn_stack.mean(axis=0).mean(axis=1)
elif attn_stack.ndim == 4:
attn_mean = attn_stack.mean(axis=0)
else:
raise ValueError(f"Unexpected attention shape: {attn_stack.shape}")
return attn_mean
In [42]:
def run_transformer_pipeline(
X, y,
n_bins=7,
gene_emb_dim=128,
nhead=4,
n_layers=2,
lr=1e-3,
n_epochs=6,
batch_size_train=32,
batch_size_test=64,
test_size=0.3,
random_state=42,
device=device
):
"""
Full pipeline: binning → embeddings → model → training → attention extraction.
Returns
-------
results : dict with keys:
'model' : trained ScBERTToy
'le' : fitted LabelEncoder (0-indexed, no gaps)
'gene_emb' : nn.Embedding for genes
'bin_emb' : nn.Embedding for bins
'genes' : list of gene names
'class_attn' : dict[class_id] → averaged attention matrix (S, S)
'class_counts' : dict[class_id] → number of cells
'attn_sums' : np.ndarray (n_test_cells, n_genes) — per-cell attention
'labels_test' : np.ndarray of test labels (0-indexed)
'pooled' : np.ndarray (n_test_cells, gene_emb_dim) — encoder embeddings
'train_history' : list of dicts per epoch
"""
# ── 1. Label encoding ───────────────────────────────────────────────────
le = LabelEncoder()
y_encoded = le.fit_transform(y)
n_classes = len(le.classes_)
print(f"[pipeline] n_classes={n_classes}, classes={le.classes_}")
# ── 2. Binning ──────────────────────────────────────────────────────────
binned = percentile_bin_global(X, n_bins=n_bins)
actual_max_bin = int(binned.max())
n_bin_embeddings = actual_max_bin + 1 # safe embedding size
print(f"[pipeline] bin range: {binned.min()} – {actual_max_bin} "
f"(using {n_bin_embeddings} embedding slots)")
# ── 3. Embeddings ───────────────────────────────────────────────────────
genes = X.columns.tolist()
n_genes = len(genes)
gene_emb = nn.Embedding(n_genes, gene_emb_dim).to(device)
bin_emb = nn.Embedding(n_bin_embeddings, gene_emb_dim).to(device)
# ── 4. DataLoaders ──────────────────────────────────────────────────────
train_idx, test_idx = train_test_split(
np.arange(len(y_encoded)),
test_size=test_size,
stratify=y_encoded,
random_state=random_state
)
train_ds = BinnedTokenDataset(binned[train_idx], y_encoded[train_idx])
test_ds = BinnedTokenDataset(binned[test_idx], y_encoded[test_idx])
train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size_test, shuffle=False)
print(f"[pipeline] train={len(train_ds)}, test={len(test_ds)}")
# ── 5. Model ────────────────────────────────────────────────────────────
encoder = TinyTransformerEncoder(gene_emb_dim, nhead=nhead, n_layers=n_layers).to(device)
classifier = ClassifierHead(gene_emb_dim, n_classes).to(device)
model = ScBERTToy(encoder, classifier).to(device)
optimizer = torch.optim.Adam(
list(model.parameters()) + list(gene_emb.parameters()) + list(bin_emb.parameters()),
lr=lr
)
criterion = nn.CrossEntropyLoss()
# ── 6. Training ─────────────────────────────────────────────────────────
def _train_epoch():
model.train()
total, preds_all, labels_all = 0.0, [], []
for bins_np, labels in train_loader:
bins_np = np.array(bins_np, copy=True)
labels = labels.to(device)
tok = tokens_from_bins(bins_np, gene_emb, bin_emb, device)
optimizer.zero_grad()
logits, _, _ = model(tok)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total += float(loss.item()) * len(labels)
preds_all.append(logits.argmax(dim=1).detach().cpu().numpy())
labels_all.append(labels.detach().cpu().numpy())
return total / len(train_ds), accuracy_score(
np.concatenate(labels_all), np.concatenate(preds_all)
)
@torch.no_grad()
def _evaluate():
model.eval()
total, preds_all, labels_all = 0.0, [], []
for bins_np, labels in test_loader:
bins_np = np.array(bins_np, copy=True)
labels = labels.to(device)
tok = tokens_from_bins(bins_np, gene_emb, bin_emb, device)
logits, _, _ = model(tok)
loss = criterion(logits, labels)
total += float(loss.item()) * len(labels)
preds_all.append(logits.argmax(dim=1).detach().cpu().numpy())
labels_all.append(labels.detach().cpu().numpy())
return total / len(test_ds), accuracy_score(
np.concatenate(labels_all), np.concatenate(preds_all)
)
train_history = []
for epoch in range(1, n_epochs + 1):
tr_loss, tr_acc = _train_epoch()
te_loss, te_acc = _evaluate()
train_history.append(dict(epoch=epoch,
tr_loss=tr_loss, tr_acc=tr_acc,
te_loss=te_loss, te_acc=te_acc))
print(f" Epoch {epoch:02d} | "
f"train_loss={tr_loss:.4f} acc={tr_acc:.4f} | "
f"val_loss={te_loss:.4f} acc={te_acc:.4f}")
# ── 7. Attention extraction ─────────────────────────────────────────────
model.eval()
attn_sums, labels_out, pooled_out = [], [], []
with torch.no_grad():
for bins_np, labels in test_loader:
bins_np = np.array(bins_np, copy=True)
tok = tokens_from_bins(bins_np, gene_emb, bin_emb, device)
logits, enc_out, attn_list = model(tok)
aw = torch.stack(attn_list, dim=0).cpu().numpy() # (layers, B, heads, S, S)
aw_mean = aw.mean(axis=0).mean(axis=1) # (B, S, S)
received = aw_mean.sum(axis=1) # (B, S)
attn_sums.append(received)
labels_out.append(labels.numpy())
pooled_out.append(enc_out.mean(dim=1).cpu().numpy())
attn_sums = np.concatenate(attn_sums, axis=0)
labels_out = np.concatenate(labels_out, axis=0)
pooled_out = np.concatenate(pooled_out, axis=0)
# ── 8. Per-class averaged attention matrices ────────────────────────────
class_attn, class_counts = average_attention_by_class(
model, test_loader, gene_emb, bin_emb # always uses the correct embeddings
)
print(f"[pipeline] class_attn keys: {sorted(class_attn.keys())}")
assert max(class_attn.keys()) < n_classes, \
f"Label {max(class_attn.keys())} out of bounds for le with {n_classes} classes!"
return dict(
model = model,
le = le,
gene_emb = gene_emb,
bin_emb = bin_emb,
genes = genes,
class_attn = class_attn,
class_counts = class_counts,
attn_sums = attn_sums,
labels_test = labels_out,
pooled = pooled_out,
train_history = train_history,
)
In [43]:
# Healthy cells
#res1 = run_transformer_pipeline(X2, y2, n_epochs=6)
# Diabetes cells
res2 = run_transformer_pipeline(X2, y2, n_epochs=6)
# Access anything cleanly — no more le vs le2 confusion
beta_class = res2['le'].transform(["beta cell"])[0]
modules, sub_attn, top_genes, labels = attention_to_gene_modules(
res2['class_attn'][beta_class],
res2['genes'],
top_k=30,
n_modules=4
)
[pipeline] n_classes=9, classes=['PSC cell' 'acinar cell' 'alpha cell' 'beta cell' 'co-expression cell' 'delta cell' 'ductal cell' 'gamma cell' 'unclassified endocrine cell'] [pipeline] bin range: 0 – 6 (using 7 embedding slots) [pipeline] train=769, test=330 Epoch 01 | train_loss=1.7512 acc=0.3875 | val_loss=1.7062 acc=0.4030 Epoch 02 | train_loss=1.6894 acc=0.4148 | val_loss=1.6148 acc=0.5788 Epoch 03 | train_loss=1.5637 acc=0.5293 | val_loss=1.1452 acc=0.6212 Epoch 04 | train_loss=1.1089 acc=0.6203 | val_loss=0.9440 acc=0.6788 Epoch 05 | train_loss=0.9092 acc=0.6957 | val_loss=0.8799 acc=0.6727 Epoch 06 | train_loss=0.8112 acc=0.7321 | val_loss=0.8328 acc=0.7000 [pipeline] class_attn keys: [0, 1, 2, 3, 4, 5, 6, 7, 8]
In [44]:
top_genes
Out[44]:
array(['LCN2|NM_005564', 'SDC4|NM_002999',
'PMEPA1|NM_199171+NM_199170+NM_020182+NM_001255976+NM_199169',
'SERINC2|NM_001199038+NM_001199037+NM_178865+NM_018565+NM_001199039',
'LITAF|NM_004862+NM_001136473+NM_001136472+NR_024320',
'KRT19|NM_002276', 'GCG|NM_002054',
'ALDH1A3|NM_000693+NM_001293815', 'CD9|NM_001769',
'ISYNA1|NR_045573+NM_001170938+NM_016368+NM_001253389+NR_045574',
'STOM|NM_001270527+NM_004099+NM_001270526+NM_198194+NR_073037',
'DUOX2|NM_014080', 'RNF123|NM_022064', 'SOD3|NM_003102',
'DARS2|NM_018122',
'GPSM1|NM_015597+NM_001145638+NM_001145639+NM_001200003',
'SLC43A3|NM_199329+NM_014096+NM_001278201+NM_017611+NM_001278206',
'DNPH1|NM_199184+NM_006443',
'ERCC_1875:mix1_468.75:mix2|ERCC-00136',
'DEPDC1B|NM_001145208+NM_018369', 'ANPEP|NM_001150',
'CHCHD10|NM_213720', 'PPY|NM_002722',
'ERCC_937.5:mix1_937.5:mix2|ERCC-00009',
'C10orf53|NM_001042427+NM_182554', 'JUNB|NM_002229',
'SYNGR2|NM_004710', 'CHGB|NM_001819', 'SCAMP2|NM_005697',
'NOL4L|NM_080616+NM_001256798'], dtype='<U331')
In [45]:
plot_gene_module_network(sub_attn=sub_attn, top_genes=top_genes,
labels=labels, modules=modules,
title="Beta-cell gene module network (diabetes)")
In [46]:
# ── Pull results from diabetes pipeline ────────────────────────────────────
le2 = res2['le']
class_attn2 = res2['class_attn']
class_counts2= res2['class_counts']
genes2 = [g.split('|')[0] for g in res2['genes']] # strip transcript suffix
# ── Find beta cell class index ──────────────────────────────────────────────
print("Available classes:", le2.classes_) # confirm exact label spelling
beta_class = le2.transform(["beta cell"])[0]
print(f"Beta cell → class index {beta_class}, n={class_counts2[beta_class]} cells")
# ── Plot ────────────────────────────────────────────────────────────────────
plot_attention_heatmap(
class_attn2[beta_class],
genes2,
top_k=30,
title=f"Beta cell attention heatmap — diabetes (n={class_counts2[beta_class]})"
)
Available classes: ['PSC cell' 'acinar cell' 'alpha cell' 'beta cell' 'co-expression cell' 'delta cell' 'ductal cell' 'gamma cell' 'unclassified endocrine cell'] Beta cell → class index 3, n=30 cells
Extensions¶
- Try different bin counts.
- Compare attention-derived genes to marker genes or differential expression.
- Replace the encoder with Performer for larger gene sets.
In [ ]: