AlphaFold as a Graph Neural Network — Learned Edge Representations¶

GNN Infers Edge Structure from Node Features¶


Architecture overview¶

Input: node features only (one-hot + physicochemical)
          |
          v
  [Node Embedding]  →  h_i, h_j  for all pairs (i,j)
          |
          v
  [EdgeInferenceNet]  →  learned edge features e_ij  +  edge weight w_ij
  (takes [h_i, h_j, seq_sep] — NO distance input)
          |
          v
  [Message Passing Layers]  (weighted by w_ij)
          |
          v
  [Pairwise Readout MLP]  →  contact logits (L, L)

Section 1 — Imports and Setup¶

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    TORCH_AVAILABLE = True
    print(f"PyTorch {torch.__version__} available.")
except ImportError:
    TORCH_AVAILABLE = False
    print("PyTorch not found.")

SEED = 42
np.random.seed(SEED)
if TORCH_AVAILABLE:
    torch.manual_seed(SEED)
PyTorch 2.11.0 available.

Section 2 — Toy Protein Sequence and Node Features¶

Node features: one-hot amino acid identity + 3 physicochemical properties.

In [3]:
sequence = "ACDEFGHIKLMN"
L = len(sequence)

# Physicochemical property lookup (simplified, illustrative values)
# Columns: [hydrophobicity, charge, relative_size]

ALL_AAS = "ACDEFGHIKLMNPQRSTVWY"
aa_to_idx = {aa: i for i, aa in enumerate(ALL_AAS)}

AA_PROPS = {
    'A': [1.8,  0.0, 0.3],   'C': [2.5,  0.0, 0.5],
    'D': [-3.5,-1.0, 0.5],   'E': [-3.5,-1.0, 0.6],
    'F': [2.8,  0.0, 0.9],   'G': [-0.4, 0.0, 0.1],
    'H': [-3.2, 0.5, 0.8],   'I': [4.5,  0.0, 0.7],
    'K': [-3.9, 1.0, 0.8],   'L': [3.8,  0.0, 0.7],
    'M': [1.9,  0.0, 0.8],   'N': [-3.5, 0.0, 0.6],
    'P': [-1.6, 0.0, 0.5],   'Q': [-3.5, 0.0, 0.7],
    'R': [-4.5, 1.0, 1.0],   'S': [-0.8, 0.0, 0.4],
    'T': [-0.7, 0.0, 0.5],   'V': [4.2,  0.0, 0.6],
    'W': [-0.9, 0.0, 1.0],   'Y': [-1.3, 0.0, 0.9],
}

def build_node_features(sequence, aa_to_idx, aa_props):
    """Node features: one-hot (20) + physicochemical (3) = 23 dims."""
    n, n_aa = len(sequence), len(aa_to_idx)
    X = np.zeros((n, n_aa + 3), dtype=np.float32)
    for i, aa in enumerate(sequence):
        X[i, aa_to_idx[aa]] = 1.0
        props = aa_props.get(aa, [0.0, 0.0, 0.5])
        X[i, n_aa + 0] = (props[0] + 4.5) / 9.0   # hydrophobicity → [0,1]
        X[i, n_aa + 1] = (props[1] + 1.0) / 2.0   # charge → [0,1]
        X[i, n_aa + 2] = props[2]                  # size
    return X

X_nodes = build_node_features(sequence, aa_to_idx, AA_PROPS)
print(f"Node feature matrix: {X_nodes.shape}  (L={L} residues × 23 features)")
print("No distance information will be provided to the edge inference module.")
Node feature matrix: (12, 23)  (L=12 residues × 23 features)
No distance information will be provided to the edge inference module.

Section 3 — Ground-Truth Coordinates and Contact Map¶

We generate synthetic helix coordinates and the ground-truth contact map for supervision and evaluation only.
The model never sees these distances during the forward pass.

In [5]:
def make_helix_coords(n_residues, noise=0.5, seed=42):
    """Synthetic alpha-helix Cα coordinates — used ONLY to derive ground-truth labels."""
    rng = np.random.default_rng(seed)
    t   = np.linspace(0, 2 * np.pi * n_residues / 3.6, n_residues)
    r, rise = 2.3, 1.5
    x = r * np.cos(t) + rng.normal(0, noise, n_residues)
    y = r * np.sin(t) + rng.normal(0, noise, n_residues)
    z = rise * np.arange(n_residues) + rng.normal(0, noise, n_residues)
    return np.stack([x, y, z], axis=1)

coords = make_helix_coords(L)

# Pairwise distance matrix — ground truth only, NOT fed to GNN
diff = coords[:, None, :] - coords[None, :, :]
D    = np.sqrt((diff ** 2).sum(axis=-1))

DISTANCE_THRESHOLD = 8.0  # Angstroms

# Ground-truth contact map: D < threshold, |i-j| > 1
sep = np.abs(np.subtract.outer(range(L), range(L)))
contact_map = ((D < DISTANCE_THRESHOLD) & (sep > 1)).astype(float)
np.fill_diagonal(contact_map, 0)

print(f"Ground-truth contacts (|i-j|>1): {int(contact_map.sum()//2)} pairs")
print("These labels supervise training but distances are hidden from the model.")

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(D, cmap='viridis_r')
axes[0].set_title('Pairwise Distance Matrix\n(ground truth — NOT seen by GNN)', fontsize=11)
axes[0].set_xticks(range(L)); axes[0].set_xticklabels(list(sequence))
axes[0].set_yticks(range(L)); axes[0].set_yticklabels(list(sequence))

axes[1].imshow(contact_map, cmap='Blues')
axes[1].set_title('Ground-Truth Contact Map\n(supervision signal)', fontsize=11)
axes[1].set_xticks(range(L)); axes[1].set_xticklabels(list(sequence))
axes[1].set_yticks(range(L)); axes[1].set_yticklabels(list(sequence))

plt.tight_layout()
plt.show()
Ground-truth contacts (|i-j|>1): 27 pairs
These labels supervise training but distances are hidden from the model.
No description has been provided for this image
In [27]:
import numpy as np
import matplotlib.pyplot as plt

import numpy as np
import matplotlib.pyplot as plt

# If you're in Jupyter, this helps 3D plots render inline
# %matplotlib inline

def plot_helix_3d(coords, title="Synthetic alpha-helix coordinates"):
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection="3d")

    x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]

    # Residue trace
    ax.plot(x, y, z, linewidth=2, alpha=0.8)

    # Residue points
    sc = ax.scatter(x, y, z, s=50, c=np.arange(len(coords)), cmap="viridis", depthshade=True)

    # Label residues
    for i, (xi, yi, zi) in enumerate(coords):
        ax.text(xi, yi, zi, f"{i+1}", fontsize=9, ha="center", va="bottom")

    # Connect consecutive residues lightly
    for i in range(len(coords) - 1):
        ax.plot([x[i], x[i+1]], [y[i], y[i+1]], [z[i], z[i+1]], color="gray", alpha=0.35)

    ax.set_title(title)
    ax.set_xlabel("X (Å)")
    ax.set_ylabel("Y (Å)")
    ax.set_zlabel("Z (Å)")

    # Make aspect ratio more balanced
    ranges = np.array([x.max() - x.min(), y.max() - y.min(), z.max() - z.min()])
    max_range = ranges.max()
    mid_x = (x.max() + x.min()) / 2
    mid_y = (y.max() + y.min()) / 2
    mid_z = (z.max() + z.min()) / 2

    ax.set_xlim(mid_x - max_range / 2, mid_x + max_range / 2)
    ax.set_ylim(mid_y - max_range / 2, mid_y + max_range / 2)
    ax.set_zlim(mid_z - max_range / 2, mid_z + max_range / 2)

    cbar = plt.colorbar(sc, ax=ax, pad=0.1)
    cbar.set_label("Residue index")

    plt.tight_layout()
    plt.show()

plot_helix_3d(coords)
No description has been provided for this image

Section 4 — Fully Connected Graph (No Distance Filtering)¶

Because we no longer filter edges by distance threshold, every pair (i, j) with |i−j| > 0 is a candidate edge.
The model itself will learn which pairs are structurally informative by assigning soft edge weights.

The only structural prior we keep is sequence separation — the |i−j|/L scalar — because this is sequence-derived information, not structural.

In [4]:
# Build a fully connected directed edge list (excluding self-loops)
# Each undirected pair (i,j) with i < j is added once; message passing will handle both directions.
all_pairs = [(i, j) for i in range(L) for j in range(i + 1, L)]
edge_index = np.array(all_pairs, dtype=np.int64)   # shape: (n_pairs, 2)

# Sequence-separation feature: the ONLY positional information given to the edge module
seq_sep_feat = np.array(
    [[abs(i - j) / L] for (i, j) in all_pairs],
    dtype=np.float32
)  # shape: (n_pairs, 1)

print(f"Total candidate edges : {len(all_pairs)}  (all pairs, no distance filter)")
print(f"Edge index shape      : {edge_index.shape}")
print(f"Seq-sep feature shape : {seq_sep_feat.shape}  — only positional prior allowed")
print()
print("The GNN will learn WHICH of these edges matter and HOW MUCH,")
print("purely from the amino-acid node features.")
Total candidate edges : 66  (all pairs, no distance filter)
Edge index shape      : (66, 2)
Seq-sep feature shape : (66, 1)  — only positional prior allowed

The GNN will learn WHICH of these edges matter and HOW MUCH,
purely from the amino-acid node features.

Section 5 — Model Architecture¶

Key component: EdgeInferenceNet¶

This module takes the current node embeddings of residues i and j (plus their sequence separation) and outputs:

  1. A learned edge feature vector e_ij (replaces hand-crafted distance features)
  2. A scalar edge weight w_ij ∈ (0, 1) — a soft "should this edge be active?" gate

During message passing, neighbor messages are scaled by w_ij, so the model can effectively suppress uninformative edges.

This is conceptually similar to the attention mechanism in AlphaFold's Evoformer, where pair weights are learned rather than fixed.

In [5]:
if TORCH_AVAILABLE:

    # ------------------------------------------------------------------
    # EdgeInferenceNet: learns edge features and weights from node embeddings
    # ------------------------------------------------------------------
    class EdgeInferenceNet(nn.Module):
        """
        Infers edge representations from pairs of node embeddings.

        Input  : [h_i || h_j || seq_sep]  — concatenated node embeddings + sequence separation
        Output : (edge_feat, edge_weight)
            edge_feat   : (n_edges, edge_dim)  — learned edge feature vector
            edge_weight : (n_edges,)            — soft edge gate in (0, 1)

        """
        def __init__(self, node_dim, edge_dim=16, seq_sep_dim=1):
            super().__init__()
            in_dim = node_dim * 2 + seq_sep_dim  # [h_i, h_j, seq_sep]

            # Shared MLP that maps pair → edge representation
            self.edge_mlp = nn.Sequential(
                nn.Linear(in_dim, 64),
                nn.ReLU(),
                nn.Linear(64, 32),
                nn.ReLU(),
            )
            # Edge feature head: produces the feature vector
            self.feat_head   = nn.Linear(32, edge_dim)
            # Edge weight head: produces a scalar gate per edge
            self.weight_head = nn.Linear(32, 1)

        def forward(self, h, edge_index, seq_sep):
            """
            h         : (L, node_dim)     — current node embeddings
            edge_index: (n_edges, 2)      — pairs (i, j)
            seq_sep   : (n_edges, 1)      — normalized |i-j|/L
            """
            hi = h[edge_index[:, 0]]   # (n_edges, node_dim)
            hj = h[edge_index[:, 1]]   # (n_edges, node_dim)

            # Make symmetric: average forward and reverse pair representations
            pair_fwd = torch.cat([hi, hj, seq_sep], dim=-1)
            pair_rev = torch.cat([hj, hi, seq_sep], dim=-1)
            shared   = (self.edge_mlp(pair_fwd) + self.edge_mlp(pair_rev)) / 2.0

            edge_feat   = self.feat_head(shared)              # (n_edges, edge_dim)
            edge_weight = torch.sigmoid(self.weight_head(shared)).squeeze(-1)  # (n_edges,)
            return edge_feat, edge_weight


    # ------------------------------------------------------------------
    # MessagePassingLayer: weighted message passing using learned edges
    # ------------------------------------------------------------------
    class MessagePassingLayer(nn.Module):
        """
        One message-passing step where edge weights (from EdgeInferenceNet)
        gate the contribution of each neighbor.
        """
        def __init__(self, node_dim, edge_dim, hidden_dim):
            super().__init__()
            self.msg_net = nn.Sequential(
                nn.Linear(node_dim * 2 + edge_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.update_net = nn.Sequential(
                nn.Linear(node_dim + hidden_dim, hidden_dim),
                nn.ReLU(),
            )

        def forward(self, h, edge_index, edge_feat, edge_weight):
            """
            h           : (L, node_dim)
            edge_index  : (n_edges, 2)
            edge_feat   : (n_edges, edge_dim)  — from EdgeInferenceNet
            edge_weight : (n_edges,)            — soft gate from EdgeInferenceNet
            """
            L      = h.shape[0]
            msg_dim = self.msg_net[-1].out_features
            agg    = torch.zeros(L, msg_dim, device=h.device)

            for k, (i, j) in enumerate(edge_index):
                msg_in  = torch.cat([h[i], h[j], edge_feat[k]], dim=-1)
                msg     = self.msg_net(msg_in) * edge_weight[k]  # weight gates the message
                agg[i] += msg
                agg[j] += msg   # undirected: accumulate for both endpoints

            h_new = self.update_net(torch.cat([h, agg], dim=-1))
            return h_new


    # ------------------------------------------------------------------
    # Full model: AlphaFold-inspired GNN with learned edge inference
    # ------------------------------------------------------------------
    class LearnedEdgeGNN(nn.Module):
        """
        Full model pipeline:
          1. Embed raw node features into hidden space
          2. Infer edge features and weights from node embeddings (no distance input)
          3. Run weighted message passing
          4. Pairwise readout: predict contact probability for every (i,j) pair
        """
        def __init__(self, node_in_dim=23, edge_dim=16, hidden_dim=32,
                     n_layers=2, seq_sep_dim=1):
            super().__init__()

            # Step 1: initial node embedding
            self.node_embed = nn.Linear(node_in_dim, hidden_dim)

            # Step 2: edge inference (operates on node embeddings, not raw distances)
            self.edge_net = EdgeInferenceNet(
                node_dim=hidden_dim,
                edge_dim=edge_dim,
                seq_sep_dim=seq_sep_dim
            )

            # Step 3: message passing layers
            self.gnn_layers = nn.ModuleList([
                MessagePassingLayer(hidden_dim, edge_dim, hidden_dim)
                for _ in range(n_layers)
            ])

            # Step 4: pairwise contact readout
            # Input: [h_i || h_j || e_ij] — node pair + learned edge feature
            self.pair_readout = nn.Sequential(
                nn.Linear(hidden_dim * 2 + edge_dim, 64),
                nn.ReLU(),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, 1),
            )

        def forward(self, x_node, edge_index, seq_sep):
            """
            x_node     : (L, node_in_dim)  — raw node features
            edge_index : (n_edges, 2)      — all candidate pairs
            seq_sep    : (n_edges, 1)      — normalized sequence separation
            Returns    : (L, L) contact logits
            """
            L = x_node.shape[0]

            # Step 1: embed nodes
            h = F.relu(self.node_embed(x_node))  # (L, hidden_dim)

            # Step 2: infer edges from node embeddings (iteratively refined)
            # Re-infer edges at each layer so they evolve with the node representations
            for layer in self.gnn_layers:
                edge_feat, edge_weight = self.edge_net(h, edge_index, seq_sep)
                h = layer(h, edge_index, edge_feat, edge_weight)

            # Final edge inference for readout
            edge_feat, _ = self.edge_net(h, edge_index, seq_sep)

            # Step 4: pairwise readout — build (L, L) prediction matrix
            # Fill a full (L, L, edge_dim) edge feature tensor for broadcasting
            edge_dim  = edge_feat.shape[1]
            e_full    = torch.zeros(L, L, edge_dim)
            for k, (i, j) in enumerate(edge_index):
                e_full[i, j] = edge_feat[k]
                e_full[j, i] = edge_feat[k]  # symmetric

            hi = h.unsqueeze(1).expand(L, L, -1)  # (L, L, hidden_dim)
            hj = h.unsqueeze(0).expand(L, L, -1)  # (L, L, hidden_dim)
            pair_in = torch.cat([hi, hj, e_full], dim=-1)  # (L, L, 2*hidden + edge_dim)
            logits  = self.pair_readout(pair_in).squeeze(-1)  # (L, L)

            # Symmetrize: contact(i,j) == contact(j,i)
            logits = (logits + logits.T) / 2.0
            return logits

        def get_learned_edge_weights(self, x_node, edge_index, seq_sep):
            """Utility: extract final edge weights for visualization."""
            h = F.relu(self.node_embed(x_node))
            for layer in self.gnn_layers:
                edge_feat, edge_weight = self.edge_net(h, edge_index, seq_sep)
                h = layer(h, edge_index, edge_feat, edge_weight)
            _, edge_weight = self.edge_net(h, edge_index, seq_sep)
            return edge_weight.detach()

    print("LearnedEdgeGNN defined successfully.")
    print("  - EdgeInferenceNet: learns edge features + weights from node embeddings")
    print("  - MessagePassingLayer: weighted aggregation (no fixed distance input)")
    print("  - Pair readout: contact prediction from final node + edge representations")

else:
    print("PyTorch unavailable — please install torch to run this model.")
LearnedEdgeGNN defined successfully.
  - EdgeInferenceNet: learns edge features + weights from node embeddings
  - MessagePassingLayer: weighted aggregation (no fixed distance input)
  - Pair readout: contact prediction from final node + edge representations

Section 6 — Prepare Tensors and Training¶

In [6]:
if TORCH_AVAILABLE:
    x_node_t   = torch.tensor(X_nodes,     dtype=torch.float32)
    edge_idx_t = torch.tensor(edge_index,  dtype=torch.long)
    seq_sep_t  = torch.tensor(seq_sep_feat, dtype=torch.float32)
    y_t        = torch.tensor(contact_map, dtype=torch.float32)

    # Mask: evaluate only pairs with |i-j| > 1 (exclude trivial sequential neighbors)
    sep_mask = torch.tensor(
        (np.abs(np.subtract.outer(range(L), range(L))) > 1) &
        (np.eye(L) == 0),
        dtype=torch.bool
    )

    print("Tensors prepared:")
    print(f"  x_node   : {x_node_t.shape}")
    print(f"  edge_idx : {edge_idx_t.shape}  (all {len(all_pairs)} pairs)")
    print(f"  seq_sep  : {seq_sep_t.shape}   (only structural prior)")
    print(f"  y        : {y_t.shape}")
    print(f"  Pairs to predict: {sep_mask.sum().item()}")
Tensors prepared:
  x_node   : torch.Size([12, 23])
  edge_idx : torch.Size([66, 2])  (all 66 pairs)
  seq_sep  : torch.Size([66, 1])   (only structural prior)
  y        : torch.Size([12, 12])
  Pairs to predict: 110
In [7]:
if TORCH_AVAILABLE:
    torch.manual_seed(SEED)

    model     = LearnedEdgeGNN(node_in_dim=23, edge_dim=16, hidden_dim=32, n_layers=2)
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    loss_fn   = nn.BCEWithLogitsLoss()

    N_EPOCHS = 300
    losses   = []

    print(f"Training LearnedEdgeGNN for {N_EPOCHS} epochs...")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print()

    for epoch in range(N_EPOCHS):
        model.train()
        optimizer.zero_grad()

        logits = model(x_node_t, edge_idx_t, seq_sep_t)   # (L, L)
        loss   = loss_fn(logits[sep_mask], y_t[sep_mask])

        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if epoch == 0 or (epoch + 1) % 50 == 0:
            print(f"  Epoch {epoch+1:3d} | Loss: {loss.item():.4f}")

    print("\nTraining complete.")
Training LearnedEdgeGNN for 300 epochs...
Model parameters: 26,386

  Epoch   1 | Loss: 0.6935
  Epoch  50 | Loss: 0.0006
  Epoch 100 | Loss: 0.0000
  Epoch 150 | Loss: 0.0000
  Epoch 200 | Loss: 0.0000
  Epoch 250 | Loss: 0.0000
  Epoch 300 | Loss: 0.0000

Training complete.

Section 7 — Predictions and Evaluation¶

In [8]:
if TORCH_AVAILABLE:
    model.eval()
    with torch.no_grad():
        logits_pred = model(x_node_t, edge_idx_t, seq_sep_t)
        probs       = torch.sigmoid(logits_pred).numpy()

    # Mask immediate neighbors for display
    probs_display = probs.copy()
    for i in range(L):
        for j in range(L):
            if abs(i - j) <= 1:
                probs_display[i, j] = 0.0
    np.fill_diagonal(probs_display, 0.0)

    # ---- Visualization ----
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    axes[0].plot(losses, color='steelblue', linewidth=1.5)
    axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('BCE Loss')
    axes[0].set_title('Training Loss\n(learned edge GNN)', fontsize=11)
    axes[0].grid(alpha=0.3)

    im1 = axes[1].imshow(contact_map, cmap='Blues', vmin=0, vmax=1)
    axes[1].set_title('Ground Truth Contact Map', fontsize=11)
    axes[1].set_xticks(range(L)); axes[1].set_xticklabels(list(sequence))
    axes[1].set_yticks(range(L)); axes[1].set_yticklabels(list(sequence))
    plt.colorbar(im1, ax=axes[1], fraction=0.046)

    im2 = axes[2].imshow(probs_display, cmap='Reds', vmin=0, vmax=1)
    axes[2].set_title('Predicted Contact Probabilities\n(no distance input)', fontsize=11)
    axes[2].set_xticks(range(L)); axes[2].set_xticklabels(list(sequence))
    axes[2].set_yticks(range(L)); axes[2].set_yticklabels(list(sequence))
    plt.colorbar(im2, ax=axes[2], fraction=0.046)

    plt.tight_layout()
    plt.show()

    # ---- Precision metric ----
    long_range = np.abs(np.subtract.outer(range(L), range(L))) > 4
    triu_mask  = np.triu(long_range, k=1)
    probs_lr   = probs_display.copy()
    probs_lr[~triu_mask] = -1
    flat_idx   = np.argsort(probs_lr.ravel())[::-1]
    top_k      = min(L // 2, int(triu_mask.sum()))
    top_pairs  = [(f // L, f % L) for f in flat_idx[:top_k] if probs_lr.ravel()[f] >= 0]

    if top_pairs:
        tp        = sum(contact_map[i, j] for i, j in top_pairs)
        precision = tp / len(top_pairs)
        print(f"Top-{len(top_pairs)} long-range precision: {precision:.1%}")
        print(f"(Random baseline: ~{contact_map[triu_mask].mean():.1%})")
No description has been provided for this image
Top-6 long-range precision: 0.0%
(Random baseline: ~0.0%)

Section 8 — Visualizing Learned Edge Weights¶

This is unique to our revised model. We can inspect which residue pairs the model assigned high edge weights to — i.e., which connections it decided were most informative — purely from sequence features.

If the model has learned meaningful structure, high-weight edges should correlate with short distances (true contacts).

In [21]:
if TORCH_AVAILABLE:
    model.eval()
    with torch.no_grad():
        learned_weights = model.get_learned_edge_weights(
            x_node_t, edge_idx_t, seq_sep_t
        ).numpy()  # (n_pairs,)

    # Expand weights into a symmetric (L, L) matrix for visualization
    W_matrix = np.zeros((L, L))
    for k, (i, j) in enumerate(edge_index):
        W_matrix[i, j] = learned_weights[k]
        W_matrix[j, i] = learned_weights[k]
In [23]:
if TORCH_AVAILABLE:
    model.eval()
    with torch.no_grad():
        learned_weights = model.get_learned_edge_weights(
            x_node_t, edge_idx_t, seq_sep_t
        ).numpy()  # (n_pairs,)

    # Expand weights into a symmetric (L, L) matrix for visualization
    W_matrix = np.zeros((L, L))
    for k, (i, j) in enumerate(edge_index):
        W_matrix[i, j] = learned_weights[k]
        W_matrix[j, i] = learned_weights[k]

    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    # Learned edge weights
    im0 = axes[0].imshow(W_matrix, cmap="viridis", vmin=0.40, vmax=0.47)
    axes[0].set_title('Learned Edge Weights\n(from EdgeInferenceNet)', fontsize=11)
    axes[0].set_xticks(range(L)); axes[0].set_xticklabels(list(sequence))
    axes[0].set_yticks(range(L)); axes[0].set_yticklabels(list(sequence))
    plt.colorbar(im0, ax=axes[0], fraction=0.046, label='w_ij')

    # Ground-truth distance (inverted for comparison: closer = brighter)
    D_inv = 1.0 - D / D.max()
    im1 = axes[1].imshow(D_inv, cmap='hot', vmin=0, vmax=1)
    axes[1].set_title('Inverted Distance Matrix\n(closer = brighter, for comparison)', fontsize=11)
    axes[1].set_xticks(range(L)); axes[1].set_xticklabels(list(sequence))
    axes[1].set_yticks(range(L)); axes[1].set_yticklabels(list(sequence))
    plt.colorbar(im1, ax=axes[1], fraction=0.046, label='1 - d/d_max')

    # Correlation: learned weight vs true distance per pair
    axes[2].scatter(
        D[np.triu_indices(L, k=2)],
        W_matrix[np.triu_indices(L, k=2)],
        alpha=0.6, s=40, color='steelblue'
    )
    axes[2].set_xlabel('True Pairwise Distance (Å)')
    axes[2].set_ylabel('Learned Edge Weight w_ij')
    axes[2].set_title('Edge Weight vs True Distance\n(model never saw distances!)', fontsize=11)
    axes[2].grid(alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Correlation coefficient
    d_vals = D[np.triu_indices(L, k=2)]
    w_vals = W_matrix[np.triu_indices(L, k=2)]
    corr   = np.corrcoef(-d_vals, w_vals)[0, 1]  # negative d because closer = higher weight
    print(f"Correlation between edge weight and closeness (-distance): r = {corr:.3f}")
    print("A positive r means the model learned to upweight spatially close pairs.")
No description has been provided for this image
Correlation between edge weight and closeness (-distance): r = 0.964
A positive r means the model learned to upweight spatially close pairs.

Section 9 — Comparison: Original vs Learned-Edge Model¶

Our model shows the same principle at toy scale: given only amino-acid identity and position in the chain, a GNN can learn which residue pairs are likely to be spatially proximal.


Discussion Questions¶

Q1. The EdgeInferenceNet is re-applied at each GNN layer. Why might it help to re-infer edges after each round of message passing rather than computing them once at the start?

Q2. Our only structural prior is sequence separation |i−j|/L. What other sequence-derived features could you add that don't require knowing the 3D structure?

Q3. Inspect the scatter plot of learned edge weight vs. true distance. Is there a negative correlation? What would it mean if there were no correlation?

Q4. The model uses a fully connected graph (all L×L pairs). For a real protein of length L=500, how many candidate edges would that be? What computational strategies could you use to make this tractable?

Q5. AlphaFold's Evoformer uses triangle attention — pair (i,j) aggregates from all pairs (i,k) and (k,j). How would you extend our EdgeInferenceNet to incorporate this triangular reasoning?

Answer to Q2¶

Mutual information between positions i and j across a multiple sequence alignment (MSA) — if two positions always mutate together across species, they are likely spatially close.

Answer to Q4¶

Axial attention (what AlphaFold actually does) Rather than full L²×L² attention over all pairs, AlphaFold's Evoformer applies attention separately along rows and columns of the pair matrix — O(L²) per operation instead of O(L⁴). This is what makes the full Evoformer tractable at lengths up to ~2,500 residues.

"If residue i is close to k, and k is close to j, then i is probably close to j."¶

In [ ]: