Homework 4 Solution Notebook: DeepBind-style 1D CNN for Motif Discovery¶
This notebook contains worked code and concise answers for the six questions in HW4. It keeps the original synthetic-sequence setup, but adds experiments for:
- parameter counting and hyperparameters,
- kernel-size sweeps,
- motif recovery from convolution filters,
- noise / motif-probability robustness,
- multi-motif specialization,
- motif interaction experiments.
Run the notebook top-to-bottom. The exact scores may vary slightly with randomness, but the trends should be stable.
from IPython.display import Image, display
display(Image(filename="sequence_properties_1.jpg"))
display(Image(filename="sequence_properties_2.jpg"))
display(Image(filename="homotypic_motif_density_localization.jpg"))
display(Image(filename="homotypic_motif_density_localization_task.jpg"))
display(Image(filename="dragonn_and_pssm.jpg"))
1. Setup¶
import numpy as np
import matplotlib.pyplot as plt
import random
import pandas as pd
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, auc
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, losses, callbacks
np.random.seed(42)
random.seed(42)
tf.random.set_seed(42)
BASES = ['A','C','G','T']
_base_to_idx = {b:i for i,b in enumerate(BASES)}
def one_hot_encode_seqs(seqs):
"""Encode a list of uppercase DNA sequences (A,C,G,T) to shape (N, L, 4)."""
N = len(seqs)
L = len(seqs[0])
arr = np.zeros((N, L, 4), dtype=np.float32)
for i, s in enumerate(seqs):
for j, ch in enumerate(s):
arr[i, j, _base_to_idx.get(ch, 0)] = 1.0
return arr
def seqs_from_onehot(X):
seqs = []
for i in range(X.shape[0]):
seqs.append(''.join(BASES[int(np.argmax(X[i, j]))] for j in range(X.shape[1])))
return seqs
def reverse_complement(seq):
comp = str.maketrans('ACGT', 'TGCA')
return seq.translate(comp)[::-1]
def make_pwm_from_motif(motif):
"""Convenience PWM for an exact motif string."""
pwm = np.zeros((len(motif), 4), dtype=float)
for i, ch in enumerate(motif):
pwm[i, _base_to_idx[ch]] = 1.0
return pwm
def plot_pwm_heatmap(pwm, title='PWM'):
fig, ax = plt.subplots(figsize=(max(6, pwm.shape[0] / 2), 3))
df = pd.DataFrame(pwm, columns=BASES)
im = ax.imshow(df.T.values, aspect='auto', cmap='viridis', vmin=0, vmax=1)
ax.set_yticks(range(len(BASES)))
ax.set_yticklabels(BASES)
ax.set_xticks(range(pwm.shape[0]))
ax.set_xlabel('Position')
ax.set_ylabel('Base')
ax.set_title(title)
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.show()
2. Synthetic dataset generation¶
def generate_synthetic_dataset(num_pos=2000, num_neg=2000, seq_len=101,
motif='ACGTG', motif_prob=1.0, random_seed=42,
add_gaussian_noise=False, noise_std=0.05):
"""Generate synthetic sequences with a single implanted motif."""
rng = np.random.RandomState(random_seed)
def random_seq(L):
return ''.join(rng.choice(BASES) for _ in range(L))
pos_seqs = []
for _ in range(num_pos):
s = list(random_seq(seq_len))
if rng.rand() <= motif_prob:
k = len(motif)
start = rng.randint(0, seq_len - k + 1)
s[start:start + k] = list(motif)
pos_seqs.append(''.join(s))
neg_seqs = [random_seq(seq_len) for _ in range(num_neg)]
seqs = pos_seqs + neg_seqs
labels = np.array([1] * len(pos_seqs) + [0] * len(neg_seqs), dtype=np.int32)
perm = rng.permutation(len(seqs))
seqs = [seqs[i] for i in perm]
labels = labels[perm]
return seqs, labels
def generate_multi_motif_dataset(num_pos=2000, num_neg=2000, seq_len=101,
motif_a='ACGTG', motif_b='TGCAT',
motif_prob=1.0, cooccur_prob=0.5,
random_seed=42):
"""Generate positives containing motif_a, motif_b, or both."""
rng = np.random.RandomState(random_seed)
def random_seq(L):
return ''.join(rng.choice(BASES) for _ in range(L))
pos_seqs = []
for _ in range(num_pos):
s = list(random_seq(seq_len))
r = rng.rand()
if r <= cooccur_prob:
motifs = [motif_a, motif_b]
elif r <= (cooccur_prob + (1 - cooccur_prob) / 2):
motifs = [motif_a]
else:
motifs = [motif_b]
for motif in motifs:
if rng.rand() <= motif_prob:
k = len(motif)
start = rng.randint(0, seq_len - k + 1)
s[start:start + k] = list(motif)
pos_seqs.append(''.join(s))
neg_seqs = [random_seq(seq_len) for _ in range(num_neg)]
seqs = pos_seqs + neg_seqs
labels = np.array([1] * len(pos_seqs) + [0] * len(neg_seqs), dtype=np.int32)
perm = rng.permutation(len(seqs))
seqs = [seqs[i] for i in perm]
labels = labels[perm]
return seqs, labels
def generate_interaction_dataset(num_pos=2000, num_neg=2000, seq_len=101,
motif_a='ACGTG', motif_b='TGCAT',
gap_min=10, gap_max=25, random_seed=42):
"""Generate a harder interaction task."""
rng = np.random.RandomState(random_seed)
def random_seq(L):
return ''.join(rng.choice(BASES) for _ in range(L))
pos_seqs = []
for _ in range(num_pos):
s = list(random_seq(seq_len))
k1, k2 = len(motif_a), len(motif_b)
start1 = rng.randint(0, seq_len - k1 - gap_max - k2)
gap = rng.randint(gap_min, gap_max + 1)
start2 = start1 + k1 + gap
s[start1:start1 + k1] = list(motif_a)
s[start2:start2 + k2] = list(motif_b)
pos_seqs.append(''.join(s))
neg_seqs = []
for _ in range(num_neg):
s = list(random_seq(seq_len))
mode = rng.choice(['one_a', 'one_b', 'far_apart', 'none'])
if mode == 'one_a':
start = rng.randint(0, seq_len - len(motif_a) + 1)
s[start:start + len(motif_a)] = list(motif_a)
elif mode == 'one_b':
start = rng.randint(0, seq_len - len(motif_b) + 1)
s[start:start + len(motif_b)] = list(motif_b)
elif mode == 'far_apart':
start1 = rng.randint(0, seq_len - len(motif_a) - len(motif_b) - gap_max)
start2 = rng.randint(start1 + len(motif_a) + gap_max + 1, seq_len - len(motif_b) + 1)
s[start1:start1 + len(motif_a)] = list(motif_a)
s[start2:start2 + len(motif_b)] = list(motif_b)
neg_seqs.append(''.join(s))
seqs = pos_seqs + neg_seqs
labels = np.array([1] * len(pos_seqs) + [0] * len(neg_seqs), dtype=np.int32)
perm = rng.permutation(len(seqs))
seqs = [seqs[i] for i in perm]
labels = labels[perm]
return seqs, labels
SEQ_LEN = 101
MOTIF = 'ACGTG'
seqs, labels = generate_synthetic_dataset(num_pos=4000, num_neg=4000, seq_len=SEQ_LEN, motif=MOTIF, motif_prob=1.0)
X = one_hot_encode_seqs(seqs)
train_X, test_X, train_y, test_y, train_seqs, test_seqs = train_test_split(
X, labels, seqs, test_size=0.2, random_state=42, stratify=labels
)
train_X, val_X, train_y, val_y, train_seqs, val_seqs = train_test_split(
train_X, train_y, train_seqs, test_size=0.125, random_state=42, stratify=train_y
)
print('Train/Val/Test shapes:', train_X.shape, val_X.shape, test_X.shape)
print('Train positive fraction:', train_y.mean())
print('Example sequence:', train_seqs[0], 'label:', train_y[0])
Train/Val/Test shapes: (5600, 101, 4) (800, 101, 4) (1600, 101, 4) Train positive fraction: 0.5 Example sequence: CCGTTGTTATCATTGCATACACGTGTCCAATGCTTGCAAGTATGTAATGGTATCTTCAGATCATTCGGTTGCGGCTCAATCTTGCGTGGACTTTGCTACCT label: 1
3. Model implementation¶
def build_deepbind_like_model(input_length, num_filters=16, kernel_size=11, lr=1e-3, dense_units=32):
inp = layers.Input(shape=(input_length, 4), name='sequence')
x = layers.Conv1D(filters=num_filters, kernel_size=kernel_size, padding='valid', activation='relu', name='conv1')(inp)
x = layers.GlobalMaxPooling1D(name='gmp')(x)
x = layers.Dense(dense_units, activation='relu', name='dense1')(x)
out = layers.Dense(1, activation='sigmoid', name='output')(x)
model = models.Model(inputs=inp, outputs=out)
model.compile(optimizer=optimizers.Adam(learning_rate=lr),
loss=losses.BinaryCrossentropy(),
metrics=['accuracy'])
return model
def train_eval_model(train_X, train_y, val_X, val_y, test_X, test_y,
input_length, num_filters=16, kernel_size=5, lr=1e-3,
dense_units=32, epochs=10, batch_size=128, verbose=0):
model = build_deepbind_like_model(input_length, num_filters=num_filters,
kernel_size=kernel_size, lr=lr, dense_units=dense_units)
es = callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
history = model.fit(train_X, train_y,
validation_data=(val_X, val_y),
epochs=epochs, batch_size=batch_size,
callbacks=[es], verbose=verbose)
preds = model.predict(test_X, verbose=0).ravel()
roc = roc_auc_score(test_y, preds)
precision, recall, _ = precision_recall_curve(test_y, preds)
pr_auc = auc(recall, precision)
return model, history, preds, roc, pr_auc
Answer. The baseline model has one 1D convolution layer, one global max pooling layer, one hidden dense layer, and one sigmoid output layer.
For a convolution layer with F filters, kernel width K, and 4 input channels, the parameter count is:
$$ (K \times 4 \times F) + F $$
The +F term is the bias for each filter.
For the dense layers, if the convolution output after pooling has size F and the hidden layer has H units, then the dense parameters are:
$$ (F \times H) + H $$
The main hyper-parameters are num_filters, kernel_size, dense_units, learning rate, batch size, and epochs. In the data generator, the motif length, motif probability, and sequence length are also important hyper-parameters because they control how easy the task is.
For the output layer, $$ (H \times 1) + 1 $$
# Q1: inspect the model and count parameters
baseline_model = build_deepbind_like_model(SEQ_LEN, num_filters=16, kernel_size=5, lr=1e-3, dense_units=32)
baseline_model.summary()
print('Total parameters:', baseline_model.count_params())
2026-04-08 11:02:23.431018: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 2026-04-08 11:02:23.431045: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB 2026-04-08 11:02:23.431048: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB 2026-04-08 11:02:23.431063: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support. 2026-04-08 11:02:23.431074: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ sequence (InputLayer) │ (None, 101, 4) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv1 (Conv1D) │ (None, 97, 16) │ 336 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ gmp (GlobalMaxPooling1D) │ (None, 16) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense1 (Dense) │ (None, 32) │ 544 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ output (Dense) │ (None, 1) │ 33 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 913 (3.57 KB)
Trainable params: 913 (3.57 KB)
Non-trainable params: 0 (0.00 B)
Total parameters: 913
4. Baseline training and evaluation¶
baseline_model, history, preds, roc, pr_auc = train_eval_model(
train_X, train_y, val_X, val_y, test_X, test_y,
input_length=SEQ_LEN,
num_filters=16,
kernel_size=len(MOTIF),
lr=1e-3,
dense_units=32,
epochs=10,
batch_size=128,
verbose=1
)
print(f'Baseline test ROC AUC: {roc:.4f}')
print(f'Baseline test PR AUC: {pr_auc:.4f}')
plt.figure(figsize=(6, 3))
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.legend()
plt.title('Loss curves')
plt.tight_layout()
plt.show()
Epoch 1/10
2026-04-08 11:02:32.979321: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
44/44 ━━━━━━━━━━━━━━━━━━━━ 4s 15ms/step - accuracy: 0.5040 - loss: 0.6939 - val_accuracy: 0.5238 - val_loss: 0.6895 Epoch 2/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.5626 - loss: 0.6841 - val_accuracy: 0.5913 - val_loss: 0.6786 Epoch 3/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.6498 - loss: 0.6703 - val_accuracy: 0.6925 - val_loss: 0.6603 Epoch 4/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.7553 - loss: 0.6471 - val_accuracy: 0.7775 - val_loss: 0.6283 Epoch 5/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8168 - loss: 0.6090 - val_accuracy: 0.8262 - val_loss: 0.5795 Epoch 6/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8359 - loss: 0.5523 - val_accuracy: 0.8375 - val_loss: 0.5139 Epoch 7/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8490 - loss: 0.4816 - val_accuracy: 0.8413 - val_loss: 0.4463 Epoch 8/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8600 - loss: 0.4131 - val_accuracy: 0.8525 - val_loss: 0.3916 Epoch 9/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8778 - loss: 0.3582 - val_accuracy: 0.8775 - val_loss: 0.3506 Epoch 10/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8951 - loss: 0.3200 - val_accuracy: 0.8975 - val_loss: 0.3235 Baseline test ROC AUC: 0.9106 Baseline test PR AUC: 0.8471
kernel_sizes = [3, 5, 7, 11, 15]
results = []
models_by_kernel = {}
for k in kernel_sizes:
model_k, hist_k, preds_k, roc_k, pr_k = train_eval_model(
train_X, train_y, val_X, val_y, test_X, test_y,
input_length=SEQ_LEN,
num_filters=16,
kernel_size=k,
lr=1e-3,
dense_units=32,
epochs=10,
batch_size=128,
verbose=0
)
results.append({'kernel_size': k, 'ROC_AUC': roc_k, 'PR_AUC': pr_k})
models_by_kernel[k] = (model_k, hist_k)
print(f'k={k:2d} ROC AUC={roc_k:.4f} PR AUC={pr_k:.4f}')
results_df = pd.DataFrame(results)
results_df
k= 3 ROC AUC=0.5581 PR AUC=0.5442 k= 5 ROC AUC=0.9539 PR AUC=0.9196 k= 7 ROC AUC=0.9477 PR AUC=0.9118 k=11 ROC AUC=0.9410 PR AUC=0.9126 k=15 ROC AUC=0.9272 PR AUC=0.9007
| kernel_size | ROC_AUC | PR_AUC | |
|---|---|---|---|
| 0 | 3 | 0.558066 | 0.544188 |
| 1 | 5 | 0.953883 | 0.919563 |
| 2 | 7 | 0.947662 | 0.911830 |
| 3 | 11 | 0.940986 | 0.912588 |
| 4 | 15 | 0.927242 | 0.900721 |
Answer. Kernel sizes near the implanted motif length usually perform best, because the filter can match the full motif with minimal extra context. Smaller kernels may only see fragments of the motif, while much larger kernels can dilute the signal with irrelevant neighboring bases. The table above should show the best AUCs around the motif length or slightly larger, depending on random initialization and training dynamics.
import heapq
def get_conv_layer_output(model, X, layer_name='conv1'):
conv_model = models.Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
return conv_model.predict(X, verbose=0)
def collect_top_kmers(model, train_X, train_seqs, top_frac=0.01, layer_name='conv1'):
conv_out = get_conv_layer_output(model, train_X, layer_name=layer_name)
kernel = model.get_layer(layer_name).kernel_size[0]
num_filters = conv_out.shape[-1]
num_top = max(50, int(top_frac * train_X.shape[0]))
heaps = {f: [] for f in range(num_filters)}
for i in range(conv_out.shape[0]):
arr = conv_out[i]
L_out = arr.shape[0]
for pos in range(L_out):
for f in range(num_filters):
score = float(arr[pos, f])
h = heaps[f]
if len(h) < num_top:
heapq.heappush(h, (score, i, pos))
elif score > h[0][0]:
heapq.heapreplace(h, (score, i, pos))
filter_kmers = {f: [] for f in range(num_filters)}
filter_scores = {f: [] for f in range(num_filters)}
for f, heap in heaps.items():
top = sorted(heap, reverse=True)
for score, i, pos in top:
kmer = train_seqs[i][pos:pos + kernel]
if len(kmer) == kernel:
filter_kmers[f].append(kmer)
filter_scores[f].append(score)
return filter_kmers, filter_scores, conv_out
def build_pwm(kmers):
k = len(kmers[0])
pfm = np.zeros((k, 4), dtype=float)
for s in kmers:
for i, ch in enumerate(s):
pfm[i, _base_to_idx[ch]] += 1
pfm += 1e-6
pwm = pfm / pfm.sum(axis=1, keepdims=True)
return pwm
def pwm_similarity_to_motif(pwm, motif):
motif = motif[:pwm.shape[0]]
score = 0.0
for i, ch in enumerate(motif):
score += pwm[i, _base_to_idx[ch]]
return score / len(motif)
filter_kmers, filter_scores, conv_out = collect_top_kmers(baseline_model, train_X, train_seqs, top_frac=0.02, layer_name='conv1')
true_pwm = make_pwm_from_motif(MOTIF)
scores = []
for f in range(len(filter_kmers)):
if len(filter_kmers[f]) < 5:
continue
pwm_f = build_pwm(filter_kmers[f])
scores.append((f, pwm_similarity_to_motif(pwm_f, MOTIF), len(filter_kmers[f])))
scores = sorted(scores, key=lambda x: x[1], reverse=True)
scores[:5]
[(4, 0.7999999803571436, 112), (1, 0.5999999875000005, 112), (5, 0.39999999464285735, 112), (7, 0.39999999464285735, 112), (14, 0.39999999464285735, 112)]
import logomaker
from typing import Optional
def pwm_to_logo(pwm: np.ndarray, title: str = '', ax=None, save_path: Optional[str] = None):
"""Plot a PWM as a sequence logo using logomaker."""
pwm_norm = pwm / (pwm.sum(axis=1, keepdims=True) + 1e-9) # normalize rows to sum to 1
pwm_df = pd.DataFrame(pwm_norm, columns=['A', 'C', 'G', 'T'])
if ax is None:
fig, ax = plt.subplots(figsize=(10, 3))
else:
fig = ax.figure
logomaker.Logo(pwm_df, ax=ax, color_scheme='classic')
ax.set_title(title, fontsize=11, fontweight='bold')
ax.set_xlabel('Position')
ax.set_ylabel('Frequency')
ax.spines[['top', 'right']].set_visible(False)
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
return ax
def summarize_best_filter(scores, filter_kmers, true_pwm, motif,
save_path: Optional[str] = None):
"""Print summary and plot true vs learned PWM side by side as logos."""
best_filter = scores[0][0]
best_similarity = scores[0][1]
best_pwm = build_pwm(filter_kmers[best_filter])
# ── Console summary ───────────────────────────────────────────────────────
print(f"{'─'*45}")
print(f" Best filter : {best_filter}")
print(f" Similarity score : {best_similarity:.4f}")
print(f" Top 10 k-mers : {filter_kmers[best_filter][:10]}")
print(f"{'─'*45}")
# ── Top-5 filter ranking ──────────────────────────────────────────────────
print("\nTop 5 filters by motif similarity:")
print(f" {'Rank':<6} {'Filter':<8} {'Similarity':<12} {'N kmers'}")
for rank, (f, sim, n) in enumerate(scores[:5], 1):
print(f" {rank:<6} {f:<8} {sim:<12.4f} {n}")
# ── Side-by-side logo plot ────────────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(14, 3))
pwm_to_logo(true_pwm, title=f'True motif PWM: {motif}', ax=axes[0])
pwm_to_logo(best_pwm, title=f'Learned filter {best_filter} PWM', ax=axes[1])
plt.suptitle(f'Motif recovery | similarity = {best_similarity:.4f}',
fontsize=12, fontweight='bold', y=1.02)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
return best_filter, best_pwm
# ── Run ───────────────────────────────────────────────────────────────────────
best_filter, best_pwm = summarize_best_filter(
scores, filter_kmers, true_pwm, MOTIF,
save_path="motif_recovery.png"
)
───────────────────────────────────────────── Best filter : 4 Similarity score : 0.8000 Top 10 k-mers : ['AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG', 'AAGTG'] ───────────────────────────────────────────── Top 5 filters by motif similarity: Rank Filter Similarity N kmers 1 4 0.8000 112 2 1 0.6000 112 3 5 0.4000 112 4 7 0.4000 112 5 14 0.4000 112
Answer. In this case, the CNN did not recover the exact implanted motif, but it learned a closely related sequence pattern. This suggests the filter is capturing the discriminative local signal, though not necessarily the exact biological motif. In practice, CNN filters often recover approximate motifs or motif variants rather than a perfect consensus.
def add_gaussian_noise_to_onehot(X, std=0.05, random_seed=42):
rng = np.random.RandomState(random_seed)
Xn = X.astype(np.float32).copy()
Xn += rng.normal(loc=0.0, scale=std, size=Xn.shape).astype(np.float32)
Xn = np.clip(Xn, 0.0, 1.0)
return Xn
def run_noise_experiment(motif_prob_values=(1.0, 0.8, 0.6), noise_std=0.05):
rows = []
for mp in motif_prob_values:
seqs_i, labels_i = generate_synthetic_dataset(
num_pos=4000, num_neg=4000, seq_len=SEQ_LEN,
motif=MOTIF, motif_prob=mp, random_seed=42
)
X_i = one_hot_encode_seqs(seqs_i)
X_train_i, X_test_i, y_train_i, y_test_i, train_seqs_i, test_seqs_i = train_test_split(
X_i, labels_i, seqs_i, test_size=0.2, random_state=42, stratify=labels_i
)
X_train_i, X_val_i, y_train_i, y_val_i, _, _ = train_test_split(
X_train_i, y_train_i, train_seqs_i, test_size=0.125, random_state=42, stratify=y_train_i
)
X_train_noisy = add_gaussian_noise_to_onehot(X_train_i, std=noise_std, random_seed=42)
model_i, hist_i, preds_i, roc_i, pr_i = train_eval_model(
X_train_noisy, y_train_i, X_val_i, y_val_i, X_test_i, y_test_i,
input_length=SEQ_LEN,
num_filters=16,
kernel_size=len(MOTIF),
lr=1e-3,
dense_units=32,
epochs=10,
batch_size=128,
verbose=0
)
rows.append({'motif_prob': mp, 'noise_std': noise_std, 'ROC_AUC': roc_i, 'PR_AUC': pr_i})
print(f'motif_prob={mp:.1f} ROC AUC={roc_i:.4f} PR AUC={pr_i:.4f}')
return pd.DataFrame(rows)
noise_results = run_noise_experiment(motif_prob_values=(1.0, 0.8, 0.6), noise_std=0.05)
noise_results
motif_prob=1.0 ROC AUC=0.8244 PR AUC=0.7567 motif_prob=0.8 ROC AUC=0.8662 PR AUC=0.8573 motif_prob=0.6 ROC AUC=0.7149 PR AUC=0.6858
| motif_prob | noise_std | ROC_AUC | PR_AUC | |
|---|---|---|---|---|
| 0 | 1.0 | 0.05 | 0.824370 | 0.756652 |
| 1 | 0.8 | 0.05 | 0.866226 | 0.857294 |
| 2 | 0.6 | 0.05 | 0.714886 | 0.685760 |
Answer. Performance drops as the task becomes noisier. If the motif is implanted with probability less than 1, the positive class is less consistent and the CNN has fewer reliable motif examples to learn from. Adding Gaussian noise to the inputs also blurs the motif signal, which typically lowers ROC AUC and PR AUC. The model is most reliable when the motif is implanted consistently and the inputs are close to one-hot.
5. Multiple motifs: does the CNN learn separate motifs?¶
MOTIF_A = 'ACGTG'
MOTIF_B = 'TGCAT'
seqs2, labels2 = generate_multi_motif_dataset(
num_pos=4000, num_neg=4000, seq_len=SEQ_LEN,
motif_a=MOTIF_A, motif_b=MOTIF_B,
motif_prob=1.0, cooccur_prob=0.5, random_seed=7
)
X2 = one_hot_encode_seqs(seqs2)
X2_train, X2_test, y2_train, y2_test, seqs2_train, seqs2_test = train_test_split(
X2, labels2, seqs2, test_size=0.2, random_state=42, stratify=labels2
)
X2_train, X2_val, y2_train, y2_val, _, _ = train_test_split(
X2_train, y2_train, seqs2_train, test_size=0.125, random_state=42, stratify=y2_train
)
multi_model, multi_hist, multi_preds, multi_roc, multi_pr = train_eval_model(
X2_train, y2_train, X2_val, y2_val, X2_test, y2_test,
input_length=SEQ_LEN,
num_filters=24,
kernel_size=len(MOTIF_A),
lr=1e-3,
dense_units=32,
epochs=10,
batch_size=128,
verbose=1
)
print(f'Multi-motif ROC AUC: {multi_roc:.4f}')
print(f'Multi-motif PR AUC: {multi_pr:.4f}')
multi_filter_kmers, multi_filter_scores, _ = collect_top_kmers(multi_model, X2_train, seqs2_train, top_frac=0.02, layer_name='conv1')
motif_a_pwm = make_pwm_from_motif(MOTIF_A)
motif_b_pwm = make_pwm_from_motif(MOTIF_B)
multi_scores = []
for f in range(len(multi_filter_kmers)):
if len(multi_filter_kmers[f]) < 5:
continue
pwm_f = build_pwm(multi_filter_kmers[f])
sa = pwm_similarity_to_motif(pwm_f, MOTIF_A)
sb = pwm_similarity_to_motif(pwm_f, MOTIF_B)
multi_scores.append((f, sa, sb, len(multi_filter_kmers[f])))
multi_scores = sorted(multi_scores, key=lambda x: max(x[1], x[2]), reverse=True)
multi_scores[:10]
Epoch 1/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 2s 20ms/step - accuracy: 0.5216 - loss: 0.6916 - val_accuracy: 0.6212 - val_loss: 0.6816 Epoch 2/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.6530 - loss: 0.6764 - val_accuracy: 0.7862 - val_loss: 0.6453 Epoch 3/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.7793 - loss: 0.6323 - val_accuracy: 0.8225 - val_loss: 0.5701 Epoch 4/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.8166 - loss: 0.5532 - val_accuracy: 0.8400 - val_loss: 0.4733 Epoch 5/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.8341 - loss: 0.4619 - val_accuracy: 0.8450 - val_loss: 0.3968 Epoch 6/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.8487 - loss: 0.3933 - val_accuracy: 0.8612 - val_loss: 0.3525 Epoch 7/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.8567 - loss: 0.3530 - val_accuracy: 0.8625 - val_loss: 0.3293 Epoch 8/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8607 - loss: 0.3300 - val_accuracy: 0.8650 - val_loss: 0.3165 Epoch 9/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - accuracy: 0.8614 - loss: 0.3152 - val_accuracy: 0.8675 - val_loss: 0.3081 Epoch 10/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.8619 - loss: 0.3041 - val_accuracy: 0.8725 - val_loss: 0.3027 Multi-motif ROC AUC: 0.9421 Multi-motif PR AUC: 0.9313
[(8, 0.30714285510204087, 0.2267857151147959, 112), (1, 0.28749999866071435, 0.19107143067602037, 112), (7, 0.28035714177295923, 0.26428571377551024, 112), (13, 0.20892857289540814, 0.27857142755102043, 112), (22, 0.23035714355867346, 0.2767857133290817, 112), (15, 0.24464285733418362, 0.27321428488520416, 112), (18, 0.2696428564413266, 0.23035714355867346, 112), (14, 0.2678571422193877, 0.2428571431122449, 112), (19, 0.2267857151147959, 0.26607142799744904, 112), (6, 0.26249999955357145, 0.24285714311224488, 112)]
top = df.head(5)
print(top[["filter", "best_match", "best_sim", "info_content", "n_windows"]])
for f in top["filter"]:
pwm = build_pwm_from_windows(int_filter_kmers[f]) if 'build_pwm_from_windows' in globals() else build_pwm(int_filter_kmers[f])
plot_pwm_logo(pwm, title=f"Filter {f} ({'motif ' + df[df['filter']==f]['best_match'].iloc[0]})")
filter best_match best_sim info_content n_windows 19 19 B 0.291071 0.122691 112 21 21 A 0.257143 0.133296 112 17 17 B 0.250000 0.121927 112 23 23 B 0.251786 0.116732 112 8 8 A 0.262500 0.111452 112
<Figure size 600x300 with 0 Axes>
<Figure size 600x300 with 0 Axes>
<Figure size 600x300 with 0 Axes>
<Figure size 600x300 with 0 Axes>
<Figure size 600x300 with 0 Axes>
import logomaker
import pandas as pd
import matplotlib.pyplot as plt
BASES = ["A", "C", "G", "T"]
def plot_pwm_logo(pwm, title="PWM"):
df = pd.DataFrame(pwm, columns=BASES)
plt.figure(figsize=(max(6, pwm.shape[0] * 0.6), 3))
logomaker.Logo(df, color_scheme='classic')
plt.title(title)
plt.xlabel("Position")
plt.ylabel("Frequency")
plt.tight_layout()
plt.show()
best_a = max(multi_scores, key=lambda x: x[1])[0]
best_b = max(multi_scores, key=lambda x: x[2])[0]
print('Best filter for motif A:', best_a)
print('Best filter for motif B:', best_b)
# Build PWMs
pwm_a = build_pwm(multi_filter_kmers[best_a])
pwm_b = build_pwm(multi_filter_kmers[best_b])
# Plot logos
fig, axes = plt.subplots(1, 2, figsize=(10, 3))
logomaker.Logo(pd.DataFrame(pwm_a, columns=BASES), ax=axes[0])
axes[0].set_title(f'Filter {best_a}')
logomaker.Logo(pd.DataFrame(pwm_b, columns=BASES), ax=axes[1])
axes[1].set_title(f'Filter {best_b}')
plt.tight_layout()
plt.savefig("logo.png", dpi=150, bbox_inches='tight')
plt.show()
Best filter for motif A: 8 Best filter for motif B: 13
Answer.¶
With Multi-motif ROC AUC: 0.9421, Multi-motif PR AUC: 0.9313, the 1-CNN layer model understands the task very well. However, it Conv1 filter is messay and it does not give clean motifs. Each filter only captures part of the signal and motifs can be distributed across multiple filters. Our PWN window is very blurry. Hence, we might need another CNN layers to combine the signals.
display(Image(filename="dragonn_model_figure.jpg"))
6. Motif interactions: does the CNN capture co-occurrence or spacing relationships?¶
seqs3, labels3 = generate_interaction_dataset(
num_pos=4000, num_neg=4000, seq_len=SEQ_LEN,
motif_a=MOTIF_A, motif_b=MOTIF_B,
gap_min=10, gap_max=25, random_seed=13
)
X3 = one_hot_encode_seqs(seqs3)
X3_train, X3_test, y3_train, y3_test, seqs3_train, seqs3_test = train_test_split(
X3, labels3, seqs3, test_size=0.2, random_state=42, stratify=labels3
)
X3_train, X3_val, y3_train, y3_val, _, _ = train_test_split(
X3_train, y3_train, seqs3_train, test_size=0.125, random_state=42, stratify=y3_train
)
def build_interaction_model(input_length, num_filters=16, kernel_size=5, lr=1e-3):
inp = layers.Input(shape=(input_length, 4), name='sequence')
x = layers.Conv1D(num_filters, kernel_size, padding='valid', activation='relu', name='conv1')(inp)
x = layers.Conv1D(num_filters, kernel_size, padding='valid', activation='relu', name='conv2')(x)
x = layers.GlobalMaxPooling1D(name='gmp')(x)
x = layers.Dense(32, activation='relu', name='dense1')(x)
out = layers.Dense(1, activation='sigmoid', name='output')(x)
model = models.Model(inp, out)
model.compile(optimizer=optimizers.Adam(learning_rate=lr), loss='binary_crossentropy', metrics=['accuracy'])
return model
interaction_model = build_interaction_model(SEQ_LEN, num_filters=24, kernel_size=len(MOTIF_A), lr=1e-3)
interaction_model.summary()
es = callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
hist3 = interaction_model.fit(X3_train, y3_train, validation_data=(X3_val, y3_val), epochs=10, batch_size=128, callbacks=[es], verbose=1)
pred3 = interaction_model.predict(X3_test, verbose=0).ravel()
roc3 = roc_auc_score(y3_test, pred3)
precision3, recall3, _ = precision_recall_curve(y3_test, pred3)
pr_auc3 = auc(recall3, precision3)
print(f'Interaction-task ROC AUC: {roc3:.4f}')
print(f'Interaction-task PR AUC: {pr_auc3:.4f}')
plt.figure(figsize=(6, 3))
plt.plot(hist3.history['loss'], label='train loss')
plt.plot(hist3.history['val_loss'], label='val loss')
plt.legend()
plt.title('Interaction task loss')
plt.tight_layout()
plt.show()
int_filter_kmers, int_filter_scores, conv3 = collect_top_kmers(interaction_model, X3_train, seqs3_train, top_frac=0.02, layer_name='conv1')
interaction_scores = []
for f in range(len(int_filter_kmers)):
if len(int_filter_kmers[f]) < 5:
continue
pwm_f = build_pwm(int_filter_kmers[f])
sa = pwm_similarity_to_motif(pwm_f, MOTIF_A)
sb = pwm_similarity_to_motif(pwm_f, MOTIF_B)
interaction_scores.append((f, sa, sb))
interaction_scores = sorted(interaction_scores, key=lambda x: max(x[1], x[2]), reverse=True)
interaction_scores[:10]
Model: "functional_31"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ sequence (InputLayer) │ (None, 101, 4) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv1 (Conv1D) │ (None, 97, 24) │ 504 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2 (Conv1D) │ (None, 93, 24) │ 2,904 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ gmp (GlobalMaxPooling1D) │ (None, 24) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense1 (Dense) │ (None, 32) │ 800 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ output (Dense) │ (None, 1) │ 33 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 4,241 (16.57 KB)
Trainable params: 4,241 (16.57 KB)
Non-trainable params: 0 (0.00 B)
Epoch 1/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 2s 25ms/step - accuracy: 0.5044 - loss: 0.6971 - val_accuracy: 0.6150 - val_loss: 0.6781 Epoch 2/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.7110 - loss: 0.6590 - val_accuracy: 0.7362 - val_loss: 0.6012 Epoch 3/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - accuracy: 0.7823 - loss: 0.5559 - val_accuracy: 0.8300 - val_loss: 0.4878 Epoch 4/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.8443 - loss: 0.4382 - val_accuracy: 0.8425 - val_loss: 0.4201 Epoch 5/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - accuracy: 0.8635 - loss: 0.3775 - val_accuracy: 0.8400 - val_loss: 0.3990 Epoch 6/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.8639 - loss: 0.3547 - val_accuracy: 0.8400 - val_loss: 0.3941 Epoch 7/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.8645 - loss: 0.3447 - val_accuracy: 0.8400 - val_loss: 0.3942 Epoch 8/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - accuracy: 0.8655 - loss: 0.3396 - val_accuracy: 0.8388 - val_loss: 0.3945 Epoch 9/10 44/44 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - accuracy: 0.8658 - loss: 0.3362 - val_accuracy: 0.8388 - val_loss: 0.3950 Interaction-task ROC AUC: 0.8738 Interaction-task PR AUC: 0.8013
[(19, 0.20535714445153058, 0.2910714271045919), (5, 0.28749999866071435, 0.23035714355867348), (15, 0.22500000089285713, 0.282142855994898), (14, 0.2589285711096939, 0.27857142755102043), (12, 0.27142857066326537, 0.23214285778061222), (7, 0.23035714355867348, 0.2678571422193878), (11, 0.2428571431122449, 0.26428571377551024), (8, 0.26249999955357145, 0.23392857200255096), (18, 0.2428571431122449, 0.26249999955357145), (1, 0.26071428533163266, 0.23750000044642855)]
yhat = (pred3 >= 0.5).astype(int)
correct_pos = np.where((y3_test == 1) & (yhat == 1))[0]
correct_neg = np.where((y3_test == 0) & (yhat == 0))[0]
if len(correct_pos) > 0:
i_pos = correct_pos[0]
plot_two_layer_heatmaps(
seqs3_test[i_pos],
interaction_model,
motif_a=MOTIF_A,
motif_b=MOTIF_B,
conv1_name="conv1",
conv2_name="conv2",
title="Positive example: two-layer CNN detects motif interaction"
)
else:
print("No correctly predicted positive example found.")
best_filters = [f for f, sa, sb in interaction_scores[:5]]
print(best_filters)
[19, 5, 15, 14, 12]
def plot_clean_two_layer_heatmaps(seq, model, important_filters,
conv1_name="conv1", conv2_name="conv2",
save_path=None):
x = one_hot_encode_seqs([seq])
conv1_out = get_conv_layer_output(model, x, layer_name=conv1_name)[0]
conv2_out = get_conv_layer_output(model, x, layer_name=conv2_name)[0]
# Select only important filters
conv1_sel = conv1_out[:, important_filters].T
conv2_sel = conv2_out[:, important_filters].T
fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True)
sns.heatmap(conv1_sel, ax=axes[0], cmap="viridis", cbar=True)
axes[0].set_title("Conv1 (important filters only)")
axes[0].set_ylabel("Filters")
sns.heatmap(conv2_sel, ax=axes[1], cmap="magma", cbar=True)
axes[1].set_title("Conv2 (interaction filters)")
axes[1].set_xlabel("Sequence position")
if save_path is not None:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.tight_layout()
plt.show()
yhat = (pred3 >= 0.5).astype(int)
correct_pos = np.where((y3_test == 1) & (yhat == 1))[0]
correct_neg = np.where((y3_test == 0) & (yhat == 0))[0]
if len(correct_pos) > 0:
i_pos = correct_pos[0]
plot_two_layer_heatmaps(
seqs3_test[i_pos],
interaction_model,
motif_a=MOTIF_A,
motif_b=MOTIF_B,
conv1_name="conv1",
conv2_name="conv2",
title="Positive example: two-layer CNN detects motif interaction"
)
else:
print("No correctly predicted positive example found.")
important_filters = [f for f, _, _ in interaction_scores[:5]]
plot_clean_two_layer_heatmaps(
seqs3_test[i_pos],
interaction_model,
important_filters,
save_path="two_layer.png"
)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def plot_two_layer_heatmaps(seq, model, motif_a=None, motif_b=None,
conv1_name="conv1", conv2_name="conv2",
title="Two-layer CNN motif interaction",
save_path=None):
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
# Encode sequence
x = one_hot_encode_seqs([seq])
# Get activations
conv1_out = get_conv_layer_output(model, x, layer_name=conv1_name)[0]
conv2_out = get_conv_layer_output(model, x, layer_name=conv2_name)[0]
conv1_hm = conv1_out.T
conv2_hm = conv2_out.T
# Normalize rows
def row_norm(mat):
m = mat.copy()
m_min = m.min(axis=1, keepdims=True)
m_max = m.max(axis=1, keepdims=True)
denom = np.where((m_max - m_min) == 0, 1.0, (m_max - m_min))
return (m - m_min) / denom
conv1_hm = row_norm(conv1_hm)
conv2_hm = row_norm(conv2_hm)
# Find motif positions
def find_all(s, motif):
if motif is None:
return []
pos = []
i = 0
while True:
i = s.find(motif, i)
if i == -1:
break
pos.append(i)
i += 1
return pos
pos_a = find_all(seq, motif_a)
pos_b = find_all(seq, motif_b)
# Create figure
fig, axes = plt.subplots(
3, 1, figsize=(14, 7),
gridspec_kw={"height_ratios": [0.7, 1.4, 1.4]},
sharex=True
)
# Top: motif locations
axes[0].set_title(title)
for p in pos_a:
axes[0].axvspan(p, p + len(motif_a), color="blue", alpha=0.3)
for p in pos_b:
axes[0].axvspan(p, p + len(motif_b), color="orange", alpha=0.3)
axes[0].set_yticks([])
axes[0].set_ylabel("Motifs")
# Conv1 heatmap
sns.heatmap(conv1_hm, ax=axes[1], cmap="viridis", cbar=True)
axes[1].set_ylabel("Conv1 filters")
axes[1].set_title("Layer 1: local motif detectors")
# Conv2 heatmap
sns.heatmap(conv2_hm, ax=axes[2], cmap="magma", cbar=True)
axes[2].set_ylabel("Conv2 filters")
axes[2].set_title("Layer 2: motif interactions")
axes[2].set_xlabel("Sequence position")
plt.tight_layout()
# Save BEFORE show
if save_path is not None:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
best_int_a = max(interaction_scores, key=lambda x: x[1])[0]
best_int_b = max(interaction_scores, key=lambda x: x[2])[0]
print('Best interaction filter for motif A:', best_int_a)
print('Best interaction filter for motif B:', best_int_b)
plot_pwm_heatmap(build_pwm(int_filter_kmers[best_int_a]), title=f'Interaction task filter {best_int_a} (motif A candidate)')
plot_pwm_heatmap(build_pwm(int_filter_kmers[best_int_b]), title=f'Interaction task filter {best_int_b} (motif B candidate)')
# Check whether the two best filters co-activate across sequences.
def sequence_filter_max_activations(model, X, layer_name='conv1'):
conv_out = get_conv_layer_output(model, X, layer_name=layer_name)
return conv_out.max(axis=1)
max_act = sequence_filter_max_activations(interaction_model, X3_test, layer_name='conv1')
fa = max_act[:, best_int_a]
fb = max_act[:, best_int_b]
plt.figure(figsize=(4, 4))
plt.scatter(fa, fb, s=8, alpha=0.5)
plt.xlabel(f'Filter {best_int_a} max activation')
plt.ylabel(f'Filter {best_int_b} max activation')
plt.title('Filter co-activation on interaction task')
plt.tight_layout()
plt.show()
print('Correlation between the two filter activations:', np.corrcoef(fa, fb)[0, 1])
Best interaction filter for motif A: 5 Best interaction filter for motif B: 19
Correlation between the two filter activations: -0.04719416329022643
Answer. The CNN can capture motif interactions, but mainly through architecture depth and co-activation patterns. A shallow model mostly learns local motifs. A deeper model can combine multiple motif detectors and then learn whether they co-occur or appear with the right spacing. In the interaction task above, the first convolution layer tends to specialize to the two motifs, while the deeper layer helps integrate those signals into the class decision.
Final short answers¶
Q5: Does the CNN learn separate filters? Yes. Different filters specialize to different local motifs or motif variants.
Q6: Does it capture motif interactions? Partly yes. The CNN captures interactions best when there are multiple filters and at least two layers, so that local motif detectors can be combined into a higher-level decision about co-occurrence or spacing. A single shallow layer mostly captures local motif presence, not rich long-range interaction structure.
The dense layer can capture interactions between motifs by combining activations from different convolutional filters, allowing the model to learn co-occurrence patterns and nonlinear relationships. However, because spatial information is largely lost after flattening, the dense layer cannot effectively model positional relationships such as spacing or order between motifs. Therefore, while it contributes to interaction modeling, it is limited compared to deeper convolutional layers.
Additional Resources:¶
Exploring convolutional neural network (CNN) architectures for simulated genomic data as well as ENCODE TF ChIP-seq datasets. (DragoNN) https://colab.research.google.com/github/kundajelab/dragonn/blob/master/tutorials/ENCODE_Jamboree_July2019_DragonnTutorial.ipynb#scrollTo=APNfq3knhfZg