Genomic Data Analysis Transformer Encoder Example¶

A compact notebook showing a Transformer encoder workflow for genomic / single-cell data:

  • scRNA-seq data from healty and diabetic pancreas
  • expression binning
  • gene and bin embeddings
  • small Transformer encoder
  • classifier training
  • attention-based module discovery

Single-cell RNA-seq analysis of human pancreas from healthy individuals and type 2 diabetes patients¶

Segerstolpe, Åsa, et al. "Single-cell transcriptome profiling of human pancreatic islets in health and type 2 diabetes." Cell metabolism 24.4 (2016): 593-607.

Download the scRNA-seq data from¶

https://www.ebi.ac.uk/biostudies/arrayexpress/studies/E-MTAB-5061

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import random
import matplotlib.pyplot as plt

np.random.seed(42)
torch.manual_seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
Device: cpu
In [2]:
import os

current_directory = os.getcwd()
print(f"Current working directory: {current_directory}")
Current working directory: /Users/hoyen/Desktop/STAT718-2026/PythonCode/Transformer
In [3]:
from __future__ import annotations

from pathlib import Path
import pandas as pd


def load_expression_matrix(matrix_path):
    matrix_path = Path(matrix_path)

    # Read the first line to get sample IDs
    with matrix_path.open("r", encoding="utf-8") as fh:
        first_line = fh.readline().rstrip("\n")

    parts = first_line.split("\t")
    if not parts or not parts[0].startswith("#samples"):
        raise ValueError(f"Unexpected first line: {parts[:5]}")

    sample_ids = parts[1:]
    n = len(sample_ids)

    # Read the table, skipping the first comment line
    df = pd.read_csv(
        matrix_path,
        sep="\t",
        comment="#",
        header=None,
        engine="python",
    )

    print("Raw shape:", df.shape)

    # Expected: 2 ID columns + n RPKM + n counts
    expected_cols = 2 + 2 * n
    if df.shape[1] != expected_cols:
        raise ValueError(
            f"Expected {expected_cols} columns total "
            f"(2 IDs + {n} RPKM + {n} counts), but got {df.shape[1]}."
        )

    gene_symbol = df.iloc[:, 0].astype(str)
    transcript_id = df.iloc[:, 1].astype(str)

    rpkm = df.iloc[:, 2 : 2 + n].copy()
    counts = df.iloc[:, 2 + n : 2 + 2 * n].copy()

    rpkm.columns = sample_ids
    counts.columns = sample_ids

    gene_symbol = df.iloc[:, 0].astype(str)
    transcript_id = df.iloc[:, 1].astype(str)

    feature_id = gene_symbol + "|" + transcript_id 
    # feature_id = gene_symbol 
    rpkm.index = feature_id
    counts.index = feature_id

    return rpkm, counts

def load_sdrf_metadata(sdrf_path):
    sdrf_path = Path(sdrf_path)
    meta = pd.read_csv(sdrf_path, sep="\t", dtype=str, low_memory=False)

    def find_col(*keywords):
        for c in meta.columns:
            cl = c.lower()
            if all(k in cl for k in keywords):
                return c
        return None

    # ---------- Find sample column ----------
    sample_col = None
    for candidate in ["source name", "sample name", "assay name"]:
        if candidate in [c.lower().strip() for c in meta.columns]:
            sample_col = next(
                c for c in meta.columns if c.lower().strip() == candidate
            )
            break

    # ---------- Find cell type ----------
    celltype_col = None
    for c in meta.columns:
        if "inferred cell type" in c.lower():
            celltype_col = c
            break

    # ---------- Find disease ----------
    disease_col = None
    for c in meta.columns:
        if "disease" in c.lower():
            disease_col = c
            break

    # ---------- Find quality ----------
    quality_col = None
    for c in meta.columns:
        if "submitted single cell quality" in c.lower():
            quality_col = c
            break

    # ---------- Build cell_meta ----------
    cell_meta = pd.DataFrame(index=meta[sample_col].astype(str))
    cell_meta.index.name = "sample_id"

    cell_meta["cell_type"] = (
        meta[celltype_col].values if celltype_col is not None else None
    )
    cell_meta["disease"] = (
        meta[disease_col].values if disease_col is not None else None
    )
    cell_meta["quality"] = (
        meta[quality_col].values if quality_col is not None else None
    )

    return cell_meta


base_path = Path("/Users/hoyen/Desktop/STAT718-2026/PythonCode/Transformer/E-MTAB-5061")
matrix_file = base_path / "count.txt"
sdrf_file = base_path / "E-MTAB-5061.sdrf.txt"

rpkm, counts = load_expression_matrix(matrix_file)
cell_meta = load_sdrf_metadata(sdrf_file)
Raw shape: (26271, 7030)

Remove low quality cells¶

In [4]:
ok_mask = cell_meta["quality"].str.strip().str.lower().eq("ok")
cell_meta_ok = cell_meta.loc[ok_mask].copy()
rpkm_ok = rpkm.loc[:, rpkm.columns.intersection(cell_meta_ok.index)]
cell_meta_ok.to_csv(base_path / "E-MTAB-5061_cell_metadata.tsv", sep="\t", index=False)
In [5]:
cell_meta_ok.shape
Out[5]:
(2209, 3)
In [6]:
rpkm_ok.shape
Out[6]:
(26271, 2209)
In [7]:
cell_meta_ok.head
Out[7]:
<bound method NDFrame.head of                     cell_type                    disease quality
sample_id                                                       
AZ_A10             delta cell                     normal      OK
AZ_A11             alpha cell                     normal      OK
AZ_A12             delta cell                     normal      OK
AZ_A2              gamma cell                     normal      OK
AZ_A5             ductal cell                     normal      OK
...                       ...                        ...     ...
HP1526901T2D_P22   delta cell  type II diabetes mellitus      OK
HP1526901T2D_P23   gamma cell  type II diabetes mellitus      OK
HP1526901T2D_P4     beta cell  type II diabetes mellitus      OK
HP1526901T2D_P7     beta cell  type II diabetes mellitus      OK
HP1526901T2D_P9     beta cell  type II diabetes mellitus      OK

