File size: 2,402 Bytes
65bd8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Ten Species Dataset.

Load dataset from HF; tokenize 'on-the-fly'
"""

import random

import datasets
import torch
import transformers

STRING_COMPLEMENT_MAP = {
  "A": "T", "C": "G", "G": "C", "T": "A",
  "a": "t", "c": "g", "g": "c", "t": "a",
  "N": "N", "n": "n",
}


def coin_flip(p=0.5):
    """Flip a (potentially weighted) coin."""
    return random.random() > p


def string_reverse_complement(seq):
    """Reverse complement a DNA sequence."""
    rev_comp = ""
    for base in seq[::-1]:
        if base in STRING_COMPLEMENT_MAP:
            rev_comp += STRING_COMPLEMENT_MAP[base]
        # if bp not complement map, use the same bp
        else:
            rev_comp += base
    return rev_comp

class TenSpeciesDataset(torch.utils.data.Dataset):
  """Ten Species Dataset.

  Tokenization happens on the fly.
  """
  def __init__(
      self,
      split: str,
      tokenizer: transformers.PreTrainedTokenizer,
      max_length: int = 1024,
      rc_aug: bool = False,
      add_special_tokens: bool = False,
      dataset=None):
    if dataset is None:
      dataset = datasets.load_dataset(
        'yairschiff/ten_species',
        split='train',  # original dataset only has `train` split
        chunk_length=max_length,
        overlap=0,
        trust_remote_code=True)
      self.dataset = dataset.train_test_split(
        test_size=0.05, seed=42)[split]  # hard-coded seed & size
    else:
      self.dataset = dataset
    self.tokenizer = tokenizer
    self.max_length = max_length
    self.rc_aug = rc_aug
    self.add_special_tokens = add_special_tokens

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    """Returns a sequence and species label."""
    seq = self.dataset[idx]['sequence']
    if self.rc_aug and coin_flip():
      seq = string_reverse_complement(seq)
    seq = self.tokenizer(
      seq,
      max_length=self.max_length,
      padding="max_length",
      truncation=True,
      add_special_tokens=self.add_special_tokens,
      return_attention_mask=True)

    input_ids = seq['input_ids']
    attention_mask = seq['attention_mask']
    input_ids = torch.LongTensor(input_ids)
    attention_mask = torch.LongTensor(attention_mask)

    return {
      'input_ids': input_ids,
      'attention_mask': attention_mask,
      'species_label': torch.LongTensor([
        self.dataset[idx]['species_label']]).squeeze(),
    }