Homework 4 Solution Notebook: DeepBind-style 1D CNN for Motif Discovery¶

This notebook contains worked code and concise answers for the six questions in HW4. It keeps the original synthetic-sequence setup, but adds experiments for:

  • parameter counting and hyperparameters,
  • kernel-size sweeps,
  • motif recovery from convolution filters,
  • noise / motif-probability robustness,
  • multi-motif specialization,
  • motif interaction experiments.

Run the notebook top-to-bottom. The exact scores may vary slightly with randomness, but the trends should be stable.

In [53]:
from IPython.display import Image, display

display(Image(filename="sequence_properties_1.jpg"))
No description has been provided for this image
In [54]:
display(Image(filename="sequence_properties_2.jpg"))
No description has been provided for this image
In [55]:
display(Image(filename="homotypic_motif_density_localization.jpg"))
No description has been provided for this image
In [57]:
display(Image(filename="homotypic_motif_density_localization_task.jpg"))
No description has been provided for this image
In [58]:
display(Image(filename="dragonn_and_pssm.jpg"))
No description has been provided for this image

1. Setup¶

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import pandas as pd
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, auc
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, losses, callbacks

np.random.seed(42)
random.seed(42)
tf.random.set_seed(42)

BASES = ['A','C','G','T']
_base_to_idx = {b:i for i,b in enumerate(BASES)}


def one_hot_encode_seqs(seqs):
    """Encode a list of uppercase DNA sequences (A,C,G,T) to shape (N, L, 4)."""
    N = len(seqs)
    L = len(seqs[0])
    arr = np.zeros((N, L, 4), dtype=np.float32)
    for i, s in enumerate(seqs):
        for j, ch in enumerate(s):
            arr[i, j, _base_to_idx.get(ch, 0)] = 1.0
    return arr


def seqs_from_onehot(X):
    seqs = []
    for i in range(X.shape[0]):
        seqs.append(''.join(BASES[int(np.argmax(X[i, j]))] for j in range(X.shape[1])))
    return seqs


def reverse_complement(seq):
    comp = str.maketrans('ACGT', 'TGCA')
    return seq.translate(comp)[::-1]


def make_pwm_from_motif(motif):
    """Convenience PWM for an exact motif string."""
    pwm = np.zeros((len(motif), 4), dtype=float)
    for i, ch in enumerate(motif):
        pwm[i, _base_to_idx[ch]] = 1.0
    return pwm


def plot_pwm_heatmap(pwm, title='PWM'):
    fig, ax = plt.subplots(figsize=(max(6, pwm.shape[0] / 2), 3))
    df = pd.DataFrame(pwm, columns=BASES)
    im = ax.imshow(df.T.values, aspect='auto', cmap='viridis', vmin=0, vmax=1)
    ax.set_yticks(range(len(BASES)))
    ax.set_yticklabels(BASES)
    ax.set_xticks(range(pwm.shape[0]))
    ax.set_xlabel('Position')
    ax.set_ylabel('Base')
    ax.set_title(title)
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    plt.show()

2. Synthetic dataset generation¶

In [2]:
def generate_synthetic_dataset(num_pos=2000, num_neg=2000, seq_len=101,
                               motif='ACGTG', motif_prob=1.0, random_seed=42,
                               add_gaussian_noise=False, noise_std=0.05):
    """Generate synthetic sequences with a single implanted motif."""
    rng = np.random.RandomState(random_seed)

    def random_seq(L):
        return ''.join(rng.choice(BASES) for _ in range(L))

    pos_seqs = []
    for _ in range(num_pos):
        s = list(random_seq(seq_len))
        if rng.rand() <= motif_prob:
            k = len(motif)
            start = rng.randint(0, seq_len - k + 1)
            s[start:start + k] = list(motif)
        pos_seqs.append(''.join(s))

    neg_seqs = [random_seq(seq_len) for _ in range(num_neg)]

    seqs = pos_seqs + neg_seqs
    labels = np.array([1] * len(pos_seqs) + [0] * len(neg_seqs), dtype=np.int32)

    perm = rng.permutation(len(seqs))
    seqs = [seqs[i] for i in perm]
    labels = labels[perm]
    return seqs, labels


def generate_multi_motif_dataset(num_pos=2000, num_neg=2000, seq_len=101,
                                 motif_a='ACGTG', motif_b='TGCAT',
                                 motif_prob=1.0, cooccur_prob=0.5,
                                 random_seed=42):
    """Generate positives containing motif_a, motif_b, or both."""
    rng = np.random.RandomState(random_seed)

    def random_seq(L):
        return ''.join(rng.choice(BASES) for _ in range(L))

    pos_seqs = []
    for _ in range(num_pos):
        s = list(random_seq(seq_len))
        r = rng.rand()
        if r <= cooccur_prob:
            motifs = [motif_a, motif_b]
        elif r <= (cooccur_prob + (1 - cooccur_prob) / 2):
            motifs = [motif_a]
        else:
            motifs = [motif_b]

        for motif in motifs:
            if rng.rand() <= motif_prob:
                k = len(motif)
                start = rng.randint(0, seq_len - k + 1)
                s[start:start + k] = list(motif)
        pos_seqs.append(''.join(s))

    neg_seqs = [random_seq(seq_len) for _ in range(num_neg)]

    seqs = pos_seqs + neg_seqs
    labels = np.array([1] * len(pos_seqs) + [0] * len(neg_seqs), dtype=np.int32)
    perm = rng.permutation(len(seqs))
    seqs = [seqs[i] for i in perm]
    labels = labels[perm]
    return seqs, labels