[2209 rows x 3 columns]>
In [8]:
# X: cells x genes
X = rpkm_ok.T
# y: cell types aligned to X
y = cell_meta_ok.loc[X.index, "cell_type"]
In [9]:
print(X.shape)   # (num_cells, num_genes)
print(y.shape)   # (num_cells,)
(2209, 26271)
(2209,)
In [10]:
# check alignment
print((X.index == y.index).all())   # should be True
True
In [11]:
### select 2000 highly variable genes 
mean = X.mean(axis=0)
var = X.var(axis=0)
dispersion = var / (mean + 1e-6)
top_genes = dispersion.sort_values(ascending=False).head(2000).index
X_hvg=X[top_genes]
X_hvg.shape
Out[11]:
(2209, 2001)
In [12]:
duplicates = X_hvg.columns[X_hvg.columns.duplicated()]

print("Number of duplicate columns:", len(duplicates))
print("Duplicate column names:", duplicates.tolist())
Number of duplicate columns: 1
Duplicate column names: ['CFC1B|NM_001079530']
In [13]:
import pandas as pd
import numpy as np

# variance for every column position, even if names repeat
col_var = X_hvg.var(axis=0)

# keep the position with the highest variance for each column name
best_pos = {}
for i, col in enumerate(X_hvg.columns):
    v = col_var.iloc[i]
    if col not in best_pos or v > best_pos[col][0]:
        best_pos[col] = (v, i)

# preserve original order of the kept columns
keep_idx = sorted(pos for _, pos in best_pos.values())

X_hvg_nodup = X_hvg.iloc[:, keep_idx]

print("Before:", X_hvg.shape)
print("After:", X_hvg_nodup.shape)
print("Any duplicated columns left?", X_hvg_nodup.columns.duplicated().any())
Before: (2209, 2001)
After: (2209, 2000)
Any duplicated columns left? False
In [14]:
disease = cell_meta_ok.loc[X_hvg_nodup.index, "disease"].str.strip().str.lower()
disease.head
Out[14]:
<bound method NDFrame.head of HP1502401_H13                          normal
HP1502401_J14                          normal
HP1502401_B14                          normal
HP1502401_A14                          normal
HP1502401_C14                          normal
                              ...            
HP1526901T2D_F7     type ii diabetes mellitus
HP1525301T2D_K3     type ii diabetes mellitus
HP1525301T2D_J10    type ii diabetes mellitus
HP1526901T2D_N8     type ii diabetes mellitus
HP1526901T2D_A8     type ii diabetes mellitus
Name: disease, Length: 2209, dtype: object>
In [15]:
normal_mask = disease.eq("normal")
diabetes_mask = disease.str.contains("diabetes", na=False)

X1 = X_hvg_nodup.loc[normal_mask]
y1 = y.loc[normal_mask]

X2 = X_hvg_nodup.loc[diabetes_mask]
y2 = y.loc[diabetes_mask]
In [16]:
print(X1.shape)
print(y1.shape)
(1097, 2000)
(1097,)
In [17]:
print(X2.shape)
print(y2.shape)
(1112, 2000)
(1112,)
In [18]:
print("Normal (y1):")
print(y1.value_counts())

print("\nDiabetes (y2):")
print(y2.value_counts())
Normal (y1):
cell_type
alpha cell                     443
beta cell                      171
ductal cell                    135
acinar cell                    112
gamma cell                      75
delta cell                      59
unclassified endocrine cell     29
co-expression cell              26
PSC cell                        23
endothelial cell                13
epsilon cell                     5
mast cell                        4
MHC class II cell                1
unclassified cell                1
Name: count, dtype: int64

Diabetes (y2):
cell_type
alpha cell                     443
ductal cell                    251
gamma cell                     122
beta cell                       99
acinar cell                     73
delta cell                      55
PSC cell                        31
co-expression cell              13
unclassified endocrine cell     12
MHC class II cell                4
endothelial cell                 3
mast cell                        3
epsilon cell                     2
unclassified cell                1
Name: count, dtype: int64
In [19]:
### remove cell with less than 10 cells in healthy cells
counts = y1.value_counts()
valid_classes = counts[counts >= 10].index
mask1 = y1.isin(valid_classes)
y1 = y1[mask1]
X1 = X1[mask1]
print(y1.value_counts())
cell_type
alpha cell                     443
beta cell                      171
ductal cell                    135
acinar cell                    112
gamma cell                      75
delta cell                      59
unclassified endocrine cell     29
co-expression cell              26
PSC cell                        23
endothelial cell                13
Name: count, dtype: int64
In [20]:
### remove cell with less than 10 cells in healthy cells
counts = y2.value_counts()
valid_classes = counts[counts >= 10].index
mask2 = y2.isin(valid_classes)
y2 = y2[mask2]
X2 = X2[mask2]
print(y2.value_counts())
cell_type
alpha cell                     443
ductal cell                    251
gamma cell                     122
beta cell                       99
acinar cell                     73
delta cell                      55
PSC cell                        31
co-expression cell              13
unclassified endocrine cell     12
Name: count, dtype: int64
In [21]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
# fit on combined labels to ensure consistent encoding across both datasets
le.fit(np.concatenate([y1, y2]))

y1_encoded = le.transform(y1)
y2_encoded = le.transform(y2)

print(le.classes_)
['PSC cell' 'acinar cell' 'alpha cell' 'beta cell' 'co-expression cell'
 'delta cell' 'ductal cell' 'endothelial cell' 'gamma cell'
 'unclassified endocrine cell']
In [22]:
y1_encoded
Out[22]:
array([8, 2, 3, ..., 8, 8, 2])
In [23]:
y2_encoded
Out[23]:
array([8, 2, 2, ..., 5, 2, 2])

Binning and embeddings from healthy cells¶

In [24]:
def percentile_bin_global(expr_mat, n_bins=7):
    expr_mat = expr_mat.to_numpy()
    flat = expr_mat.flatten()
    q = np.nanpercentile(flat, np.linspace(0, 100, n_bins + 1))
    binned = np.digitize(expr_mat, q[1:-1], right=True).astype(np.int64)
    return np.clip(binned, 0, n_bins - 1)  # ← clamp to valid range [0, 6]
    
n_bins = 7
binned1 = percentile_bin_global(X1, n_bins=n_bins)
print(binned1.min(), binned1.max())  
0 6
In [26]:
n_genes = X1.shape[1]
print(f"n_genes: {n_genes}, X1 shape: {X1.shape}, binned1 shape: {binned1.shape}")
n_genes: 2000, X1 shape: (1086, 2000), binned1 shape: (1086, 2000)
In [27]:
n_classes = len(np.unique(y1))
gene_emb_dim = 128
gene_emb1 = nn.Embedding(n_genes, gene_emb_dim).to(device)
bin_emb1 = nn.Embedding(n_bins, gene_emb_dim).to(device)
genes = X1.columns.tolist()

def tokens_from_bins(binned_np, gene_emb, bin_emb, device):
    b = torch.from_numpy(binned_np).long().to(device)
    batch, seq_len = b.shape
    gene_idx = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch, 1)
    return gene_emb(gene_idx) + bin_emb(b)

# call:
print(binned1.shape, tokens_from_bins(binned1[:2], gene_emb1, bin_emb1, device).shape)
(1086, 2000) torch.Size([2, 2000, 128])

Transformer encoder with attention weights¶

In [28]:
class TransformerLayerWithAttn(nn.Module):
    def __init__(self, d_model, nhead, dim_ff=256, dropout=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(nn.Linear(d_model, dim_ff), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim_ff, d_model))
        self.norm2 = nn.LayerNorm(d_model)
    def forward(self, x):
        attn_out, attn_w = self.mha(x, x, x, need_weights=True)
        x = self.norm1(x + attn_out)
        x = self.norm2(x + self.ff(x))
        return x, attn_w

class TinyTransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead, n_layers=2):
        super().__init__()
        self.layers = nn.ModuleList([TransformerLayerWithAttn(d_model, nhead) for _ in range(n_layers)])
    def forward(self, x):
        attns = []
        for layer in self.layers:
            x, aw = layer(x)
            attns.append(aw)
        return x, attns

class ClassifierHead(nn.Module):
    def __init__(self, d_model, n_classes, hidden=256):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(d_model, hidden), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden, n_classes))
    def forward(self, enc_out):
        return self.mlp(enc_out.mean(dim=1))

class ScBERTToy(nn.Module):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier
    def forward(self, tok_emb):
        enc_out, attns = self.encoder(tok_emb)
        logits = self.classifier(enc_out)
        return logits, enc_out, attns

encoder = TinyTransformerEncoder(gene_emb_dim, nhead=4, n_layers=2).to(device)
classifier = ClassifierHead(gene_emb_dim, n_classes).to(device)
model = ScBERTToy(encoder, classifier).to(device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(gene_emb1.parameters()) + list(bin_emb1.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss()
print(model)
ScBERTToy(
  (encoder): TinyTransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerLayerWithAttn(
        (mha): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (ff): Sequential(
          (0): Linear(in_features=128, out_features=256, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=256, out_features=128, bias=True)
        )
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (classifier): ClassifierHead(
    (mlp): Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.2, inplace=False)
      (3): Linear(in_features=256, out_features=10, bias=True)
    )
  )
)

Data loaders¶

In [29]:
class BinnedTokenDataset(Dataset):
    def __init__(self, binned_matrix, labels):
        self.binned = binned_matrix.astype(np.int64)
        self.labels = labels.astype(np.int64)
    def __len__(self):
        return self.binned.shape[0]
    def __getitem__(self, idx):
        return self.binned[idx], int(self.labels[idx])

train_idx, test_idx = train_test_split(np.arange(len(y1_encoded)), test_size=0.3, stratify=y1_encoded, random_state=42)
train_ds = BinnedTokenDataset(binned1[train_idx], y1_encoded[train_idx])
test_ds = BinnedTokenDataset(binned1[test_idx], y1_encoded[test_idx])
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)
print(len(train_ds), len(test_ds))
760 326

Training¶

In [30]:
def train_epoch():
    model.train()
    total = 0.0
    preds_all = []
    labels_all = []
    for bins_np, labels in train_loader:
        bins_np = np.array(bins_np)
        labels = labels.to(device)
        tok = tokens_from_bins(bins_np, gene_emb1, bin_emb1, device)
        optimizer.zero_grad()
        logits, _, _ = model(tok)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total += float(loss.item()) * len(labels)
        preds_all.append(logits.argmax(dim=1).detach().cpu().numpy())
        labels_all.append(labels.detach().cpu().numpy())
    preds_all = np.concatenate(preds_all)
    labels_all = np.concatenate(labels_all)
    return total / len(train_ds), accuracy_score(labels_all, preds_all)

@torch.no_grad()
def evaluate():
    model.eval()
    total = 0.0
    preds_all = []
    labels_all = []
    for bins_np, labels in test_loader:
        bins_np = np.array(bins_np)
        labels = labels.to(device)
        tok = tokens_from_bins(bins_np, gene_emb1, bin_emb1, device)
        logits, _, _ = model(tok)
        loss = criterion(logits, labels)
        total += float(loss.item()) * len(labels)
        preds_all.append(logits.argmax(dim=1).detach().cpu().numpy())
        labels_all.append(labels.detach().cpu().numpy())
    preds_all = np.concatenate(preds_all)
    labels_all = np.concatenate(labels_all)
    return total / len(test_ds), accuracy_score(labels_all, preds_all), labels_all, preds_all

for epoch in range(1, 7):
    tr_loss, tr_acc = train_epoch()
    te_loss, te_acc, _, _ = evaluate()
    print(f'Epoch {epoch:02d} | train_loss={tr_loss:.4f} acc={tr_acc:.4f} | val_loss={te_loss:.4f} acc={te_acc:.4f}')
Epoch 01 | train_loss=1.8642 acc=0.4000 | val_loss=1.7437 acc=0.4110
Epoch 02 | train_loss=1.6885 acc=0.4355 | val_loss=1.5680 acc=0.4693
Epoch 03 | train_loss=1.3955 acc=0.5184 | val_loss=1.3400 acc=0.5153
Epoch 04 | train_loss=1.2928 acc=0.5461 | val_loss=1.1991 acc=0.5675
Epoch 05 | train_loss=1.1484 acc=0.6026 | val_loss=1.1677 acc=0.6104
Epoch 06 | train_loss=1.1118 acc=0.6079 | val_loss=1.0369 acc=0.6411

Attention-based module discovery¶

In [31]:
model.eval()
attn_sums = []
labels_test = []
pooled = []

with torch.no_grad():
    for bins_np, labels in test_loader:
        bins_np = np.array(bins_np)
        tok = tokens_from_bins(bins_np, gene_emb1, bin_emb1, device)
        logits, enc_out, attn_list = model(tok)
        aw = torch.stack(attn_list, dim=0).cpu().numpy()  # (layers, batch, heads, S, S)
        aw_mean = aw.mean(axis=0).mean(axis=1)            # (batch, S, S)
        received = aw_mean.sum(axis=1)                    # (batch, S)
        attn_sums.append(received)
        labels_test.append(labels.numpy())
        pooled.append(enc_out.mean(dim=1).cpu().numpy())

attn_sums = np.concatenate(attn_sums, axis=0)
labels_test = np.concatenate(labels_test, axis=0)
pooled = np.concatenate(pooled, axis=0)

def top_genes_for_class(attn_sums, labels, gene_names, class_id, topk=8):
    gene_names = list(gene_names)  # ensure it's a list
    idxs = np.where(labels == class_id)[0]
    avg_attn = attn_sums[idxs].mean(axis=0)
    avg_attn = np.array(avg_attn).flatten()  # ensure 1D array
    top_idx = np.argsort(-avg_attn)[:topk]
    return [(gene_names[i], float(avg_attn[i])) for i in top_idx]

for c in range(n_classes):
    class_name = le.inverse_transform([c])[0]
    print(f'Class {c} ({class_name}):', top_genes_for_class(attn_sums, labels_test, genes, c))
Class 0 (PSC cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 1 (acinar cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 2 (alpha cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 3 (beta cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 4 (co-expression cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 5 (delta cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 6 (ductal cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 7 (endothelial cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 0.9999999403953552)]
Class 8 (gamma cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0)]
Class 9 (unclassified endocrine cell): [('INS|NM_001185097+NM_000207+NM_001185098+NM_001291897', 1.0000001192092896)]

2D visualization of pooled embeddings¶

In [32]:
try:
    import umap
    z = umap.UMAP(n_components=2, random_state=42).fit_transform(pooled)
except Exception:
    from sklearn.decomposition import PCA
    z = PCA(n_components=2).fit_transform(pooled)

plt.figure(figsize=(6, 5))
plt.scatter(z[:, 0], z[:, 1], c=labels_test, cmap='tab10', s=8, alpha=0.7)
plt.title('Pooled encoder embeddings (test set)')
plt.colorbar()
plt.show()
/opt/anaconda3/envs/genomics-cnn/lib/python3.9/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
No description has been provided for this image
In [33]:
import numpy as np
import matplotlib.pyplot as plt

def batch_attention_matrix(model, bins_np):
    model.eval()
    with torch.no_grad():
        tok_emb = tokens_from_bins(bins_np, gene_emb1, bin_emb1, device)  # ← fix
        logits, enc_out, attn_list = model(tok_emb)
        attn_stack = torch.stack(attn_list, dim=0).detach().cpu().numpy()
        if attn_stack.ndim == 5:
            attn_mean = attn_stack.mean(axis=0).mean(axis=1)
        elif attn_stack.ndim == 4:
            attn_mean = attn_stack.mean(axis=0)
        else:
            raise ValueError(f"Unexpected attention shape: {attn_stack.shape}")
        return attn_mean
        

def average_attention_by_class(model, loader):
    """
    Average attention matrices by class label.
    Returns:
        class_attn: dict[class_id] = averaged attention matrix (S, S)
        class_counts: dict[class_id] = number of cells used
    """
    class_sum = {}
    class_count = {}

    model.eval()
    with torch.no_grad():
        for bins_np, labels in loader:
            bins_np = np.array(bins_np)
            labels_np = np.array(labels)

            attn_mean = batch_attention_matrix(model, bins_np)  # (batch, S, S)

            for i, c in enumerate(labels_np):
                if c not in class_sum:
                    class_sum[c] = attn_mean[i].copy()
                    class_count[c] = 1
                else:
                    class_sum[c] += attn_mean[i]
                    class_count[c] += 1

    class_attn = {c: class_sum[c] / class_count[c] for c in class_sum}
    return class_attn, class_count


def plot_attention_heatmap(attn_mat, gene_names, top_k=30, title="Attention heatmap"):
    """
    Plot a readable gene-gene attention heatmap using the top_k genes
    with the highest total incoming attention.
    """
    gene_scores = attn_mat.sum(axis=0)
    top_idx = np.argsort(-gene_scores)[:top_k]

    sub_mat = attn_mat[np.ix_(top_idx, top_idx)]
    sub_genes = [gene_names[i] for i in top_idx]

    plt.figure(figsize=(10, 8))
    im = plt.imshow(sub_mat, aspect="auto", interpolation="nearest")
    plt.colorbar(im, label="Attention weight")
    plt.xticks(range(top_k), sub_genes, rotation=90)
    plt.yticks(range(top_k), sub_genes)
    plt.title(title)
    plt.tight_layout()
    plt.show()
In [34]:
genes = [g.split('|')[0] for g in genes]
print(genes[:5])  # ['INS', 'SST', 'PPY', 'CLPS', 'GHRL']
['INS', 'SST', 'PPY', 'CLPS', 'GHRL']
In [35]:
##### print attention heatmap for all cell types

class_attn, class_counts = average_attention_by_class(model, test_loader)
for c in sorted(class_attn.keys()):
    name = le.inverse_transform([c])[0]
    plot_attention_heatmap(
        class_attn[c],
        genes,
        top_k=30,
        title=f"{name} attention heatmap (n={class_counts[c]})"
    )
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [36]:
target_name = "beta"  # adjust if your label is "Beta cells", etc.

for c in sorted(class_attn.keys()):
    name = le.inverse_transform([c])[0]
    
    if target_name.lower() in name.lower():
        plot_attention_heatmap(
            class_attn[c],
            genes,
            top_k=30,
            title=f"{name} attention heatmap (n={class_counts[c]})"
        )
No description has been provided for this image
In [37]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform

def attention_to_gene_modules(attn, gene_names, top_k=30, n_modules=4):
    """
    Convert an attention heatmap into gene modules by clustering the top-k genes.

    Parameters
    ----------
    attn : array-like
        1D gene importance vector or 2D gene x gene attention matrix.
    gene_names : list-like
        Gene names in the same order as attn.
    top_k : int
        Number of top genes to keep before clustering.
    n_modules : int
        Number of modules to cut the dendrogram into.

    Returns
    -------
    modules : dict
        {module_id: [genes...]}
    sub_attn : np.ndarray
        Top-k attention submatrix.
    top_genes : np.ndarray
        Selected top-k genes.
    labels : np.ndarray
        Cluster labels for top-k genes.
    """
    attn = np.asarray(attn)
    gene_names = np.asarray(gene_names)

    # If 1D, just rank genes by attention score
    if attn.ndim == 1:
        top_idx = np.argsort(attn)[::-1][:top_k]
        top_genes = gene_names[top_idx]
        modules = {1: list(top_genes)}
        return modules, None, top_genes, np.ones(len(top_genes), dtype=int)

    # If 2D, use a symmetrized attention matrix
    sym_attn = (attn + attn.T) / 2.0

    # Score genes by average attention, then keep top_k
    gene_scores = sym_attn.mean(axis=0)
    top_idx = np.argsort(gene_scores)[::-1][:top_k]
    top_genes = gene_names[top_idx]
    sub_attn = sym_attn[np.ix_(top_idx, top_idx)]

    # Convert similarity to distance for clustering
    sub_min = sub_attn.min()
    sub_max = sub_attn.max()
    sub_norm = (sub_attn - sub_min) / (sub_max - sub_min + 1e-8)
    dist = 1.0 - sub_norm
    np.fill_diagonal(dist, 0.0)

    # Hierarchical clustering
    condensed = squareform(dist, checks=False)
    Z = linkage(condensed, method="average")
    labels = fcluster(Z, t=n_modules, criterion="maxclust")

    # Build modules
    modules = {}
    for m in sorted(np.unique(labels)):
        modules[f"Module {m}"] = list(top_genes[labels == m])

    return modules, sub_attn, top_genes, labels
In [38]:
beta_name = "beta cell"   # change to your exact label if needed
beta_class = le.transform([beta_name])[0]

modules, sub_attn, top_genes, labels = attention_to_gene_modules(
    class_attn[beta_class],
    genes,
    top_k=30,
    n_modules=3
)

for mod_name, mod_genes in modules.items():
    print(mod_name, ":", mod_genes)
Module 1 : ['CPE', 'PRSS3P2', 'SLC44A3', 'CLK1', 'SOX4', 'NDUFA13', 'DYNLRB1', 'TIMP1', 'CTSD', 'RPL38', 'GC', 'ATP5J2', 'NEU1', 'SQSTM1', 'G6PC2', 'SURF4', 'CTRC', 'SRP14', 'EIF4EBP1', 'PRDX4', 'SLIRP', 'ATP6V0E1', 'SPINT2', 'ERCC_3750:mix1_7500:mix2', 'LRRC75A-AS1', 'BEX1', 'CYSTM1', 'CD63']
Module 2 : ['PAM']
Module 3 : ['GHITM']
In [39]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

def plot_gene_module_network(sub_attn, top_genes, labels, modules=None, edge_quantile=0.85, title="Gene module network"):
    """
    sub_attn : 2D attention matrix for top genes
    top_genes : list/array of gene names in the same order as sub_attn
    labels : cluster labels for each gene in top_genes
    modules : optional dict like {"Module 1": [...], ...}
    edge_quantile : threshold for drawing edges
    """
    top_genes = list(top_genes)
    sub_attn = np.asarray(sub_attn)

    G = nx.Graph()

    # module assignment
    gene_to_module = {}
    if modules is not None:
        for mod_name, genes_in_mod in modules.items():
            for g in genes_in_mod:
                gene_to_module[g] = mod_name
    else:
        for g, lab in zip(top_genes, labels):
            gene_to_module[g] = f"Module {lab}"

    # add nodes
    for g in top_genes:
        G.add_node(g, module=gene_to_module.get(g, "Module 0"))

    # add edges for strong attention values
    vals = sub_attn[np.triu_indices_from(sub_attn, k=1)]
    thresh = np.quantile(vals, edge_quantile)

    for i in range(len(top_genes)):
        for j in range(i + 1, len(top_genes)):
            w = sub_attn[i, j]
            if w >= thresh:
                G.add_edge(top_genes[i], top_genes[j], weight=float(w))

    # layout
    pos = nx.spring_layout(G, seed=42, k=1.0)

    # colors by module
    module_names = sorted(set(nx.get_node_attributes(G, "module").values()))
    cmap = plt.cm.tab10
    module_color_map = {m: cmap(i % 10) for i, m in enumerate(module_names)}
    node_colors = [module_color_map[G.nodes[n]["module"]] for n in G.nodes()]

    # edge widths from weights
    edge_weights = [G[u][v]["weight"] for u, v in G.edges()]
    if edge_weights:
        wmin, wmax = min(edge_weights), max(edge_weights)
        edge_widths = [1 + 4 * (w - wmin) / (wmax - wmin + 1e-8) for w in edge_weights]
    else:
        edge_widths = []

    plt.figure(figsize=(12, 10))
    nx.draw_networkx_edges(G, pos, alpha=0.35, width=edge_widths)
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=900, linewidths=1, edgecolors="black")
    nx.draw_networkx_labels(G, pos, font_size=9)

    handles = [
        plt.Line2D([0], [0], marker='o', color='w',
                   markerfacecolor=module_color_map[m], markeredgecolor='black',
                   markersize=10, label=m)
        for m in module_names
    ]
    plt.legend(handles=handles, title="Modules", loc="best")
    plt.title(title)
    plt.axis("off")
    plt.tight_layout()
    plt.show()
In [40]:
beta_name = "beta cell"
beta_class = le.transform([beta_name])[0]

modules, sub_attn, top_genes, labels = attention_to_gene_modules(
    class_attn[beta_class],
    genes,
    top_k=30,
    n_modules=3
)

plot_gene_module_network(
    sub_attn=sub_attn,
    top_genes=top_genes,
    labels=labels,
    modules=modules,
    edge_quantile=0.85,
    title="Beta-cell gene module network"
)
No description has been provided for this image

Getting beta cell gene modules from diabetes cells¶

In [41]:
def average_attention_by_class(model, loader, gene_emb, bin_emb):  # updated signature
    class_sum = {}
    class_count = {}
    model.eval()
    with torch.no_grad():
        for bins_np, labels in loader:
            bins_np   = np.array(bins_np, copy=True)
            labels_np = np.array(labels)
            attn_mean = batch_attention_matrix(model, bins_np, gene_emb, bin_emb)  # ✅
            for i, c in enumerate(labels_np):
                if c not in class_sum:
                    class_sum[c]  = attn_mean[i].copy()
                    class_count[c] = 1
                else:
                    class_sum[c]  += attn_mean[i]
                    class_count[c] += 1
    class_attn = {c: class_sum[c] / class_count[c] for c in class_sum}
    return class_attn, class_count


def batch_attention_matrix(model, bins_np, gene_emb, bin_emb):  # updated signature
    model.eval()
    with torch.no_grad():
        tok_emb    = tokens_from_bins(bins_np, gene_emb, bin_emb, device)  # ✅
        logits, enc_out, attn_list = model(tok_emb)
        attn_stack = torch.stack(attn_list, dim=0).detach().cpu().numpy()
        if attn_stack.ndim == 5:
            attn_mean = attn_stack.mean(axis=0).mean(axis=1)
        elif attn_stack.ndim == 4:
            attn_mean = attn_stack.mean(axis=0)
        else:
            raise ValueError(f"Unexpected attention shape: {attn_stack.shape}")
        return attn_mean
In [42]:
def run_transformer_pipeline(
    X, y,
    n_bins=7,
    gene_emb_dim=128,
    nhead=4,
    n_layers=2,
    lr=1e-3,
    n_epochs=6,
    batch_size_train=32,
    batch_size_test=64,
    test_size=0.3,
    random_state=42,
    device=device
):
    """
    Full pipeline: binning → embeddings → model → training → attention extraction.

    Returns
    -------
    results : dict with keys:
        'model'         : trained ScBERTToy
        'le'            : fitted LabelEncoder (0-indexed, no gaps)
        'gene_emb'      : nn.Embedding for genes
        'bin_emb'       : nn.Embedding for bins
        'genes'         : list of gene names
        'class_attn'    : dict[class_id] → averaged attention matrix (S, S)
        'class_counts'  : dict[class_id] → number of cells
        'attn_sums'     : np.ndarray (n_test_cells, n_genes) — per-cell attention
        'labels_test'   : np.ndarray of test labels (0-indexed)
        'pooled'        : np.ndarray (n_test_cells, gene_emb_dim) — encoder embeddings
        'train_history' : list of dicts per epoch
    """

    # ── 1. Label encoding ───────────────────────────────────────────────────
    le = LabelEncoder()
    y_encoded = le.fit_transform(y)
    n_classes = len(le.classes_)
    print(f"[pipeline] n_classes={n_classes}, classes={le.classes_}")

    # ── 2. Binning ──────────────────────────────────────────────────────────
    binned = percentile_bin_global(X, n_bins=n_bins)
    actual_max_bin = int(binned.max())
    n_bin_embeddings = actual_max_bin + 1          # safe embedding size
    print(f"[pipeline] bin range: {binned.min()} – {actual_max_bin}  "
          f"(using {n_bin_embeddings} embedding slots)")

    # ── 3. Embeddings ───────────────────────────────────────────────────────
    genes = X.columns.tolist()
    n_genes = len(genes)
    gene_emb = nn.Embedding(n_genes, gene_emb_dim).to(device)
    bin_emb  = nn.Embedding(n_bin_embeddings, gene_emb_dim).to(device)

    # ── 4. DataLoaders ──────────────────────────────────────────────────────
    train_idx, test_idx = train_test_split(
        np.arange(len(y_encoded)),
        test_size=test_size,
        stratify=y_encoded,
        random_state=random_state
    )
    train_ds = BinnedTokenDataset(binned[train_idx], y_encoded[train_idx])
    test_ds  = BinnedTokenDataset(binned[test_idx],  y_encoded[test_idx])
    train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size_test,  shuffle=False)
    print(f"[pipeline] train={len(train_ds)}, test={len(test_ds)}")

    # ── 5. Model ────────────────────────────────────────────────────────────
    encoder    = TinyTransformerEncoder(gene_emb_dim, nhead=nhead, n_layers=n_layers).to(device)
    classifier = ClassifierHead(gene_emb_dim, n_classes).to(device)
    model      = ScBERTToy(encoder, classifier).to(device)
    optimizer  = torch.optim.Adam(
        list(model.parameters()) + list(gene_emb.parameters()) + list(bin_emb.parameters()),
        lr=lr
    )
    criterion = nn.CrossEntropyLoss()

    # ── 6. Training ─────────────────────────────────────────────────────────
    def _train_epoch():
        model.train()
        total, preds_all, labels_all = 0.0, [], []
        for bins_np, labels in train_loader:
            bins_np = np.array(bins_np, copy=True)
            labels  = labels.to(device)
            tok     = tokens_from_bins(bins_np, gene_emb, bin_emb, device)
            optimizer.zero_grad()
            logits, _, _ = model(tok)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total += float(loss.item()) * len(labels)
            preds_all.append(logits.argmax(dim=1).detach().cpu().numpy())
            labels_all.append(labels.detach().cpu().numpy())
        return total / len(train_ds), accuracy_score(
            np.concatenate(labels_all), np.concatenate(preds_all)
        )

    @torch.no_grad()
    def _evaluate():
        model.eval()
        total, preds_all, labels_all = 0.0, [], []
        for bins_np, labels in test_loader:
            bins_np = np.array(bins_np, copy=True)
            labels  = labels.to(device)
            tok     = tokens_from_bins(bins_np, gene_emb, bin_emb, device)
            logits, _, _ = model(tok)
            loss = criterion(logits, labels)
            total += float(loss.item()) * len(labels)
            preds_all.append(logits.argmax(dim=1).detach().cpu().numpy())
            labels_all.append(labels.detach().cpu().numpy())
        return total / len(test_ds), accuracy_score(
            np.concatenate(labels_all), np.concatenate(preds_all)
        )

    train_history = []
    for epoch in range(1, n_epochs + 1):
        tr_loss, tr_acc = _train_epoch()
        te_loss, te_acc = _evaluate()
        train_history.append(dict(epoch=epoch,
                                  tr_loss=tr_loss, tr_acc=tr_acc,
                                  te_loss=te_loss, te_acc=te_acc))
        print(f"  Epoch {epoch:02d} | "
              f"train_loss={tr_loss:.4f} acc={tr_acc:.4f} | "
              f"val_loss={te_loss:.4f} acc={te_acc:.4f}")

    # ── 7. Attention extraction ─────────────────────────────────────────────
    model.eval()
    attn_sums, labels_out, pooled_out = [], [], []

    with torch.no_grad():
        for bins_np, labels in test_loader:
            bins_np = np.array(bins_np, copy=True)
            tok = tokens_from_bins(bins_np, gene_emb, bin_emb, device)
            logits, enc_out, attn_list = model(tok)
            aw       = torch.stack(attn_list, dim=0).cpu().numpy()  # (layers, B, heads, S, S)
            aw_mean  = aw.mean(axis=0).mean(axis=1)                 # (B, S, S)
            received = aw_mean.sum(axis=1)                          # (B, S)
            attn_sums.append(received)
            labels_out.append(labels.numpy())
            pooled_out.append(enc_out.mean(dim=1).cpu().numpy())

    attn_sums  = np.concatenate(attn_sums,  axis=0)
    labels_out = np.concatenate(labels_out, axis=0)
    pooled_out = np.concatenate(pooled_out, axis=0)

    # ── 8. Per-class averaged attention matrices ────────────────────────────
    class_attn, class_counts = average_attention_by_class(
        model, test_loader, gene_emb, bin_emb   # always uses the correct embeddings
    )
    print(f"[pipeline] class_attn keys: {sorted(class_attn.keys())}")
    assert max(class_attn.keys()) < n_classes, \
        f"Label {max(class_attn.keys())} out of bounds for le with {n_classes} classes!"

    return dict(
        model         = model,
        le            = le,
        gene_emb      = gene_emb,
        bin_emb       = bin_emb,
        genes         = genes,
        class_attn    = class_attn,
        class_counts  = class_counts,
        attn_sums     = attn_sums,
        labels_test   = labels_out,
        pooled        = pooled_out,
        train_history = train_history,
    )
In [43]:
# Healthy cells
#res1 = run_transformer_pipeline(X2, y2, n_epochs=6)

# Diabetes cells
res2 = run_transformer_pipeline(X2, y2, n_epochs=6)

# Access anything cleanly — no more le vs le2 confusion
beta_class = res2['le'].transform(["beta cell"])[0]
modules, sub_attn, top_genes, labels = attention_to_gene_modules(
    res2['class_attn'][beta_class],
    res2['genes'],
    top_k=30,
    n_modules=4
)
[pipeline] n_classes=9, classes=['PSC cell' 'acinar cell' 'alpha cell' 'beta cell' 'co-expression cell'
 'delta cell' 'ductal cell' 'gamma cell' 'unclassified endocrine cell']
[pipeline] bin range: 0 – 6  (using 7 embedding slots)
[pipeline] train=769, test=330
  Epoch 01 | train_loss=1.7512 acc=0.3875 | val_loss=1.7062 acc=0.4030
  Epoch 02 | train_loss=1.6894 acc=0.4148 | val_loss=1.6148 acc=0.5788
  Epoch 03 | train_loss=1.5637 acc=0.5293 | val_loss=1.1452 acc=0.6212
  Epoch 04 | train_loss=1.1089 acc=0.6203 | val_loss=0.9440 acc=0.6788
  Epoch 05 | train_loss=0.9092 acc=0.6957 | val_loss=0.8799 acc=0.6727
  Epoch 06 | train_loss=0.8112 acc=0.7321 | val_loss=0.8328 acc=0.7000
[pipeline] class_attn keys: [0, 1, 2, 3, 4, 5, 6, 7, 8]
In [44]:
top_genes
Out[44]:
array(['LCN2|NM_005564', 'SDC4|NM_002999',
       'PMEPA1|NM_199171+NM_199170+NM_020182+NM_001255976+NM_199169',
       'SERINC2|NM_001199038+NM_001199037+NM_178865+NM_018565+NM_001199039',
       'LITAF|NM_004862+NM_001136473+NM_001136472+NR_024320',
       'KRT19|NM_002276', 'GCG|NM_002054',
       'ALDH1A3|NM_000693+NM_001293815', 'CD9|NM_001769',
       'ISYNA1|NR_045573+NM_001170938+NM_016368+NM_001253389+NR_045574',
       'STOM|NM_001270527+NM_004099+NM_001270526+NM_198194+NR_073037',
       'DUOX2|NM_014080', 'RNF123|NM_022064', 'SOD3|NM_003102',
       'DARS2|NM_018122',
       'GPSM1|NM_015597+NM_001145638+NM_001145639+NM_001200003',
       'SLC43A3|NM_199329+NM_014096+NM_001278201+NM_017611+NM_001278206',
       'DNPH1|NM_199184+NM_006443',
       'ERCC_1875:mix1_468.75:mix2|ERCC-00136',
       'DEPDC1B|NM_001145208+NM_018369', 'ANPEP|NM_001150',
       'CHCHD10|NM_213720', 'PPY|NM_002722',
       'ERCC_937.5:mix1_937.5:mix2|ERCC-00009',
       'C10orf53|NM_001042427+NM_182554', 'JUNB|NM_002229',
       'SYNGR2|NM_004710', 'CHGB|NM_001819', 'SCAMP2|NM_005697',
       'NOL4L|NM_080616+NM_001256798'], dtype='<U331')
In [45]:
plot_gene_module_network(sub_attn=sub_attn, top_genes=top_genes,
                         labels=labels, modules=modules,
                         title="Beta-cell gene module network (diabetes)")
No description has been provided for this image
In [46]:
# ── Pull results from diabetes pipeline ────────────────────────────────────
le2          = res2['le']
class_attn2  = res2['class_attn']
class_counts2= res2['class_counts']
genes2       = [g.split('|')[0] for g in res2['genes']]  # strip transcript suffix

# ── Find beta cell class index ──────────────────────────────────────────────
print("Available classes:", le2.classes_)   # confirm exact label spelling

beta_class = le2.transform(["beta cell"])[0]
print(f"Beta cell → class index {beta_class}, n={class_counts2[beta_class]} cells")

# ── Plot ────────────────────────────────────────────────────────────────────
plot_attention_heatmap(
    class_attn2[beta_class],
    genes2,
    top_k=30,
    title=f"Beta cell attention heatmap — diabetes (n={class_counts2[beta_class]})"
)
Available classes: ['PSC cell' 'acinar cell' 'alpha cell' 'beta cell' 'co-expression cell'
 'delta cell' 'ductal cell' 'gamma cell' 'unclassified endocrine cell']
Beta cell → class index 3, n=30 cells
No description has been provided for this image

Extensions¶

  • Try different bin counts.
  • Compare attention-derived genes to marker genes or differential expression.
  • Replace the encoder with Performer for larger gene sets.
In [ ]: