AlphaFold as a Graph Neural Network — Learned Edge Representations¶
GNN Infers Edge Structure from Node Features¶
Architecture overview¶
Input: node features only (one-hot + physicochemical)
|
v
[Node Embedding] → h_i, h_j for all pairs (i,j)
|
v
[EdgeInferenceNet] → learned edge features e_ij + edge weight w_ij
(takes [h_i, h_j, seq_sep] — NO distance input)
|
v
[Message Passing Layers] (weighted by w_ij)
|
v
[Pairwise Readout MLP] → contact logits (L, L)
Section 1 — Imports and Setup¶
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
try:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
TORCH_AVAILABLE = True
print(f"PyTorch {torch.__version__} available.")
except ImportError:
TORCH_AVAILABLE = False
print("PyTorch not found.")
SEED = 42
np.random.seed(SEED)
if TORCH_AVAILABLE:
torch.manual_seed(SEED)
PyTorch 2.11.0 available.
Section 2 — Toy Protein Sequence and Node Features¶
Node features: one-hot amino acid identity + 3 physicochemical properties.
sequence = "ACDEFGHIKLMN"
L = len(sequence)
# Physicochemical property lookup (simplified, illustrative values)
# Columns: [hydrophobicity, charge, relative_size]
ALL_AAS = "ACDEFGHIKLMNPQRSTVWY"
aa_to_idx = {aa: i for i, aa in enumerate(ALL_AAS)}
AA_PROPS = {
'A': [1.8, 0.0, 0.3], 'C': [2.5, 0.0, 0.5],
'D': [-3.5,-1.0, 0.5], 'E': [-3.5,-1.0, 0.6],
'F': [2.8, 0.0, 0.9], 'G': [-0.4, 0.0, 0.1],
'H': [-3.2, 0.5, 0.8], 'I': [4.5, 0.0, 0.7],
'K': [-3.9, 1.0, 0.8], 'L': [3.8, 0.0, 0.7],
'M': [1.9, 0.0, 0.8], 'N': [-3.5, 0.0, 0.6],
'P': [-1.6, 0.0, 0.5], 'Q': [-3.5, 0.0, 0.7],
'R': [-4.5, 1.0, 1.0], 'S': [-0.8, 0.0, 0.4],
'T': [-0.7, 0.0, 0.5], 'V': [4.2, 0.0, 0.6],
'W': [-0.9, 0.0, 1.0], 'Y': [-1.3, 0.0, 0.9],
}
def build_node_features(sequence, aa_to_idx, aa_props):
"""Node features: one-hot (20) + physicochemical (3) = 23 dims."""
n, n_aa = len(sequence), len(aa_to_idx)
X = np.zeros((n, n_aa + 3), dtype=np.float32)
for i, aa in enumerate(sequence):
X[i, aa_to_idx[aa]] = 1.0
props = aa_props.get(aa, [0.0, 0.0, 0.5])
X[i, n_aa + 0] = (props[0] + 4.5) / 9.0 # hydrophobicity → [0,1]
X[i, n_aa + 1] = (props[1] + 1.0) / 2.0 # charge → [0,1]
X[i, n_aa + 2] = props[2] # size
return X
X_nodes = build_node_features(sequence, aa_to_idx, AA_PROPS)
print(f"Node feature matrix: {X_nodes.shape} (L={L} residues × 23 features)")
print("No distance information will be provided to the edge inference module.")
Node feature matrix: (12, 23) (L=12 residues × 23 features) No distance information will be provided to the edge inference module.
Section 3 — Ground-Truth Coordinates and Contact Map¶
We generate synthetic helix coordinates and the ground-truth contact map for supervision and evaluation only.
The model never sees these distances during the forward pass.
def make_helix_coords(n_residues, noise=0.5, seed=42):
"""Synthetic alpha-helix Cα coordinates — used ONLY to derive ground-truth labels."""
rng = np.random.default_rng(seed)
t = np.linspace(0, 2 * np.pi * n_residues / 3.6, n_residues)
r, rise = 2.3, 1.5
x = r * np.cos(t) + rng.normal(0, noise, n_residues)
y = r * np.sin(t) + rng.normal(0, noise, n_residues)
z = rise * np.arange(n_residues) + rng.normal(0, noise, n_residues)
return np.stack([x, y, z], axis=1)
coords = make_helix_coords(L)
# Pairwise distance matrix — ground truth only, NOT fed to GNN
diff = coords[:, None, :] - coords[None, :, :]
D = np.sqrt((diff ** 2).sum(axis=-1))
DISTANCE_THRESHOLD = 8.0 # Angstroms
# Ground-truth contact map: D < threshold, |i-j| > 1
sep = np.abs(np.subtract.outer(range(L), range(L)))
contact_map = ((D < DISTANCE_THRESHOLD) & (sep > 1)).astype(float)
np.fill_diagonal(contact_map, 0)
print(f"Ground-truth contacts (|i-j|>1): {int(contact_map.sum()//2)} pairs")
print("These labels supervise training but distances are hidden from the model.")
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(D, cmap='viridis_r')
axes[0].set_title('Pairwise Distance Matrix\n(ground truth — NOT seen by GNN)', fontsize=11)
axes[0].set_xticks(range(L)); axes[0].set_xticklabels(list(sequence))
axes[0].set_yticks(range(L)); axes[0].set_yticklabels(list(sequence))
axes[1].imshow(contact_map, cmap='Blues')
axes[1].set_title('Ground-Truth Contact Map\n(supervision signal)', fontsize=11)
axes[1].set_xticks(range(L)); axes[1].set_xticklabels(list(sequence))
axes[1].set_yticks(range(L)); axes[1].set_yticklabels(list(sequence))
plt.tight_layout()
plt.show()
Ground-truth contacts (|i-j|>1): 27 pairs These labels supervise training but distances are hidden from the model.
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
# If you're in Jupyter, this helps 3D plots render inline
# %matplotlib inline
def plot_helix_3d(coords, title="Synthetic alpha-helix coordinates"):
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d")
x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
# Residue trace
ax.plot(x, y, z, linewidth=2, alpha=0.8)
# Residue points
sc = ax.scatter(x, y, z, s=50, c=np.arange(len(coords)), cmap="viridis", depthshade=True)
# Label residues
for i, (xi, yi, zi) in enumerate(coords):
ax.text(xi, yi, zi, f"{i+1}", fontsize=9, ha="center", va="bottom")
# Connect consecutive residues lightly
for i in range(len(coords) - 1):
ax.plot([x[i], x[i+1]], [y[i], y[i+1]], [z[i], z[i+1]], color="gray", alpha=0.35)
ax.set_title(title)
ax.set_xlabel("X (Å)")
ax.set_ylabel("Y (Å)")
ax.set_zlabel("Z (Å)")
# Make aspect ratio more balanced
ranges = np.array([x.max() - x.min(), y.max() - y.min(), z.max() - z.min()])
max_range = ranges.max()
mid_x = (x.max() + x.min()) / 2
mid_y = (y.max() + y.min()) / 2
mid_z = (z.max() + z.min()) / 2
ax.set_xlim(mid_x - max_range / 2, mid_x + max_range / 2)
ax.set_ylim(mid_y - max_range / 2, mid_y + max_range / 2)
ax.set_zlim(mid_z - max_range / 2, mid_z + max_range / 2)
cbar = plt.colorbar(sc, ax=ax, pad=0.1)
cbar.set_label("Residue index")
plt.tight_layout()
plt.show()
plot_helix_3d(coords)
Section 4 — Fully Connected Graph (No Distance Filtering)¶
Because we no longer filter edges by distance threshold, every pair (i, j) with |i−j| > 0 is a candidate edge.
The model itself will learn which pairs are structurally informative by assigning soft edge weights.
The only structural prior we keep is sequence separation — the |i−j|/L scalar — because this is sequence-derived information, not structural.
# Build a fully connected directed edge list (excluding self-loops)
# Each undirected pair (i,j) with i < j is added once; message passing will handle both directions.
all_pairs = [(i, j) for i in range(L) for j in range(i + 1, L)]
edge_index = np.array(all_pairs, dtype=np.int64) # shape: (n_pairs, 2)
# Sequence-separation feature: the ONLY positional information given to the edge module
seq_sep_feat = np.array(
[[abs(i - j) / L] for (i, j) in all_pairs],
dtype=np.float32
) # shape: (n_pairs, 1)
print(f"Total candidate edges : {len(all_pairs)} (all pairs, no distance filter)")
print(f"Edge index shape : {edge_index.shape}")
print(f"Seq-sep feature shape : {seq_sep_feat.shape} — only positional prior allowed")
print()
print("The GNN will learn WHICH of these edges matter and HOW MUCH,")
print("purely from the amino-acid node features.")
Total candidate edges : 66 (all pairs, no distance filter) Edge index shape : (66, 2) Seq-sep feature shape : (66, 1) — only positional prior allowed The GNN will learn WHICH of these edges matter and HOW MUCH, purely from the amino-acid node features.
Section 5 — Model Architecture¶
Key component: EdgeInferenceNet¶
This module takes the current node embeddings of residues i and j (plus their sequence separation) and outputs:
- A learned edge feature vector
e_ij(replaces hand-crafted distance features) - A scalar edge weight
w_ij ∈ (0, 1)— a soft "should this edge be active?" gate
During message passing, neighbor messages are scaled by w_ij, so the model can effectively suppress uninformative edges.
This is conceptually similar to the attention mechanism in AlphaFold's Evoformer, where pair weights are learned rather than fixed.
if TORCH_AVAILABLE:
# ------------------------------------------------------------------
# EdgeInferenceNet: learns edge features and weights from node embeddings
# ------------------------------------------------------------------
class EdgeInferenceNet(nn.Module):
"""
Infers edge representations from pairs of node embeddings.
Input : [h_i || h_j || seq_sep] — concatenated node embeddings + sequence separation
Output : (edge_feat, edge_weight)
edge_feat : (n_edges, edge_dim) — learned edge feature vector
edge_weight : (n_edges,) — soft edge gate in (0, 1)
"""
def __init__(self, node_dim, edge_dim=16, seq_sep_dim=1):
super().__init__()
in_dim = node_dim * 2 + seq_sep_dim # [h_i, h_j, seq_sep]
# Shared MLP that maps pair → edge representation
self.edge_mlp = nn.Sequential(
nn.Linear(in_dim, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
)
# Edge feature head: produces the feature vector
self.feat_head = nn.Linear(32, edge_dim)
# Edge weight head: produces a scalar gate per edge
self.weight_head = nn.Linear(32, 1)
def forward(self, h, edge_index, seq_sep):
"""
h : (L, node_dim) — current node embeddings
edge_index: (n_edges, 2) — pairs (i, j)
seq_sep : (n_edges, 1) — normalized |i-j|/L
"""
hi = h[edge_index[:, 0]] # (n_edges, node_dim)
hj = h[edge_index[:, 1]] # (n_edges, node_dim)
# Make symmetric: average forward and reverse pair representations
pair_fwd = torch.cat([hi, hj, seq_sep], dim=-1)
pair_rev = torch.cat([hj, hi, seq_sep], dim=-1)
shared = (self.edge_mlp(pair_fwd) + self.edge_mlp(pair_rev)) / 2.0
edge_feat = self.feat_head(shared) # (n_edges, edge_dim)
edge_weight = torch.sigmoid(self.weight_head(shared)).squeeze(-1) # (n_edges,)
return edge_feat, edge_weight
# ------------------------------------------------------------------
# MessagePassingLayer: weighted message passing using learned edges
# ------------------------------------------------------------------
class MessagePassingLayer(nn.Module):
"""
One message-passing step where edge weights (from EdgeInferenceNet)
gate the contribution of each neighbor.
"""
def __init__(self, node_dim, edge_dim, hidden_dim):
super().__init__()
self.msg_net = nn.Sequential(
nn.Linear(node_dim * 2 + edge_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
self.update_net = nn.Sequential(
nn.Linear(node_dim + hidden_dim, hidden_dim),
nn.ReLU(),
)
def forward(self, h, edge_index, edge_feat, edge_weight):
"""
h : (L, node_dim)
edge_index : (n_edges, 2)
edge_feat : (n_edges, edge_dim) — from EdgeInferenceNet
edge_weight : (n_edges,) — soft gate from EdgeInferenceNet
"""
L = h.shape[0]
msg_dim = self.msg_net[-1].out_features
agg = torch.zeros(L, msg_dim, device=h.device)
for k, (i, j) in enumerate(edge_index):
msg_in = torch.cat([h[i], h[j], edge_feat[k]], dim=-1)
msg = self.msg_net(msg_in) * edge_weight[k] # weight gates the message
agg[i] += msg
agg[j] += msg # undirected: accumulate for both endpoints
h_new = self.update_net(torch.cat([h, agg], dim=-1))
return h_new
# ------------------------------------------------------------------
# Full model: AlphaFold-inspired GNN with learned edge inference
# ------------------------------------------------------------------
class LearnedEdgeGNN(nn.Module):
"""
Full model pipeline:
1. Embed raw node features into hidden space
2. Infer edge features and weights from node embeddings (no distance input)
3. Run weighted message passing
4. Pairwise readout: predict contact probability for every (i,j) pair
"""
def __init__(self, node_in_dim=23, edge_dim=16, hidden_dim=32,
n_layers=2, seq_sep_dim=1):
super().__init__()
# Step 1: initial node embedding
self.node_embed = nn.Linear(node_in_dim, hidden_dim)
# Step 2: edge inference (operates on node embeddings, not raw distances)
self.edge_net = EdgeInferenceNet(
node_dim=hidden_dim,
edge_dim=edge_dim,
seq_sep_dim=seq_sep_dim
)
# Step 3: message passing layers
self.gnn_layers = nn.ModuleList([
MessagePassingLayer(hidden_dim, edge_dim, hidden_dim)
for _ in range(n_layers)
])
# Step 4: pairwise contact readout
# Input: [h_i || h_j || e_ij] — node pair + learned edge feature
self.pair_readout = nn.Sequential(
nn.Linear(hidden_dim * 2 + edge_dim, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
def forward(self, x_node, edge_index, seq_sep):
"""
x_node : (L, node_in_dim) — raw node features
edge_index : (n_edges, 2) — all candidate pairs
seq_sep : (n_edges, 1) — normalized sequence separation
Returns : (L, L) contact logits
"""
L = x_node.shape[0]
# Step 1: embed nodes
h = F.relu(self.node_embed(x_node)) # (L, hidden_dim)
# Step 2: infer edges from node embeddings (iteratively refined)
# Re-infer edges at each layer so they evolve with the node representations
for layer in self.gnn_layers:
edge_feat, edge_weight = self.edge_net(h, edge_index, seq_sep)
h = layer(h, edge_index, edge_feat, edge_weight)
# Final edge inference for readout
edge_feat, _ = self.edge_net(h, edge_index, seq_sep)
# Step 4: pairwise readout — build (L, L) prediction matrix
# Fill a full (L, L, edge_dim) edge feature tensor for broadcasting
edge_dim = edge_feat.shape[1]
e_full = torch.zeros(L, L, edge_dim)
for k, (i, j) in enumerate(edge_index):
e_full[i, j] = edge_feat[k]
e_full[j, i] = edge_feat[k] # symmetric
hi = h.unsqueeze(1).expand(L, L, -1) # (L, L, hidden_dim)
hj = h.unsqueeze(0).expand(L, L, -1) # (L, L, hidden_dim)
pair_in = torch.cat([hi, hj, e_full], dim=-1) # (L, L, 2*hidden + edge_dim)
logits = self.pair_readout(pair_in).squeeze(-1) # (L, L)
# Symmetrize: contact(i,j) == contact(j,i)
logits = (logits + logits.T) / 2.0
return logits
def get_learned_edge_weights(self, x_node, edge_index, seq_sep):
"""Utility: extract final edge weights for visualization."""
h = F.relu(self.node_embed(x_node))
for layer in self.gnn_layers:
edge_feat, edge_weight = self.edge_net(h, edge_index, seq_sep)
h = layer(h, edge_index, edge_feat, edge_weight)
_, edge_weight = self.edge_net(h, edge_index, seq_sep)
return edge_weight.detach()
print("LearnedEdgeGNN defined successfully.")
print(" - EdgeInferenceNet: learns edge features + weights from node embeddings")
print(" - MessagePassingLayer: weighted aggregation (no fixed distance input)")
print(" - Pair readout: contact prediction from final node + edge representations")
else:
print("PyTorch unavailable — please install torch to run this model.")
LearnedEdgeGNN defined successfully. - EdgeInferenceNet: learns edge features + weights from node embeddings - MessagePassingLayer: weighted aggregation (no fixed distance input) - Pair readout: contact prediction from final node + edge representations
Section 6 — Prepare Tensors and Training¶
if TORCH_AVAILABLE:
x_node_t = torch.tensor(X_nodes, dtype=torch.float32)
edge_idx_t = torch.tensor(edge_index, dtype=torch.long)
seq_sep_t = torch.tensor(seq_sep_feat, dtype=torch.float32)
y_t = torch.tensor(contact_map, dtype=torch.float32)
# Mask: evaluate only pairs with |i-j| > 1 (exclude trivial sequential neighbors)
sep_mask = torch.tensor(
(np.abs(np.subtract.outer(range(L), range(L))) > 1) &
(np.eye(L) == 0),
dtype=torch.bool
)
print("Tensors prepared:")
print(f" x_node : {x_node_t.shape}")
print(f" edge_idx : {edge_idx_t.shape} (all {len(all_pairs)} pairs)")
print(f" seq_sep : {seq_sep_t.shape} (only structural prior)")
print(f" y : {y_t.shape}")
print(f" Pairs to predict: {sep_mask.sum().item()}")
Tensors prepared: x_node : torch.Size([12, 23]) edge_idx : torch.Size([66, 2]) (all 66 pairs) seq_sep : torch.Size([66, 1]) (only structural prior) y : torch.Size([12, 12]) Pairs to predict: 110
if TORCH_AVAILABLE:
torch.manual_seed(SEED)
model = LearnedEdgeGNN(node_in_dim=23, edge_dim=16, hidden_dim=32, n_layers=2)
optimizer = optim.Adam(model.parameters(), lr=1e-2)
loss_fn = nn.BCEWithLogitsLoss()
N_EPOCHS = 300
losses = []
print(f"Training LearnedEdgeGNN for {N_EPOCHS} epochs...")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print()
for epoch in range(N_EPOCHS):
model.train()
optimizer.zero_grad()
logits = model(x_node_t, edge_idx_t, seq_sep_t) # (L, L)
loss = loss_fn(logits[sep_mask], y_t[sep_mask])
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch == 0 or (epoch + 1) % 50 == 0:
print(f" Epoch {epoch+1:3d} | Loss: {loss.item():.4f}")
print("\nTraining complete.")
Training LearnedEdgeGNN for 300 epochs... Model parameters: 26,386 Epoch 1 | Loss: 0.6935 Epoch 50 | Loss: 0.0006 Epoch 100 | Loss: 0.0000 Epoch 150 | Loss: 0.0000 Epoch 200 | Loss: 0.0000 Epoch 250 | Loss: 0.0000 Epoch 300 | Loss: 0.0000 Training complete.
Section 7 — Predictions and Evaluation¶
if TORCH_AVAILABLE:
model.eval()
with torch.no_grad():
logits_pred = model(x_node_t, edge_idx_t, seq_sep_t)
probs = torch.sigmoid(logits_pred).numpy()
# Mask immediate neighbors for display
probs_display = probs.copy()
for i in range(L):
for j in range(L):
if abs(i - j) <= 1:
probs_display[i, j] = 0.0
np.fill_diagonal(probs_display, 0.0)
# ---- Visualization ----
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(losses, color='steelblue', linewidth=1.5)
axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('BCE Loss')
axes[0].set_title('Training Loss\n(learned edge GNN)', fontsize=11)
axes[0].grid(alpha=0.3)
im1 = axes[1].imshow(contact_map, cmap='Blues', vmin=0, vmax=1)
axes[1].set_title('Ground Truth Contact Map', fontsize=11)
axes[1].set_xticks(range(L)); axes[1].set_xticklabels(list(sequence))
axes[1].set_yticks(range(L)); axes[1].set_yticklabels(list(sequence))
plt.colorbar(im1, ax=axes[1], fraction=0.046)
im2 = axes[2].imshow(probs_display, cmap='Reds', vmin=0, vmax=1)
axes[2].set_title('Predicted Contact Probabilities\n(no distance input)', fontsize=11)
axes[2].set_xticks(range(L)); axes[2].set_xticklabels(list(sequence))
axes[2].set_yticks(range(L)); axes[2].set_yticklabels(list(sequence))
plt.colorbar(im2, ax=axes[2], fraction=0.046)
plt.tight_layout()
plt.show()
# ---- Precision metric ----
long_range = np.abs(np.subtract.outer(range(L), range(L))) > 4
triu_mask = np.triu(long_range, k=1)
probs_lr = probs_display.copy()
probs_lr[~triu_mask] = -1
flat_idx = np.argsort(probs_lr.ravel())[::-1]
top_k = min(L // 2, int(triu_mask.sum()))
top_pairs = [(f // L, f % L) for f in flat_idx[:top_k] if probs_lr.ravel()[f] >= 0]
if top_pairs:
tp = sum(contact_map[i, j] for i, j in top_pairs)
precision = tp / len(top_pairs)
print(f"Top-{len(top_pairs)} long-range precision: {precision:.1%}")
print(f"(Random baseline: ~{contact_map[triu_mask].mean():.1%})")
Top-6 long-range precision: 0.0% (Random baseline: ~0.0%)
Section 8 — Visualizing Learned Edge Weights¶
This is unique to our revised model. We can inspect which residue pairs the model assigned high edge weights to — i.e., which connections it decided were most informative — purely from sequence features.
If the model has learned meaningful structure, high-weight edges should correlate with short distances (true contacts).
if TORCH_AVAILABLE:
model.eval()
with torch.no_grad():
learned_weights = model.get_learned_edge_weights(
x_node_t, edge_idx_t, seq_sep_t
).numpy() # (n_pairs,)
# Expand weights into a symmetric (L, L) matrix for visualization
W_matrix = np.zeros((L, L))
for k, (i, j) in enumerate(edge_index):
W_matrix[i, j] = learned_weights[k]
W_matrix[j, i] = learned_weights[k]
if TORCH_AVAILABLE:
model.eval()
with torch.no_grad():
learned_weights = model.get_learned_edge_weights(
x_node_t, edge_idx_t, seq_sep_t
).numpy() # (n_pairs,)
# Expand weights into a symmetric (L, L) matrix for visualization
W_matrix = np.zeros((L, L))
for k, (i, j) in enumerate(edge_index):
W_matrix[i, j] = learned_weights[k]
W_matrix[j, i] = learned_weights[k]
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Learned edge weights
im0 = axes[0].imshow(W_matrix, cmap="viridis", vmin=0.40, vmax=0.47)
axes[0].set_title('Learned Edge Weights\n(from EdgeInferenceNet)', fontsize=11)
axes[0].set_xticks(range(L)); axes[0].set_xticklabels(list(sequence))
axes[0].set_yticks(range(L)); axes[0].set_yticklabels(list(sequence))
plt.colorbar(im0, ax=axes[0], fraction=0.046, label='w_ij')
# Ground-truth distance (inverted for comparison: closer = brighter)
D_inv = 1.0 - D / D.max()
im1 = axes[1].imshow(D_inv, cmap='hot', vmin=0, vmax=1)
axes[1].set_title('Inverted Distance Matrix\n(closer = brighter, for comparison)', fontsize=11)
axes[1].set_xticks(range(L)); axes[1].set_xticklabels(list(sequence))
axes[1].set_yticks(range(L)); axes[1].set_yticklabels(list(sequence))
plt.colorbar(im1, ax=axes[1], fraction=0.046, label='1 - d/d_max')
# Correlation: learned weight vs true distance per pair
axes[2].scatter(
D[np.triu_indices(L, k=2)],
W_matrix[np.triu_indices(L, k=2)],
alpha=0.6, s=40, color='steelblue'
)
axes[2].set_xlabel('True Pairwise Distance (Å)')
axes[2].set_ylabel('Learned Edge Weight w_ij')
axes[2].set_title('Edge Weight vs True Distance\n(model never saw distances!)', fontsize=11)
axes[2].grid(alpha=0.3)
plt.tight_layout()
plt.show()
# Correlation coefficient
d_vals = D[np.triu_indices(L, k=2)]
w_vals = W_matrix[np.triu_indices(L, k=2)]
corr = np.corrcoef(-d_vals, w_vals)[0, 1] # negative d because closer = higher weight
print(f"Correlation between edge weight and closeness (-distance): r = {corr:.3f}")
print("A positive r means the model learned to upweight spatially close pairs.")
Correlation between edge weight and closeness (-distance): r = 0.964 A positive r means the model learned to upweight spatially close pairs.
Section 9 — Comparison: Original vs Learned-Edge Model¶
Our model shows the same principle at toy scale: given only amino-acid identity and position in the chain, a GNN can learn which residue pairs are likely to be spatially proximal.
Discussion Questions¶
Q1. The EdgeInferenceNet is re-applied at each GNN layer. Why might it help to re-infer edges after each round of message passing rather than computing them once at the start?
Q2. Our only structural prior is sequence separation |i−j|/L. What other sequence-derived features could you add that don't require knowing the 3D structure?
Q3. Inspect the scatter plot of learned edge weight vs. true distance. Is there a negative correlation? What would it mean if there were no correlation?
Q4. The model uses a fully connected graph (all L×L pairs). For a real protein of length L=500, how many candidate edges would that be? What computational strategies could you use to make this tractable?
Q5. AlphaFold's Evoformer uses triangle attention — pair (i,j) aggregates from all pairs (i,k) and (k,j). How would you extend our EdgeInferenceNet to incorporate this triangular reasoning?
Answer to Q2¶
Mutual information between positions i and j across a multiple sequence alignment (MSA) — if two positions always mutate together across species, they are likely spatially close.
Answer to Q4¶
Axial attention (what AlphaFold actually does) Rather than full L²×L² attention over all pairs, AlphaFold's Evoformer applies attention separately along rows and columns of the pair matrix — O(L²) per operation instead of O(L⁴). This is what makes the full Evoformer tractable at lengths up to ~2,500 residues.
"If residue i is close to k, and k is close to j, then i is probably close to j."¶