def generate_interaction_dataset(num_pos=2000, num_neg=2000, seq_len=101,
                                 motif_a='ACGTG', motif_b='TGCAT',
                                 gap_min=10, gap_max=25, random_seed=42):
    """Generate a harder interaction task."""
    rng = np.random.RandomState(random_seed)

    def random_seq(L):
        return ''.join(rng.choice(BASES) for _ in range(L))

    pos_seqs = []
    for _ in range(num_pos):
        s = list(random_seq(seq_len))
        k1, k2 = len(motif_a), len(motif_b)
        start1 = rng.randint(0, seq_len - k1 - gap_max - k2)
        gap = rng.randint(gap_min, gap_max + 1)
        start2 = start1 + k1 + gap
        s[start1:start1 + k1] = list(motif_a)
        s[start2:start2 + k2] = list(motif_b)
        pos_seqs.append(''.join(s))

    neg_seqs = []
    for _ in range(num_neg):
        s = list(random_seq(seq_len))
        mode = rng.choice(['one_a', 'one_b', 'far_apart', 'none'])
        if mode == 'one_a':
            start = rng.randint(0, seq_len - len(motif_a) + 1)
            s[start:start + len(motif_a)] = list(motif_a)
        elif mode == 'one_b':
            start = rng.randint(0, seq_len - len(motif_b) + 1)
            s[start:start + len(motif_b)] = list(motif_b)
        elif mode == 'far_apart':
            start1 = rng.randint(0, seq_len - len(motif_a) - len(motif_b) - gap_max)
            start2 = rng.randint(start1 + len(motif_a) + gap_max + 1, seq_len - len(motif_b) + 1)
            s[start1:start1 + len(motif_a)] = list(motif_a)
            s[start2:start2 + len(motif_b)] = list(motif_b)
        neg_seqs.append(''.join(s))

    seqs = pos_seqs + neg_seqs
    labels = np.array([1] * len(pos_seqs) + [0] * len(neg_seqs), dtype=np.int32)
    perm = rng.permutation(len(seqs))
    seqs = [seqs[i] for i in perm]
    labels = labels[perm]
    return seqs, labels


SEQ_LEN = 101
MOTIF = 'ACGTG'
seqs, labels = generate_synthetic_dataset(num_pos=4000, num_neg=4000, seq_len=SEQ_LEN, motif=MOTIF, motif_prob=1.0)

X = one_hot_encode_seqs(seqs)
train_X, test_X, train_y, test_y, train_seqs, test_seqs = train_test_split(
    X, labels, seqs, test_size=0.2, random_state=42, stratify=labels
)
train_X, val_X, train_y, val_y, train_seqs, val_seqs = train_test_split(
    train_X, train_y, train_seqs, test_size=0.125, random_state=42, stratify=train_y
)

print('Train/Val/Test shapes:', train_X.shape, val_X.shape, test_X.shape)
print('Train positive fraction:', train_y.mean())
print('Example sequence:', train_seqs[0], 'label:', train_y[0])
Train/Val/Test shapes: (5600, 101, 4) (800, 101, 4) (1600, 101, 4)
Train positive fraction: 0.5
Example sequence: CCGTTGTTATCATTGCATACACGTGTCCAATGCTTGCAAGTATGTAATGGTATCTTCAGATCATTCGGTTGCGGCTCAATCTTGCGTGGACTTTGCTACCT label: 1

3. Model implementation¶

In [3]:
def build_deepbind_like_model(input_length, num_filters=16, kernel_size=11, lr=1e-3, dense_units=32):
    inp = layers.Input(shape=(input_length, 4), name='sequence')
    x = layers.Conv1D(filters=num_filters, kernel_size=kernel_size, padding='valid', activation='relu', name='conv1')(inp)
    x = layers.GlobalMaxPooling1D(name='gmp')(x)
    x = layers.Dense(dense_units, activation='relu', name='dense1')(x)
    out = layers.Dense(1, activation='sigmoid', name='output')(x)
    model = models.Model(inputs=inp, outputs=out)
    model.compile(optimizer=optimizers.Adam(learning_rate=lr),
                  loss=losses.BinaryCrossentropy(),
                  metrics=['accuracy'])
    return model


def train_eval_model(train_X, train_y, val_X, val_y, test_X, test_y,
                     input_length, num_filters=16, kernel_size=5, lr=1e-3,
                     dense_units=32, epochs=10, batch_size=128, verbose=0):
    model = build_deepbind_like_model(input_length, num_filters=num_filters,
                                      kernel_size=kernel_size, lr=lr, dense_units=dense_units)
    es = callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
    history = model.fit(train_X, train_y,
                        validation_data=(val_X, val_y),
                        epochs=epochs, batch_size=batch_size,
                        callbacks=[es], verbose=verbose)
    preds = model.predict(test_X, verbose=0).ravel()
    roc = roc_auc_score(test_y, preds)
    precision, recall, _ = precision_recall_curve(test_y, preds)
    pr_auc = auc(recall, precision)
    return model, history, preds, roc, pr_auc

Question 1¶

Explain the number of parameters and hyper-parameters in your model.¶

Answer. The baseline model has one 1D convolution layer, one global max pooling layer, one hidden dense layer, and one sigmoid output layer.

For a convolution layer with F filters, kernel width K, and 4 input channels, the parameter count is:

$$ (K \times 4 \times F) + F $$

The +F term is the bias for each filter.

For the dense layers, if the convolution output after pooling has size F and the hidden layer has H units, then the dense parameters are:

$$ (F \times H) + H $$

The main hyper-parameters are num_filters, kernel_size, dense_units, learning rate, batch size, and epochs. In the data generator, the motif length, motif probability, and sequence length are also important hyper-parameters because they control how easy the task is.

For the output layer, $$ (H \times 1) + 1 $$

In [4]:
# Q1: inspect the model and count parameters
baseline_model = build_deepbind_like_model(SEQ_LEN, num_filters=16, kernel_size=5, lr=1e-3, dense_units=32)
baseline_model.summary()
print('Total parameters:', baseline_model.count_params())
2026-04-08 11:02:23.431018: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1
2026-04-08 11:02:23.431045: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2026-04-08 11:02:23.431048: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2026-04-08 11:02:23.431063: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2026-04-08 11:02:23.431074: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ sequence (InputLayer)           │ (None, 101, 4)         │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv1 (Conv1D)                  │ (None, 97, 16)         │           336 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ gmp (GlobalMaxPooling1D)        │ (None, 16)             │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense1 (Dense)                  │ (None, 32)             │           544 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ output (Dense)                  │ (None, 1)              │            33 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 913 (3.57 KB)
 Trainable params: 913 (3.57 KB)
 Non-trainable params: 0 (0.00 B)
Total parameters: 913

4. Baseline training and evaluation¶

In [5]:
baseline_model, history, preds, roc, pr_auc = train_eval_model(
    train_X, train_y, val_X, val_y, test_X, test_y,
    input_length=SEQ_LEN,
    num_filters=16,
    kernel_size=len(MOTIF),
    lr=1e-3,
    dense_units=32,
    epochs=10,
    batch_size=128,
    verbose=1
)

print(f'Baseline test ROC AUC: {roc:.4f}')
print(f'Baseline test PR AUC:  {pr_auc:.4f}')

plt.figure(figsize=(6, 3))
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.legend()
plt.title('Loss curves')
plt.tight_layout()
plt.show()
Epoch 1/10
2026-04-08 11:02:32.979321: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
44/44 ━━━━━━━━━━━━━━━━━━━━ 4s 15ms/step - accuracy: 0.5040 - loss: 0.6939 - val_accuracy: 0.5238 - val_loss: 0.6895
Epoch 2/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.5626 - loss: 0.6841 - val_accuracy: 0.5913 - val_loss: 0.6786
Epoch 3/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.6498 - loss: 0.6703 - val_accuracy: 0.6925 - val_loss: 0.6603
Epoch 4/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.7553 - loss: 0.6471 - val_accuracy: 0.7775 - val_loss: 0.6283
Epoch 5/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8168 - loss: 0.6090 - val_accuracy: 0.8262 - val_loss: 0.5795
Epoch 6/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8359 - loss: 0.5523 - val_accuracy: 0.8375 - val_loss: 0.5139
Epoch 7/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8490 - loss: 0.4816 - val_accuracy: 0.8413 - val_loss: 0.4463
Epoch 8/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8600 - loss: 0.4131 - val_accuracy: 0.8525 - val_loss: 0.3916
Epoch 9/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8778 - loss: 0.3582 - val_accuracy: 0.8775 - val_loss: 0.3506
Epoch 10/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8951 - loss: 0.3200 - val_accuracy: 0.8975 - val_loss: 0.3235
Baseline test ROC AUC: 0.9106
Baseline test PR AUC:  0.8471
No description has been provided for this image

Question 2¶

Train the model with different kernel sizes (smaller and larger than the motif). Report ROC AUC and PR AUC.¶

In [6]:
kernel_sizes = [3, 5, 7, 11, 15]
results = []
models_by_kernel = {}

for k in kernel_sizes:
    model_k, hist_k, preds_k, roc_k, pr_k = train_eval_model(
        train_X, train_y, val_X, val_y, test_X, test_y,
        input_length=SEQ_LEN,
        num_filters=16,
        kernel_size=k,
        lr=1e-3,
        dense_units=32,
        epochs=10,
        batch_size=128,
        verbose=0
    )
    results.append({'kernel_size': k, 'ROC_AUC': roc_k, 'PR_AUC': pr_k})
    models_by_kernel[k] = (model_k, hist_k)
    print(f'k={k:2d}  ROC AUC={roc_k:.4f}  PR AUC={pr_k:.4f}')

results_df = pd.DataFrame(results)
results_df
k= 3  ROC AUC=0.5581  PR AUC=0.5442
k= 5  ROC AUC=0.9539  PR AUC=0.9196
k= 7  ROC AUC=0.9477  PR AUC=0.9118
k=11  ROC AUC=0.9410  PR AUC=0.9126
k=15  ROC AUC=0.9272  PR AUC=0.9007
Out[6]:
kernel_size ROC_AUC PR_AUC
0 3 0.558066 0.544188
1 5 0.953883 0.919563
2 7 0.947662 0.911830
3 11 0.940986 0.912588
4 15 0.927242 0.900721

Answer. Kernel sizes near the implanted motif length usually perform best, because the filter can match the full motif with minimal extra context. Smaller kernels may only see fragments of the motif, while much larger kernels can dilute the signal with irrelevant neighboring bases. The table above should show the best AUCs around the motif length or slightly larger, depending on random initialization and training dynamics.

Question 3¶

Does the model recover the implanted motif? Show the best-matching filter and its sequence logo.¶

In [7]:
import heapq


def get_conv_layer_output(model, X, layer_name='conv1'):
    conv_model = models.Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
    return conv_model.predict(X, verbose=0)


def collect_top_kmers(model, train_X, train_seqs, top_frac=0.01, layer_name='conv1'):
    conv_out = get_conv_layer_output(model, train_X, layer_name=layer_name)
    kernel = model.get_layer(layer_name).kernel_size[0]
    num_filters = conv_out.shape[-1]
    num_top = max(50, int(top_frac * train_X.shape[0]))
    heaps = {f: [] for f in range(num_filters)}

    for i in range(conv_out.shape[0]):
        arr = conv_out[i]
        L_out = arr.shape[0]
        for pos in range(L_out):
            for f in range(num_filters):
                score = float(arr[pos, f])
                h = heaps[f]
                if len(h) < num_top:
                    heapq.heappush(h, (score, i, pos))
                elif score > h[0][0]:
                    heapq.heapreplace(h, (score, i, pos))

    filter_kmers = {f: [] for f in range(num_filters)}
    filter_scores = {f: [] for f in range(num_filters)}
    for f, heap in heaps.items():
        top = sorted(heap, reverse=True)
        for score, i, pos in top:
            kmer = train_seqs[i][pos:pos + kernel]
            if len(kmer) == kernel:
                filter_kmers[f].append(kmer)
                filter_scores[f].append(score)
    return filter_kmers, filter_scores, conv_out


def build_pwm(kmers):
    k = len(kmers[0])
    pfm = np.zeros((k, 4), dtype=float)
    for s in kmers:
        for i, ch in enumerate(s):
            pfm[i, _base_to_idx[ch]] += 1
    pfm += 1e-6
    pwm = pfm / pfm.sum(axis=1, keepdims=True)
    return pwm


def pwm_similarity_to_motif(pwm, motif):
    motif = motif[:pwm.shape[0]]
    score = 0.0
    for i, ch in enumerate(motif):
        score += pwm[i, _base_to_idx[ch]]
    return score / len(motif)


filter_kmers, filter_scores, conv_out = collect_top_kmers(baseline_model, train_X, train_seqs, top_frac=0.02, layer_name='conv1')
true_pwm = make_pwm_from_motif(MOTIF)

scores = []
for f in range(len(filter_kmers)):
    if len(filter_kmers[f]) < 5:
        continue
    pwm_f = build_pwm(filter_kmers[f])
    scores.append((f, pwm_similarity_to_motif(pwm_f, MOTIF), len(filter_kmers[f])))

scores = sorted(scores, key=lambda x: x[1], reverse=True)
scores[:5]
Out[7]:
[(4, 0.7999999803571436, 112),
 (1, 0.5999999875000005, 112),
 (5, 0.39999999464285735, 112),
 (7, 0.39999999464285735, 112),
 (14, 0.39999999464285735, 112)]
In [8]:
import logomaker
from typing import Optional

def pwm_to_logo(pwm: np.ndarray, title: str = '', ax=None, save_path: Optional[str] = None):
    """Plot a PWM as a sequence logo using logomaker."""
    pwm_norm = pwm / (pwm.sum(axis=1, keepdims=True) + 1e-9)  # normalize rows to sum to 1
    pwm_df   = pd.DataFrame(pwm_norm, columns=['A', 'C', 'G', 'T'])

    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 3))
    else:
        fig = ax.figure

    logomaker.Logo(pwm_df, ax=ax, color_scheme='classic')
    ax.set_title(title, fontsize=11, fontweight='bold')
    ax.set_xlabel('Position')
    ax.set_ylabel('Frequency')
    ax.spines[['top', 'right']].set_visible(False)

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches='tight')
    return ax


def summarize_best_filter(scores, filter_kmers, true_pwm, motif,
                          save_path: Optional[str] = None):
    """Print summary and plot true vs learned PWM side by side as logos."""
    best_filter     = scores[0][0]
    best_similarity = scores[0][1]
    best_pwm        = build_pwm(filter_kmers[best_filter])

    # ── Console summary ───────────────────────────────────────────────────────
    print(f"{'─'*45}")
    print(f"  Best filter      : {best_filter}")
    print(f"  Similarity score : {best_similarity:.4f}")
    print(f"  Top 10 k-mers    : {filter_kmers[best_filter][:10]}")
    print(f"{'─'*45}")

    # ── Top-5 filter ranking ──────────────────────────────────────────────────
    print("\nTop 5 filters by motif similarity:")
    print(f"  {'Rank':<6} {'Filter':<8} {'Similarity':<12} {'N kmers'}")
    for rank, (f, sim, n) in enumerate(scores[:5], 1):
        print(f"  {rank:<6} {f:<8} {sim:<12.4f} {n}")

    # ── Side-by-side logo plot ────────────────────────────────────────────────
    fig, axes = plt.subplots(1, 2, figsize=(14, 3))
    pwm_to_logo(true_pwm, title=f'True motif PWM: {motif}',               ax=axes[0])
    pwm_to_logo(best_pwm, title=f'Learned filter {best_filter} PWM',      ax=axes[1])
    plt.suptitle(f'Motif recovery  |  similarity = {best_similarity:.4f}',
                 fontsize=12, fontweight='bold', y=1.02)
    plt.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

    return best_filter, best_pwm


# ── Run ───────────────────────────────────────────────────────────────────────
best_filter, best_pwm = summarize_best_filter(
    scores, filter_kmers, true_pwm, MOTIF,
    save_path="motif_recovery.png"
)
─────────────────────────────────────────────
  Best filter      : 4
  Similarity score : 0.8000
  Top 10 k-mers    : ['AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG']
─────────────────────────────────────────────

Top 5 filters by motif similarity:
  Rank   Filter   Similarity   N kmers
  1      4        0.8000       112
  2      1        0.6000       112
  3      5        0.4000       112
  4      7        0.4000       112
  5      14       0.4000       112
No description has been provided for this image

Answer. In this case, the CNN did not recover the exact implanted motif, but it learned a closely related sequence pattern. This suggests the filter is capturing the discriminative local signal, though not necessarily the exact biological motif. In practice, CNN filters often recover approximate motifs or motif variants rather than a perfect consensus.

Question 4¶

Try adding Gaussian noise to negative sequences or implant the motif with probability < 1. How do results change?¶

In [11]:
def add_gaussian_noise_to_onehot(X, std=0.05, random_seed=42):
    rng = np.random.RandomState(random_seed)
    Xn = X.astype(np.float32).copy()
    Xn += rng.normal(loc=0.0, scale=std, size=Xn.shape).astype(np.float32)
    Xn = np.clip(Xn, 0.0, 1.0)
    return Xn


def run_noise_experiment(motif_prob_values=(1.0, 0.8, 0.6), noise_std=0.05):
    rows = []
    for mp in motif_prob_values:
        seqs_i, labels_i = generate_synthetic_dataset(
            num_pos=4000, num_neg=4000, seq_len=SEQ_LEN,
            motif=MOTIF, motif_prob=mp, random_seed=42
        )
        X_i = one_hot_encode_seqs(seqs_i)
        X_train_i, X_test_i, y_train_i, y_test_i, train_seqs_i, test_seqs_i = train_test_split(
            X_i, labels_i, seqs_i, test_size=0.2, random_state=42, stratify=labels_i
        )
        X_train_i, X_val_i, y_train_i, y_val_i, _, _ = train_test_split(
            X_train_i, y_train_i, train_seqs_i, test_size=0.125, random_state=42, stratify=y_train_i
        )
        X_train_noisy = add_gaussian_noise_to_onehot(X_train_i, std=noise_std, random_seed=42)
        model_i, hist_i, preds_i, roc_i, pr_i = train_eval_model(
            X_train_noisy, y_train_i, X_val_i, y_val_i, X_test_i, y_test_i,
            input_length=SEQ_LEN,
            num_filters=16,
            kernel_size=len(MOTIF),
            lr=1e-3,
            dense_units=32,
            epochs=10,
            batch_size=128,
            verbose=0
        )
        rows.append({'motif_prob': mp, 'noise_std': noise_std, 'ROC_AUC': roc_i, 'PR_AUC': pr_i})
        print(f'motif_prob={mp:.1f}  ROC AUC={roc_i:.4f}  PR AUC={pr_i:.4f}')
    return pd.DataFrame(rows)

noise_results = run_noise_experiment(motif_prob_values=(1.0, 0.8, 0.6), noise_std=0.05)
noise_results
motif_prob=1.0  ROC AUC=0.8244  PR AUC=0.7567
motif_prob=0.8  ROC AUC=0.8662  PR AUC=0.8573
motif_prob=0.6  ROC AUC=0.7149  PR AUC=0.6858
Out[11]:
motif_prob noise_std ROC_AUC PR_AUC
0 1.0 0.05 0.824370 0.756652
1 0.8 0.05 0.866226 0.857294
2 0.6 0.05 0.714886 0.685760

Answer. Performance drops as the task becomes noisier. If the motif is implanted with probability less than 1, the positive class is less consistent and the CNN has fewer reliable motif examples to learn from. Adding Gaussian noise to the inputs also blurs the motif signal, which typically lowers ROC AUC and PR AUC. The model is most reliable when the motif is implanted consistently and the inputs are close to one-hot.

5. Multiple motifs: does the CNN learn separate motifs?¶

In [15]:
MOTIF_A = 'ACGTG'
MOTIF_B = 'TGCAT'

seqs2, labels2 = generate_multi_motif_dataset(
    num_pos=4000, num_neg=4000, seq_len=SEQ_LEN,
    motif_a=MOTIF_A, motif_b=MOTIF_B,
    motif_prob=1.0, cooccur_prob=0.5, random_seed=7
)
X2 = one_hot_encode_seqs(seqs2)
X2_train, X2_test, y2_train, y2_test, seqs2_train, seqs2_test = train_test_split(
    X2, labels2, seqs2, test_size=0.2, random_state=42, stratify=labels2
)
X2_train, X2_val, y2_train, y2_val, _, _ = train_test_split(
    X2_train, y2_train, seqs2_train, test_size=0.125, random_state=42, stratify=y2_train
)

multi_model, multi_hist, multi_preds, multi_roc, multi_pr = train_eval_model(
    X2_train, y2_train, X2_val, y2_val, X2_test, y2_test,
    input_length=SEQ_LEN,
    num_filters=24,
    kernel_size=len(MOTIF_A),
    lr=1e-3,
    dense_units=32,
    epochs=10,
    batch_size=128,
    verbose=1
)

print(f'Multi-motif ROC AUC: {multi_roc:.4f}')
print(f'Multi-motif PR AUC:  {multi_pr:.4f}')

multi_filter_kmers, multi_filter_scores, _ = collect_top_kmers(multi_model, X2_train, seqs2_train, top_frac=0.02, layer_name='conv1')

motif_a_pwm = make_pwm_from_motif(MOTIF_A)
motif_b_pwm = make_pwm_from_motif(MOTIF_B)

multi_scores = []
for f in range(len(multi_filter_kmers)):
    if len(multi_filter_kmers[f]) < 5:
        continue
    pwm_f = build_pwm(multi_filter_kmers[f])
    sa = pwm_similarity_to_motif(pwm_f, MOTIF_A)
    sb = pwm_similarity_to_motif(pwm_f, MOTIF_B)
    multi_scores.append((f, sa, sb, len(multi_filter_kmers[f])))

multi_scores = sorted(multi_scores, key=lambda x: max(x[1], x[2]), reverse=True)
multi_scores[:10]
Epoch 1/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 2s 20ms/step - accuracy: 0.5216 - loss: 0.6916 - val_accuracy: 0.6212 - val_loss: 0.6816
Epoch 2/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.6530 - loss: 0.6764 - val_accuracy: 0.7862 - val_loss: 0.6453
Epoch 3/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.7793 - loss: 0.6323 - val_accuracy: 0.8225 - val_loss: 0.5701
Epoch 4/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.8166 - loss: 0.5532 - val_accuracy: 0.8400 - val_loss: 0.4733
Epoch 5/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8341 - loss: 0.4619 - val_accuracy: 0.8450 - val_loss: 0.3968
Epoch 6/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.8487 - loss: 0.3933 - val_accuracy: 0.8612 - val_loss: 0.3525
Epoch 7/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.8567 - loss: 0.3530 - val_accuracy: 0.8625 - val_loss: 0.3293
Epoch 8/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8607 - loss: 0.3300 - val_accuracy: 0.8650 - val_loss: 0.3165
Epoch 9/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - accuracy: 0.8614 - loss: 0.3152 - val_accuracy: 0.8675 - val_loss: 0.3081
Epoch 10/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.8619 - loss: 0.3041 - val_accuracy: 0.8725 - val_loss: 0.3027
Multi-motif ROC AUC: 0.9421
Multi-motif PR AUC:  0.9313
Out[15]:
[(8, 0.30714285510204087, 0.2267857151147959, 112),
 (1, 0.28749999866071435, 0.19107143067602037, 112),
 (7, 0.28035714177295923, 0.26428571377551024, 112),
 (13, 0.20892857289540814, 0.27857142755102043, 112),
 (22, 0.23035714355867346, 0.2767857133290817, 112),
 (15, 0.24464285733418362, 0.27321428488520416, 112),
 (18, 0.2696428564413266, 0.23035714355867346, 112),
 (14, 0.2678571422193877, 0.2428571431122449, 112),
 (19, 0.2267857151147959, 0.26607142799744904, 112),
 (6, 0.26249999955357145, 0.24285714311224488, 112)]
In [68]:
top = df.head(5)
print(top[["filter", "best_match", "best_sim", "info_content", "n_windows"]])

for f in top["filter"]:
    pwm = build_pwm_from_windows(int_filter_kmers[f]) if 'build_pwm_from_windows' in globals() else build_pwm(int_filter_kmers[f])
    plot_pwm_logo(pwm, title=f"Filter {f} ({'motif ' + df[df['filter']==f]['best_match'].iloc[0]})")
    filter best_match  best_sim  info_content  n_windows
19      19          B  0.291071      0.122691        112
21      21          A  0.257143      0.133296        112
17      17          B  0.250000      0.121927        112
23      23          B  0.251786      0.116732        112
8        8          A  0.262500      0.111452        112
<Figure size 600x300 with 0 Axes>
No description has been provided for this image
<Figure size 600x300 with 0 Axes>
No description has been provided for this image
<Figure size 600x300 with 0 Axes>
No description has been provided for this image
<Figure size 600x300 with 0 Axes>
No description has been provided for this image
<Figure size 600x300 with 0 Axes>
No description has been provided for this image
In [66]:
import logomaker
import pandas as pd
import matplotlib.pyplot as plt

BASES = ["A", "C", "G", "T"]

def plot_pwm_logo(pwm, title="PWM"):
    df = pd.DataFrame(pwm, columns=BASES)

    plt.figure(figsize=(max(6, pwm.shape[0] * 0.6), 3))
    logomaker.Logo(df, color_scheme='classic')

    plt.title(title)
    plt.xlabel("Position")
    plt.ylabel("Frequency")
    plt.tight_layout()
    plt.show()
In [67]:
best_a = max(multi_scores, key=lambda x: x[1])[0]
best_b = max(multi_scores, key=lambda x: x[2])[0]

print('Best filter for motif A:', best_a)
print('Best filter for motif B:', best_b)

# Build PWMs
pwm_a = build_pwm(multi_filter_kmers[best_a])
pwm_b = build_pwm(multi_filter_kmers[best_b])

# Plot logos
fig, axes = plt.subplots(1, 2, figsize=(10, 3))

logomaker.Logo(pd.DataFrame(pwm_a, columns=BASES), ax=axes[0])
axes[0].set_title(f'Filter {best_a}')

logomaker.Logo(pd.DataFrame(pwm_b, columns=BASES), ax=axes[1])
axes[1].set_title(f'Filter {best_b}')

plt.tight_layout()
plt.savefig("logo.png", dpi=150, bbox_inches='tight')
plt.show()
Best filter for motif A: 8
Best filter for motif B: 13
No description has been provided for this image

Answer.¶

With Multi-motif ROC AUC: 0.9421, Multi-motif PR AUC: 0.9313, the 1-CNN layer model understands the task very well. However, it Conv1 filter is messay and it does not give clean motifs. Each filter only captures part of the signal and motifs can be distributed across multiple filters. Our PWN window is very blurry. Hence, we might need another CNN layers to combine the signals.

In [59]:
display(Image(filename="dragonn_model_figure.jpg"))
No description has been provided for this image

6. Motif interactions: does the CNN capture co-occurrence or spacing relationships?¶

In [35]:
seqs3, labels3 = generate_interaction_dataset(
    num_pos=4000, num_neg=4000, seq_len=SEQ_LEN,
    motif_a=MOTIF_A, motif_b=MOTIF_B,
    gap_min=10, gap_max=25, random_seed=13
)
X3 = one_hot_encode_seqs(seqs3)
X3_train, X3_test, y3_train, y3_test, seqs3_train, seqs3_test = train_test_split(
    X3, labels3, seqs3, test_size=0.2, random_state=42, stratify=labels3
)
X3_train, X3_val, y3_train, y3_val, _, _ = train_test_split(
    X3_train, y3_train, seqs3_train, test_size=0.125, random_state=42, stratify=y3_train
)


def build_interaction_model(input_length, num_filters=16, kernel_size=5, lr=1e-3):
    inp = layers.Input(shape=(input_length, 4), name='sequence')
    x = layers.Conv1D(num_filters, kernel_size, padding='valid', activation='relu', name='conv1')(inp)
    x = layers.Conv1D(num_filters, kernel_size, padding='valid', activation='relu', name='conv2')(x)
    x = layers.GlobalMaxPooling1D(name='gmp')(x)
    x = layers.Dense(32, activation='relu', name='dense1')(x)
    out = layers.Dense(1, activation='sigmoid', name='output')(x)
    model = models.Model(inp, out)
    model.compile(optimizer=optimizers.Adam(learning_rate=lr), loss='binary_crossentropy', metrics=['accuracy'])
    return model

interaction_model = build_interaction_model(SEQ_LEN, num_filters=24, kernel_size=len(MOTIF_A), lr=1e-3)
interaction_model.summary()
es = callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
hist3 = interaction_model.fit(X3_train, y3_train, validation_data=(X3_val, y3_val), epochs=10, batch_size=128, callbacks=[es], verbose=1)

pred3 = interaction_model.predict(X3_test, verbose=0).ravel()
roc3 = roc_auc_score(y3_test, pred3)
precision3, recall3, _ = precision_recall_curve(y3_test, pred3)
pr_auc3 = auc(recall3, precision3)
print(f'Interaction-task ROC AUC: {roc3:.4f}')
print(f'Interaction-task PR AUC:  {pr_auc3:.4f}')

plt.figure(figsize=(6, 3))
plt.plot(hist3.history['loss'], label='train loss')
plt.plot(hist3.history['val_loss'], label='val loss')
plt.legend()
plt.title('Interaction task loss')
plt.tight_layout()
plt.show()

int_filter_kmers, int_filter_scores, conv3 = collect_top_kmers(interaction_model, X3_train, seqs3_train, top_frac=0.02, layer_name='conv1')
interaction_scores = []
for f in range(len(int_filter_kmers)):
    if len(int_filter_kmers[f]) < 5:
        continue
    pwm_f = build_pwm(int_filter_kmers[f])
    sa = pwm_similarity_to_motif(pwm_f, MOTIF_A)
    sb = pwm_similarity_to_motif(pwm_f, MOTIF_B)
    interaction_scores.append((f, sa, sb))
interaction_scores = sorted(interaction_scores, key=lambda x: max(x[1], x[2]), reverse=True)
interaction_scores[:10]
Model: "functional_31"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ sequence (InputLayer)           │ (None, 101, 4)         │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv1 (Conv1D)                  │ (None, 97, 24)         │           504 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2 (Conv1D)                  │ (None, 93, 24)         │         2,904 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ gmp (GlobalMaxPooling1D)        │ (None, 24)             │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense1 (Dense)                  │ (None, 32)             │           800 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ output (Dense)                  │ (None, 1)              │            33 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 4,241 (16.57 KB)
 Trainable params: 4,241 (16.57 KB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 2s 25ms/step - accuracy: 0.5044 - loss: 0.6971 - val_accuracy: 0.6150 - val_loss: 0.6781
Epoch 2/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.7110 - loss: 0.6590 - val_accuracy: 0.7362 - val_loss: 0.6012
Epoch 3/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - accuracy: 0.7823 - loss: 0.5559 - val_accuracy: 0.8300 - val_loss: 0.4878
Epoch 4/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.8443 - loss: 0.4382 - val_accuracy: 0.8425 - val_loss: 0.4201
Epoch 5/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - accuracy: 0.8635 - loss: 0.3775 - val_accuracy: 0.8400 - val_loss: 0.3990
Epoch 6/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.8639 - loss: 0.3547 - val_accuracy: 0.8400 - val_loss: 0.3941
Epoch 7/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.8645 - loss: 0.3447 - val_accuracy: 0.8400 - val_loss: 0.3942
Epoch 8/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.8655 - loss: 0.3396 - val_accuracy: 0.8388 - val_loss: 0.3945
Epoch 9/10
44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - accuracy: 0.8658 - loss: 0.3362 - val_accuracy: 0.8388 - val_loss: 0.3950
Interaction-task ROC AUC: 0.8738
Interaction-task PR AUC:  0.8013
No description has been provided for this image
Out[35]:
[(19, 0.20535714445153058, 0.2910714271045919),
 (5, 0.28749999866071435, 0.23035714355867348),
 (15, 0.22500000089285713, 0.282142855994898),
 (14, 0.2589285711096939, 0.27857142755102043),
 (12, 0.27142857066326537, 0.23214285778061222),
 (7, 0.23035714355867348, 0.2678571422193878),
 (11, 0.2428571431122449, 0.26428571377551024),
 (8, 0.26249999955357145, 0.23392857200255096),
 (18, 0.2428571431122449, 0.26249999955357145),
 (1, 0.26071428533163266, 0.23750000044642855)]
In [40]:
yhat = (pred3 >= 0.5).astype(int)
correct_pos = np.where((y3_test == 1) & (yhat == 1))[0]
correct_neg = np.where((y3_test == 0) & (yhat == 0))[0]

if len(correct_pos) > 0:
    i_pos = correct_pos[0]
    plot_two_layer_heatmaps(
        seqs3_test[i_pos],
        interaction_model,
        motif_a=MOTIF_A,
        motif_b=MOTIF_B,
        conv1_name="conv1",
        conv2_name="conv2",
        title="Positive example: two-layer CNN detects motif interaction"
    )
else:
    print("No correctly predicted positive example found.")
No description has been provided for this image
In [42]:
best_filters = [f for f, sa, sb in interaction_scores[:5]]
print(best_filters)
[19, 5, 15, 14, 12]
In [45]:
def plot_clean_two_layer_heatmaps(seq, model, important_filters,
                                 conv1_name="conv1", conv2_name="conv2", 
                                 save_path=None):

    x = one_hot_encode_seqs([seq])

    conv1_out = get_conv_layer_output(model, x, layer_name=conv1_name)[0]
    conv2_out = get_conv_layer_output(model, x, layer_name=conv2_name)[0]

    # Select only important filters
    conv1_sel = conv1_out[:, important_filters].T
    conv2_sel = conv2_out[:, important_filters].T

    fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True)

    sns.heatmap(conv1_sel, ax=axes[0], cmap="viridis", cbar=True)
    axes[0].set_title("Conv1 (important filters only)")
    axes[0].set_ylabel("Filters")

    sns.heatmap(conv2_sel, ax=axes[1], cmap="magma", cbar=True)
    axes[1].set_title("Conv2 (interaction filters)")
    axes[1].set_xlabel("Sequence position")
    if save_path is not None:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    plt.tight_layout()
    plt.show()
In [49]:
yhat = (pred3 >= 0.5).astype(int)
correct_pos = np.where((y3_test == 1) & (yhat == 1))[0]
correct_neg = np.where((y3_test == 0) & (yhat == 0))[0]

if len(correct_pos) > 0:
    i_pos = correct_pos[0]
    plot_two_layer_heatmaps(
        seqs3_test[i_pos],
        interaction_model,
        motif_a=MOTIF_A,
        motif_b=MOTIF_B,
        conv1_name="conv1",
        conv2_name="conv2",
        title="Positive example: two-layer CNN detects motif interaction"
    )
else:
    print("No correctly predicted positive example found.")
No description has been provided for this image
In [50]:
important_filters = [f for f, _, _ in interaction_scores[:5]]

plot_clean_two_layer_heatmaps(
    seqs3_test[i_pos],
    interaction_model,
    important_filters, 
    save_path="two_layer.png"
)
No description has been provided for this image
In [51]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def plot_two_layer_heatmaps(seq, model, motif_a=None, motif_b=None,
                            conv1_name="conv1", conv2_name="conv2",
                            title="Two-layer CNN motif interaction",
                            save_path=None):

    import seaborn as sns
    import numpy as np
    import matplotlib.pyplot as plt

    # Encode sequence
    x = one_hot_encode_seqs([seq])

    # Get activations
    conv1_out = get_conv_layer_output(model, x, layer_name=conv1_name)[0]
    conv2_out = get_conv_layer_output(model, x, layer_name=conv2_name)[0]

    conv1_hm = conv1_out.T
    conv2_hm = conv2_out.T

    # Normalize rows
    def row_norm(mat):
        m = mat.copy()
        m_min = m.min(axis=1, keepdims=True)
        m_max = m.max(axis=1, keepdims=True)
        denom = np.where((m_max - m_min) == 0, 1.0, (m_max - m_min))
        return (m - m_min) / denom

    conv1_hm = row_norm(conv1_hm)
    conv2_hm = row_norm(conv2_hm)

    # Find motif positions
    def find_all(s, motif):
        if motif is None:
            return []
        pos = []
        i = 0
        while True:
            i = s.find(motif, i)
            if i == -1:
                break
            pos.append(i)
            i += 1
        return pos

    pos_a = find_all(seq, motif_a)
    pos_b = find_all(seq, motif_b)

    # Create figure
    fig, axes = plt.subplots(
        3, 1, figsize=(14, 7),
        gridspec_kw={"height_ratios": [0.7, 1.4, 1.4]},
        sharex=True
    )

    # Top: motif locations
    axes[0].set_title(title)
    for p in pos_a:
        axes[0].axvspan(p, p + len(motif_a), color="blue", alpha=0.3)
    for p in pos_b:
        axes[0].axvspan(p, p + len(motif_b), color="orange", alpha=0.3)
    axes[0].set_yticks([])
    axes[0].set_ylabel("Motifs")

    # Conv1 heatmap
    sns.heatmap(conv1_hm, ax=axes[1], cmap="viridis", cbar=True)
    axes[1].set_ylabel("Conv1 filters")
    axes[1].set_title("Layer 1: local motif detectors")

    # Conv2 heatmap
    sns.heatmap(conv2_hm, ax=axes[2], cmap="magma", cbar=True)
    axes[2].set_ylabel("Conv2 filters")
    axes[2].set_title("Layer 2: motif interactions")
    axes[2].set_xlabel("Sequence position")

    plt.tight_layout()

    #  Save BEFORE show
    if save_path is not None:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    plt.show()
In [69]:
best_int_a = max(interaction_scores, key=lambda x: x[1])[0]
best_int_b = max(interaction_scores, key=lambda x: x[2])[0]
print('Best interaction filter for motif A:', best_int_a)
print('Best interaction filter for motif B:', best_int_b)
plot_pwm_heatmap(build_pwm(int_filter_kmers[best_int_a]), title=f'Interaction task filter {best_int_a} (motif A candidate)')
plot_pwm_heatmap(build_pwm(int_filter_kmers[best_int_b]), title=f'Interaction task filter {best_int_b} (motif B candidate)')

# Check whether the two best filters co-activate across sequences.
def sequence_filter_max_activations(model, X, layer_name='conv1'):
    conv_out = get_conv_layer_output(model, X, layer_name=layer_name)
    return conv_out.max(axis=1)

max_act = sequence_filter_max_activations(interaction_model, X3_test, layer_name='conv1')
fa = max_act[:, best_int_a]
fb = max_act[:, best_int_b]

plt.figure(figsize=(4, 4))
plt.scatter(fa, fb, s=8, alpha=0.5)
plt.xlabel(f'Filter {best_int_a} max activation')
plt.ylabel(f'Filter {best_int_b} max activation')
plt.title('Filter co-activation on interaction task')
plt.tight_layout()
plt.show()

print('Correlation between the two filter activations:', np.corrcoef(fa, fb)[0, 1])
Best interaction filter for motif A: 5
Best interaction filter for motif B: 19
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Correlation between the two filter activations: -0.04719416329022643

Answer. The CNN can capture motif interactions, but mainly through architecture depth and co-activation patterns. A shallow model mostly learns local motifs. A deeper model can combine multiple motif detectors and then learn whether they co-occur or appear with the right spacing. In the interaction task above, the first convolution layer tends to specialize to the two motifs, while the deeper layer helps integrate those signals into the class decision.

Final short answers¶

Q5: Does the CNN learn separate filters? Yes. Different filters specialize to different local motifs or motif variants.

Q6: Does it capture motif interactions? Partly yes. The CNN captures interactions best when there are multiple filters and at least two layers, so that local motif detectors can be combined into a higher-level decision about co-occurrence or spacing. A single shallow layer mostly captures local motif presence, not rich long-range interaction structure.

The dense layer can capture interactions between motifs by combining activations from different convolutional filters, allowing the model to learn co-occurrence patterns and nonlinear relationships. However, because spatial information is largely lost after flattening, the dense layer cannot effectively model positional relationships such as spacing or order between motifs. Therefore, while it contributes to interaction modeling, it is limited compared to deeper convolutional layers.

Additional Resources:¶

Exploring convolutional neural network (CNN) architectures for simulated genomic data as well as ENCODE TF ChIP-seq datasets. (DragoNN) https://colab.research.google.com/github/kundajelab/dragonn/blob/master/tutorials/ENCODE_Jamboree_July2019_DragonnTutorial.ipynb#scrollTo=APNfq3knhfZg

In [ ]: