02 Dataset and NN Architecture API

Documentation for our dataset and neural network architecture API.

1 Libraries

import numpy as np
import matplotlib.pyplot as plt
plt.style.use('dark_background')
import seaborn as sns
import pandas as pd
import re
import torch
from torch.utils.data import Dataset, DataLoader
from src.utils import ChIP_Nexus_Dataset
from src.architectures import BPNet

2 Dataset API

INPUT_DIR = "/home/philipp/BPNet/input/"

2.1 Example 1: Create Train Dataset for all TFs

One has to provide the set which must be one of “train”, “tune”, “test” as well as the input directory and the list of TFs one wants to model.

whole_dataset = ChIP_Nexus_Dataset(set_name="train", 
                                   input_dir=INPUT_DIR, 
                                   TF_list=['Sox2', 'Oct4', 'Klf4', 'Nanog'])
whole_dataset
ChIP_Nexus_Dataset
Set: train
TFs: ['Sox2', 'Oct4', 'Klf4', 'Nanog']
Size: 93904

Check the shapes via the check_shapes() method.

whole_dataset.check_shapes()
self.tf_list=['Sox2', 'Oct4', 'Klf4', 'Nanog']
self.one_hot_seqs.shape=(93904, 4, 1000) [idx, bases, pwidth]
self.tf_counts.shape=(93904, 4, 2, 1000) [idx, TF, strand, pwidth]
self.ctrl_counts.shape=(93904, 2, 1000) [idx, strand, pwidth]
self.ctrl_counts_smooth.shape=(93904, 2, 1000) [idx, strand, pwidth]

2.2 Example 2: Create Train Dataset for Sox2

If we only want to take one or a few TFs into consideration we can specify which ones using the TF_list parameter. The constructor method will take care of everything and only keep the peaks that are specific to the TFs in the TF_list.

small_dataset = ChIP_Nexus_Dataset(set_name="train", 
                                   input_dir=INPUT_DIR, 
                                   TF_list=['Sox2'])
small_dataset
ChIP_Nexus_Dataset
Set: train
TFs: ['Sox2']
Size: 6748
small_dataset.check_shapes()
self.tf_list=['Sox2']
self.one_hot_seqs.shape=(6748, 4, 1000) [idx, bases, pwidth]
self.tf_counts.shape=(6748, 1, 2, 1000) [idx, TF, strand, pwidth]
self.ctrl_counts.shape=(6748, 2, 1000) [idx, strand, pwidth]
self.ctrl_counts_smooth.shape=(6748, 2, 1000) [idx, strand, pwidth]

2.3 Example 3: Create Train Dataset for Sox2 and High-Confidence Peaks

We might also want to filter peaks based on the qValue.

cutoff = 4.5
sns.histplot(np.log2(small_dataset.region_info.qValue))
plt.xlabel("Log2 qValue")
plt.title("Distribution of qValues")
plt.axvline(cutoff, color="red")
plt.show()

Looking at the histogram of the log2 qValue, we might decide to only keep peaks with a log2 qValue above 4.5.

highconf_dataset = ChIP_Nexus_Dataset(set_name="train", 
                                      input_dir=INPUT_DIR, 
                                      TF_list=["Sox2"],
                                      qval_thr=2**cutoff)
highconf_dataset
ChIP_Nexus_Dataset
Set: train
TFs: ['Sox2']
Size: 4182
highconf_dataset.check_shapes()
self.tf_list=['Sox2']
self.one_hot_seqs.shape=(4182, 4, 1000) [idx, bases, pwidth]
self.tf_counts.shape=(4182, 1, 2, 1000) [idx, TF, strand, pwidth]
self.ctrl_counts.shape=(4182, 2, 1000) [idx, strand, pwidth]
self.ctrl_counts_smooth.shape=(4182, 2, 1000) [idx, strand, pwidth]

2.4 Example 4: Create Train Dataset for Sox2 but keep all Regions

Now we might also want to create a training set that contains all the regions but only the counts for Sox2.

sox2_all_regions = ChIP_Nexus_Dataset(set_name="train", 
                                      input_dir=INPUT_DIR, 
                                      TF_list=["Sox2"], 
                                      subset=False)
sox2_all_regions
ChIP_Nexus_Dataset
Set: train
TFs: ['Sox2']
Size: 93904
sox2_all_regions.check_shapes()
self.tf_list=['Sox2']
self.one_hot_seqs.shape=(93904, 4, 1000) [idx, bases, pwidth]
self.tf_counts.shape=(93904, 1, 2, 1000) [idx, TF, strand, pwidth]
self.ctrl_counts.shape=(93904, 2, 1000) [idx, strand, pwidth]
self.ctrl_counts_smooth.shape=(93904, 2, 1000) [idx, strand, pwidth]

3 Architecture API

3.1 Example 1: One TF, Shape Prediction, No Bias Track

model_1 = BPNet(n_dil_layers=9, TF_list=["Sox2"], pred_total=False, bias_track=False)
model_1
BPNet(
  (base_model): ConvLayers(
    (conv_layers): ModuleList(
      (0): Conv1d(4, 64, kernel_size=(25,), stride=(1,), padding=same)
      (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
      (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(4,))
      (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(8,))
      (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(16,))
      (5): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(32,))
      (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(64,))
      (7): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(128,))
      (8): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(256,))
      (9): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(512,))
    )
  )
  (profile_heads): ModuleList(
    (0): ProfileShapeHead(
      (deconv): ConvTranspose1d(64, 2, kernel_size=(25,), stride=(1,), padding=(12,))
    )
  )
)

3.2 Example 2: One TF, Shape & Total Counts Prediction, No Bias Track

model_2 = BPNet(n_dil_layers=9, TF_list=["Sox2"], pred_total=True, bias_track=False)
model_2
BPNet(
  (base_model): ConvLayers(
    (conv_layers): ModuleList(
      (0): Conv1d(4, 64, kernel_size=(25,), stride=(1,), padding=same)
      (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
      (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(4,))
      (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(8,))
      (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(16,))
      (5): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(32,))
      (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(64,))
      (7): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(128,))
      (8): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(256,))
      (9): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(512,))
    )
  )
  (profile_heads): ModuleList(
    (0): ProfileShapeHead(
      (deconv): ConvTranspose1d(64, 2, kernel_size=(25,), stride=(1,), padding=(12,))
    )
  )
  (count_heads): ModuleList(
    (0): TotalCountHead(
      (fc1): Linear(in_features=64, out_features=32, bias=True)
      (fc2): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)

3.3 Example 3: One TF, Shape & Total Counts Prediction, Bias

model_3 = BPNet(n_dil_layers=9, TF_list=["Sox2"], pred_total=True, bias_track=True)
model_3
BPNet(
  (base_model): ConvLayers(
    (conv_layers): ModuleList(
      (0): Conv1d(4, 64, kernel_size=(25,), stride=(1,), padding=same)
      (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
      (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(4,))
      (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(8,))
      (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(16,))
      (5): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(32,))
      (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(64,))
      (7): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(128,))
      (8): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(256,))
      (9): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(512,))
    )
  )
  (profile_heads): ModuleList(
    (0): ProfileShapeHead(
      (deconv): ConvTranspose1d(64, 2, kernel_size=(25,), stride=(1,), padding=(12,))
    )
  )
  (count_heads): ModuleList(
    (0): TotalCountHead(
      (fc1): Linear(in_features=64, out_features=32, bias=True)
      (fc2): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)

Features bias weights.

model_3.profile_heads[0].bias_weights
Parameter containing:
tensor([0.0100, 0.0100], requires_grad=True)

3.4 Example 4: All TFs, Shape & Total Counts Prediction, Bias

model_4 = BPNet(n_dil_layers=9, TF_list=["Sox2", "Oct4", "Nanog", "Klf4"], pred_total=True, bias_track=True)
model_4
BPNet(
  (base_model): ConvLayers(
    (conv_layers): ModuleList(
      (0): Conv1d(4, 64, kernel_size=(25,), stride=(1,), padding=same)
      (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
      (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(4,))
      (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(8,))
      (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(16,))
      (5): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(32,))
      (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(64,))
      (7): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(128,))
      (8): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(256,))
      (9): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=same, dilation=(512,))
    )
  )
  (profile_heads): ModuleList(
    (0): ProfileShapeHead(
      (deconv): ConvTranspose1d(64, 2, kernel_size=(25,), stride=(1,), padding=(12,))
    )
    (1): ProfileShapeHead(
      (deconv): ConvTranspose1d(64, 2, kernel_size=(25,), stride=(1,), padding=(12,))
    )
    (2): ProfileShapeHead(
      (deconv): ConvTranspose1d(64, 2, kernel_size=(25,), stride=(1,), padding=(12,))
    )
    (3): ProfileShapeHead(
      (deconv): ConvTranspose1d(64, 2, kernel_size=(25,), stride=(1,), padding=(12,))
    )
  )
  (count_heads): ModuleList(
    (0): TotalCountHead(
      (fc1): Linear(in_features=64, out_features=32, bias=True)
      (fc2): Linear(in_features=32, out_features=2, bias=True)
    )
    (1): TotalCountHead(
      (fc1): Linear(in_features=64, out_features=32, bias=True)
      (fc2): Linear(in_features=32, out_features=2, bias=True)
    )
    (2): TotalCountHead(
      (fc1): Linear(in_features=64, out_features=32, bias=True)
      (fc2): Linear(in_features=32, out_features=2, bias=True)
    )
    (3): TotalCountHead(
      (fc1): Linear(in_features=64, out_features=32, bias=True)
      (fc2): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)

4 Appendix

4.1 Recreate Figure 1 e

test_dataset = ChIP_Nexus_Dataset(set_name="test", 
                                  input_dir=INPUT_DIR, 
                                  TF_list=['Oct4', 'Sox2', 'Nanog', 'Klf4'])
test_dataset
ChIP_Nexus_Dataset
Set: test
TFs: ['Oct4', 'Sox2', 'Nanog', 'Klf4']
Size: 27727
tmp_df = test_dataset.region_info.copy().reset_index()
idx = tmp_df.loc[(tmp_df.seqnames=="chr1") & (tmp_df.start > 180924752-1000) & (tmp_df.end < 180925152+1000)].index.to_numpy()[0]

diff = 180924752 - tmp_df.start[idx] + 1
w = 400

fig, axis = plt.subplots(4, 1, figsize=(6, 14))

for ax, (i, tf) in zip(axis, enumerate(test_dataset.tf_list)):
  ax.plot(test_dataset.tf_counts[idx, i, 0, diff:(diff+w)], label="pos")
  ax.plot(-test_dataset.tf_counts[idx, i, 1, diff:(diff+w)], label="neg")
  ax.legend()
  ax.set_title(tf)
plt.show()

4.2 Check One-Hot Encoding

To check whether the one-hot encoding worked as expected, we compare here:

  1. The one-hot encoded sequence as stored in the test dataset
plt.imshow(test_dataset.one_hot_seqs[idx, :, diff:(diff+w)], interpolation="none", aspect="auto")
plt.title("One-Hot Encoding from Test Dataset")
plt.yticks([0, 1, 2, 3], labels=["A", "C", "G", "T"])
plt.show()

  1. The one-hot encoded sequence obtained from reading in the mm10 genome and one-hot encoding corresponding sequence
from Bio.Seq import Seq
from Bio import SeqIO
mm10_ref = SeqIO.to_dict(SeqIO.parse(f"../ref/mm10.fa", "fasta"))
seq = mm10_ref[tmp_df.iloc[idx]["seqnames"]][180924752:180925152]
one_hot_seq = np.zeros((4, 400))
for i, letter in enumerate(np.array(seq.seq)):
  if letter=="A": one_hot_seq[0, i] = 1
  if letter=="C": one_hot_seq[1, i] = 1
  if letter=="G": one_hot_seq[2, i] = 1
  if letter=="T": one_hot_seq[3, i] = 1
plt.imshow(one_hot_seq, interpolation="none", aspect="auto")
plt.yticks([0, 1, 2, 3], labels=["A", "C", "G", "T"])
plt.title("One-Hot Encoding based on Reference Sequence")
plt.show()

for the peak seen in Figure 1e

np.all(test_dataset.one_hot_seqs[idx, :, diff:(diff+w)] == one_hot_seq)
True

And we see that we get exactly the same.