saketh11 commited on
Commit
6e9b5dc
·
1 Parent(s): 2d634e1

Add local CodonTransformer modules for custom ColiFormer functionality

Browse files

- Removed CodonTransformer PyPI package dependency
- Added local CodonTransformer/ directory with custom modifications
- This includes your enhanced ColiFormer-specific functionality
- App now uses your custom CodonTransformer implementation instead of standard package
- Fixes ModuleNotFoundError: No module named 'CodonTransformer'

CodonTransformer/CodonData.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonData.py
3
+ ---------------------
4
+ Includes helper functions for preprocessing NCBI or Kazusa databases and
5
+ preparing the data for training and inference of the CodonTransformer model.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import random
11
+ from typing import Dict, List, Optional, Tuple, Union
12
+
13
+ import pandas as pd
14
+ import python_codon_tables as pct
15
+ from Bio import SeqIO
16
+ from Bio.Seq import Seq
17
+ from sklearn.utils import shuffle as sk_shuffle
18
+ from tqdm import tqdm
19
+
20
+ from CodonTransformer.CodonUtils import (
21
+ AMBIGUOUS_AMINOACID_MAP,
22
+ AMINO2CODON_TYPE,
23
+ AMINO_ACIDS,
24
+ ORGANISM2ID,
25
+ START_CODONS,
26
+ STOP_CODONS,
27
+ STOP_SYMBOL,
28
+ STOP_SYMBOLS,
29
+ ProteinConfig,
30
+ find_pattern_in_fasta,
31
+ get_taxonomy_id,
32
+ sort_amino2codon_skeleton,
33
+ )
34
+
35
+
36
+ def prepare_training_data(
37
+ dataset: Union[str, pd.DataFrame], output_file: str, shuffle: bool = True
38
+ ) -> None:
39
+ """
40
+ Prepare a JSON dataset for training the CodonTransformer model.
41
+
42
+ Input dataset should have columns below:
43
+ - dna: str (DNA sequence)
44
+ - protein: str (Protein sequence)
45
+ - organism: Union[int, str] (ID or Name of the organism)
46
+
47
+ The output JSON dataset will have the following format:
48
+ {"idx": 0, "codons": "M_ATG R_AGG L_TTG L_CTA R_CGA __TAG", "organism": 51}
49
+ {"idx": 1, "codons": "M_ATG K_AAG C_TGC F_TTT F_TTC __TAA", "organism": 59}
50
+
51
+ Args:
52
+ dataset (Union[str, pd.DataFrame]): Input dataset in CSV or DataFrame format.
53
+ output_file (str): Path to save the output JSON dataset.
54
+ shuffle (bool, optional): Whether to shuffle the dataset before saving.
55
+ Defaults to True.
56
+
57
+ Returns:
58
+ None
59
+ """
60
+ if isinstance(dataset, str):
61
+ dataset = pd.read_csv(dataset)
62
+
63
+ required_columns = {"dna", "protein", "organism"}
64
+ if not required_columns.issubset(dataset.columns):
65
+ raise ValueError(f"Input dataset must have columns: {required_columns}")
66
+
67
+ # Prepare the dataset for finetuning
68
+ dataset["codons"] = dataset.apply(
69
+ lambda row: get_merged_seq(row["protein"], row["dna"], separator="_"), axis=1
70
+ )
71
+
72
+ # Replace organism str with organism id using ORGANISM2ID
73
+ dataset["organism"] = dataset["organism"].apply(
74
+ lambda org: process_organism(org, ORGANISM2ID)
75
+ )
76
+
77
+ # Save the dataset to a JSON file
78
+ dataframe_to_json(dataset[["codons", "organism"]], output_file, shuffle=shuffle)
79
+
80
+
81
+ def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) -> None:
82
+ """
83
+ Convert pandas DataFrame to JSON file format suitable for training CodonTransformer.
84
+
85
+ This function takes a preprocessed DataFrame and writes it to a JSON file
86
+ where each line is a JSON object representing a single record.
87
+
88
+ Args:
89
+ df (pd.DataFrame): The input DataFrame with 'codons' and 'organism' columns.
90
+ output_file (str): Path to the output JSON file.
91
+ shuffle (bool, optional): Whether to shuffle the dataset before saving.
92
+ Defaults to True.
93
+
94
+ Returns:
95
+ None
96
+
97
+ Raises:
98
+ ValueError: If the required columns are not present in the DataFrame.
99
+ """
100
+ required_columns = {"codons", "organism"}
101
+ if not required_columns.issubset(df.columns):
102
+ raise ValueError(f"DataFrame must contain columns: {required_columns}")
103
+
104
+ print(f"\nStarted writing to {output_file}...")
105
+
106
+ # Shuffle the DataFrame if requested
107
+ if shuffle:
108
+ df = sk_shuffle(df)
109
+
110
+ # Write the DataFrame to a JSON file
111
+ with open(output_file, "w") as f:
112
+ for idx, row in tqdm(
113
+ df.iterrows(), total=len(df), desc="Writing JSON...", unit=" records"
114
+ ):
115
+ doc = {"idx": idx, "codons": row["codons"], "organism": row["organism"]}
116
+ f.write(json.dumps(doc) + "\n")
117
+
118
+ print(f"\nTotal Entries Saved: {len(df)}, JSON data saved to {output_file}")
119
+
120
+
121
+ def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) -> int:
122
+ """
123
+ Process and validate the organism input, converting it to a valid organism ID.
124
+
125
+ This function handles both string (organism name) and integer (organism ID) inputs.
126
+ It validates the input against a provided mapping of organism names to IDs.
127
+
128
+ Args:
129
+ organism (Union[str, int]): Input organism, either as a name (str) or ID (int).
130
+ organism_to_id (Dict[str, int]): Dictionary mapping organism names to their
131
+ corresponding IDs.
132
+
133
+ Returns:
134
+ int: The validated organism ID.
135
+
136
+ Raises:
137
+ ValueError: If the input is an invalid organism name or ID.
138
+ TypeError: If the input is neither a string nor an integer.
139
+ """
140
+ if isinstance(organism, str):
141
+ if organism not in organism_to_id:
142
+ raise ValueError(f"Invalid organism name: {organism}")
143
+ return organism_to_id[organism]
144
+
145
+ elif isinstance(organism, int):
146
+ if organism not in organism_to_id.values():
147
+ raise ValueError(f"Invalid organism ID: {organism}")
148
+ return organism
149
+
150
+ raise TypeError(
151
+ f"Organism must be a string or integer, not {type(organism).__name__}"
152
+ )
153
+
154
+
155
+ def preprocess_protein_sequence(protein: str) -> str:
156
+ """
157
+ Preprocess a protein sequence by cleaning, standardizing, and handling
158
+ ambiguous amino acids.
159
+
160
+ Args:
161
+ protein (str): The input protein sequence.
162
+
163
+ Returns:
164
+ str: The preprocessed protein sequence.
165
+
166
+ Raises:
167
+ ValueError: If the protein sequence is invalid or if the configuration is invalid.
168
+ """
169
+ if not protein:
170
+ raise ValueError("Protein sequence is empty.")
171
+
172
+ # Clean and standardize the protein sequence
173
+ protein = (
174
+ protein.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "")
175
+ )
176
+
177
+ # Handle ambiguous amino acids based on the specified behavior
178
+ config = ProteinConfig()
179
+ ambiguous_aminoacid_map_override = config.get("ambiguous_aminoacid_map_override")
180
+ ambiguous_aminoacid_behavior = config.get("ambiguous_aminoacid_behavior")
181
+ ambiguous_aminoacid_map = AMBIGUOUS_AMINOACID_MAP.copy()
182
+
183
+ for aminoacid, standard_aminoacids in ambiguous_aminoacid_map_override.items():
184
+ ambiguous_aminoacid_map[aminoacid] = standard_aminoacids
185
+
186
+ if ambiguous_aminoacid_behavior == "raise_error":
187
+ if any(aminoacid in ambiguous_aminoacid_map for aminoacid in protein):
188
+ raise ValueError("Ambiguous amino acids found in protein sequence.")
189
+ elif ambiguous_aminoacid_behavior == "standardize_deterministic":
190
+ protein = "".join(
191
+ ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0]
192
+ for aminoacid in protein
193
+ )
194
+ elif ambiguous_aminoacid_behavior == "standardize_random":
195
+ protein = "".join(
196
+ random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid]))
197
+ for aminoacid in protein
198
+ )
199
+ else:
200
+ raise ValueError(
201
+ f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}."
202
+ )
203
+
204
+ # Check for sequence validity
205
+ if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein):
206
+ raise ValueError("Invalid characters in protein sequence.")
207
+
208
+ if protein[-1] not in AMINO_ACIDS + STOP_SYMBOLS:
209
+ raise ValueError(
210
+ "Protein sequence must end with `*`, or `_`, or an amino acid."
211
+ )
212
+
213
+ # Replace '*' at the end of protein with STOP_SYMBOL if present
214
+ if protein[-1] == "*":
215
+ protein = protein[:-1] + STOP_SYMBOL
216
+
217
+ # Add stop symbol to end of protein
218
+ if protein[-1] != STOP_SYMBOL:
219
+ protein += STOP_SYMBOL
220
+
221
+ return protein
222
+
223
+
224
+ def replace_ambiguous_codons(dna: str) -> str:
225
+ """
226
+ Replaces ambiguous codons in a DNA sequence with "UNK".
227
+
228
+ Args:
229
+ dna (str): The DNA sequence to process.
230
+
231
+ Returns:
232
+ str: The processed DNA sequence with ambiguous codons replaced by "UNK".
233
+ """
234
+ result = []
235
+ dna = dna.upper()
236
+
237
+ # Check codons in DNA sequence
238
+ for i in range(0, len(dna), 3):
239
+ codon = dna[i : i + 3]
240
+
241
+ if len(codon) == 3 and all(nucleotide in "ATCG" for nucleotide in codon):
242
+ result.append(codon)
243
+ else:
244
+ result.append("UNK")
245
+
246
+ return "".join(result)
247
+
248
+
249
+ def preprocess_dna_sequence(dna: str) -> str:
250
+ """
251
+ Cleans and preprocesses a DNA sequence by standardizing it and replacing
252
+ ambiguous codons.
253
+
254
+ Args:
255
+ dna (str): The DNA sequence to preprocess.
256
+
257
+ Returns:
258
+ str: The cleaned and preprocessed DNA sequence.
259
+ """
260
+ if not dna:
261
+ return ""
262
+
263
+ # Clean and standardize the DNA sequence
264
+ dna = dna.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "")
265
+
266
+ # Replace codons with ambigous nucleotides with "UNK"
267
+ dna = replace_ambiguous_codons(dna)
268
+
269
+ # Add unkown stop codon to end of DNA sequence if not present
270
+ if dna[-3:] not in STOP_CODONS:
271
+ dna += "UNK"
272
+
273
+ return dna
274
+
275
+
276
+ def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str:
277
+ """
278
+ Return the merged sequence of protein amino acids and DNA codons in the form
279
+ of tokens separated by space, where each token is composed of an amino acid +
280
+ separator + codon.
281
+
282
+ Args:
283
+ protein (str): Protein sequence.
284
+ dna (str): DNA sequence.
285
+ separator (str): Separator between amino acid and codon.
286
+
287
+ Returns:
288
+ str: Merged sequence.
289
+
290
+ Example:
291
+ >>> get_merged_seq(protein="MAV_", dna="ATGGCTGTGTAA", separator="_")
292
+ 'M_ATG A_GCT V_GTG __TAA'
293
+
294
+ >>> get_merged_seq(protein="QHH_", dna="", separator="_")
295
+ 'Q_UNK H_UNK H_UNK __UNK'
296
+ """
297
+ merged_seq = ""
298
+
299
+ # Prepare protein and dna sequences
300
+ dna = preprocess_dna_sequence(dna)
301
+ protein = preprocess_protein_sequence(protein)
302
+
303
+ # Check if the length of protein and dna sequences are equal
304
+ if len(dna) > 0 and len(protein) != len(dna) / 3:
305
+ raise ValueError(
306
+ 'Length of protein (including stop symbol such as "_") and '
307
+ "the number of codons in DNA sequence (including stop codon) "
308
+ "must be equal."
309
+ )
310
+
311
+ # Merge protein and DNA sequences into tokens
312
+ for i, aminoacid in enumerate(protein):
313
+ merged_seq += f'{aminoacid}{separator}{dna[i * 3:i * 3 + 3] if dna else "UNK"} '
314
+
315
+ return merged_seq.strip()
316
+
317
+
318
+ def is_correct_seq(dna: str, protein: str, stop_symbol: str = STOP_SYMBOL) -> bool:
319
+ """
320
+ Check if the given DNA and protein pair is correct, that is:
321
+ 1. The length of dna is divisible by 3
322
+ 2. There is an initiator codon in the beginning of dna
323
+ 3. There is only one stop codon in the sequence
324
+ 4. The only stop codon is the last codon
325
+
326
+ Note since in Codon Table 3, 'TGA' is interpreted as Triptophan (W),
327
+ there is a separate check to make sure those sequences are considered correct.
328
+
329
+ Args:
330
+ dna (str): DNA sequence.
331
+ protein (str): Protein sequence.
332
+ stop_symbol (str): Stop symbol.
333
+
334
+ Returns:
335
+ bool: True if the sequence is correct, False otherwise.
336
+ """
337
+ return (
338
+ len(dna) % 3 == 0 # Check if DNA length is divisible by 3
339
+ and dna[:3].upper() in START_CODONS # Check for initiator codon
340
+ and protein[-1]
341
+ == stop_symbol # Check if the last protein symbol is the stop symbol
342
+ and protein.count(stop_symbol) == 1 # Check if there is only one stop symbol
343
+ and len(set(dna))
344
+ == 4 # Check if DNA consists of 4 unique nucleotides (A, T, C, G)
345
+ )
346
+
347
+
348
+ def get_amino_acid_sequence(
349
+ dna: str,
350
+ stop_symbol: str = "_",
351
+ codon_table: int = 1,
352
+ return_correct_seq: bool = False,
353
+ ) -> Union[str, Tuple[str, bool]]:
354
+ """
355
+ Return the translated protein sequence given a DNA sequence and codon table.
356
+
357
+ Args:
358
+ dna (str): DNA sequence.
359
+ stop_symbol (str): Stop symbol.
360
+ codon_table (int): Codon table number.
361
+ return_correct_seq (bool): Whether to return if the sequence is correct.
362
+
363
+ Returns:
364
+ Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if
365
+ return_correct_seq is True, otherwise just the protein sequence.
366
+ """
367
+ dna_seq = Seq(dna).strip()
368
+
369
+ # Translate the DNA sequence to a protein sequence
370
+ protein_seq = str(
371
+ dna_seq.translate(
372
+ stop_symbol=stop_symbol, # Symbol to use for stop codons
373
+ to_stop=False, # Translate the entire sequence, including any stop codons
374
+ cds=False, # Do not assume the input is a coding sequence
375
+ table=codon_table, # Codon table to use for translation
376
+ )
377
+ ).strip()
378
+
379
+ return (
380
+ protein_seq
381
+ if not return_correct_seq
382
+ else (protein_seq, is_correct_seq(dna_seq, protein_seq, stop_symbol))
383
+ )
384
+
385
+
386
+ def read_fasta_file(
387
+ input_file: str,
388
+ save_to_file: Optional[str] = None,
389
+ organism: str = "",
390
+ buffer_size: int = 50000,
391
+ ) -> pd.DataFrame:
392
+ """
393
+ Read a FASTA file of DNA sequences and convert it to a Pandas DataFrame.
394
+ Optionally, save the DataFrame to a CSV file.
395
+
396
+ Args:
397
+ input_file (str): Path to the input FASTA file.
398
+ save_to_file (Optional[str]): Path to save the output DataFrame. If None,
399
+ data is only returned.
400
+ organism (str): Name of the organism. If empty, it will be extracted from
401
+ the FASTA description.
402
+ buffer_size (int): Number of records to process before writing to file.
403
+
404
+ Returns:
405
+ pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe
406
+ is True, else None.
407
+
408
+ Raises:
409
+ FileNotFoundError: If the input file does not exist.
410
+ """
411
+ if not os.path.exists(input_file):
412
+ raise FileNotFoundError(f"Input file not found: {input_file}")
413
+
414
+ buffer = []
415
+ columns = [
416
+ "dna",
417
+ "protein",
418
+ "correct_seq",
419
+ "organism",
420
+ "GeneID",
421
+ "description",
422
+ "tokenized",
423
+ ]
424
+
425
+ # Initialize DataFrame to store all data if return_dataframe is True
426
+ all_data = pd.DataFrame(columns=columns)
427
+
428
+ with open(input_file, "r") as fasta_file:
429
+ for record in tqdm(
430
+ SeqIO.parse(fasta_file, "fasta"),
431
+ desc=f"Processing {organism}",
432
+ unit=" Records",
433
+ ):
434
+ dna = str(record.seq).strip().upper() # Ensure uppercase DNA sequence
435
+
436
+ # Determine the organism from the record if not provided
437
+ current_organism = organism or find_pattern_in_fasta(
438
+ "organism", record.description
439
+ )
440
+ gene_id = find_pattern_in_fasta("GeneID", record.description)
441
+
442
+ # Get the appropriate codon table for the organism
443
+ codon_table = get_codon_table(current_organism)
444
+
445
+ # Translate DNA to protein sequence
446
+ protein, correct_seq = get_amino_acid_sequence(
447
+ dna,
448
+ stop_symbol=STOP_SYMBOL,
449
+ codon_table=codon_table,
450
+ return_correct_seq=True,
451
+ )
452
+ description = record.description.split("[", 1)[0].strip()
453
+ tokenized = get_merged_seq(protein, dna, separator=STOP_SYMBOL)
454
+
455
+ # Create a data row for the current sequence
456
+ data_row = {
457
+ "dna": dna,
458
+ "protein": protein,
459
+ "correct_seq": correct_seq,
460
+ "organism": current_organism,
461
+ "GeneID": gene_id,
462
+ "description": description,
463
+ "tokenized": tokenized,
464
+ }
465
+ buffer.append(data_row)
466
+
467
+ # Write buffer to CSV file when buffer size is reached
468
+ if save_to_file and len(buffer) >= buffer_size:
469
+ write_buffer_to_csv(buffer, save_to_file, columns)
470
+ buffer = []
471
+
472
+ all_data = pd.concat(
473
+ [all_data, pd.DataFrame([data_row])], ignore_index=True
474
+ )
475
+
476
+ # Write remaining buffer to CSV file
477
+ if save_to_file and buffer:
478
+ write_buffer_to_csv(buffer, save_to_file, columns)
479
+
480
+ return all_data
481
+
482
+
483
+ def write_buffer_to_csv(buffer: List[Dict], output_path: str, columns: List[str]):
484
+ """Helper function to write buffer to CSV file."""
485
+ buffer_df = pd.DataFrame(buffer, columns=columns)
486
+ buffer_df.to_csv(
487
+ output_path,
488
+ mode="a",
489
+ header=(not os.path.exists(output_path)),
490
+ index=True,
491
+ )
492
+
493
+
494
+ def download_codon_frequencies_from_kazusa(
495
+ taxonomy_id: Optional[int] = None,
496
+ organism: Optional[str] = None,
497
+ taxonomy_reference: Optional[str] = None,
498
+ return_original_format: bool = False,
499
+ ) -> AMINO2CODON_TYPE:
500
+ """
501
+ Return the codon table of the given taxonomy ID from the Kazusa Database.
502
+
503
+ Args:
504
+ taxonomy_id (Optional[int]): Taxonomy ID.
505
+ organism (Optional[str]): Name of the organism.
506
+ taxonomy_reference (Optional[str]): Taxonomy reference.
507
+ return_original_format (bool): Whether to return in the original format.
508
+
509
+ Returns:
510
+ AMINO2CODON_TYPE: Codon table.
511
+ """
512
+ if taxonomy_reference:
513
+ taxonomy_id = get_taxonomy_id(taxonomy_reference, organism=organism)
514
+
515
+ kazusa_amino2codon = pct.get_codons_table(table_name=taxonomy_id)
516
+
517
+ if return_original_format:
518
+ return kazusa_amino2codon
519
+
520
+ # Replace "*" with STOP_SYMBOL in the codon table
521
+ kazusa_amino2codon[STOP_SYMBOL] = kazusa_amino2codon.pop("*")
522
+
523
+ # Create amino2codon dictionary
524
+ amino2codon = {
525
+ aminoacid: (list(codon2freq.keys()), list(codon2freq.values()))
526
+ for aminoacid, codon2freq in kazusa_amino2codon.items()
527
+ }
528
+
529
+ return sort_amino2codon_skeleton(amino2codon)
530
+
531
+
532
+ def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE:
533
+ """
534
+ Return the empty skeleton of the amino2codon dictionary, needed for
535
+ get_codon_frequencies.
536
+
537
+ Args:
538
+ organism (str): Name of the organism.
539
+
540
+ Returns:
541
+ AMINO2CODON_TYPE: Empty amino2codon dictionary.
542
+ """
543
+ amino2codon = {}
544
+ possible_codons = [f"{i}{j}{k}" for i in "ACGT" for j in "ACGT" for k in "ACGT"]
545
+ possible_aminoacids = get_amino_acid_sequence(
546
+ dna="".join(possible_codons),
547
+ codon_table=get_codon_table(organism),
548
+ return_correct_seq=False,
549
+ )
550
+
551
+ # Initialize the amino2codon skeleton with all possible codons and set their
552
+ # frequencies to 0
553
+ for i, (codon, amino) in enumerate(zip(possible_codons, possible_aminoacids)):
554
+ if amino not in amino2codon:
555
+ amino2codon[amino] = ([], [])
556
+
557
+ amino2codon[amino][0].append(codon)
558
+ amino2codon[amino][1].append(0)
559
+
560
+ # Sort the dictionary and each list of codon frequency alphabetically
561
+ amino2codon = sort_amino2codon_skeleton(amino2codon)
562
+
563
+ return amino2codon
564
+
565
+
566
+ def get_codon_frequencies(
567
+ dna_sequences: List[str],
568
+ protein_sequences: Optional[List[str]] = None,
569
+ organism: Optional[str] = None,
570
+ ) -> AMINO2CODON_TYPE:
571
+ """
572
+ Return a dictionary mapping each codon to its respective frequency based on
573
+ the collection of DNA sequences and protein sequences.
574
+
575
+ Args:
576
+ dna_sequences (List[str]): List of DNA sequences.
577
+ protein_sequences (Optional[List[str]]): List of protein sequences.
578
+ organism (Optional[str]): Name of the organism.
579
+
580
+ Returns:
581
+ AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons
582
+ and frequencies.
583
+ """
584
+ if organism:
585
+ codon_table = get_codon_table(organism)
586
+ protein_sequences = [
587
+ get_amino_acid_sequence(
588
+ dna, codon_table=codon_table, return_correct_seq=False
589
+ )
590
+ for dna in dna_sequences
591
+ ]
592
+
593
+ amino2codon = build_amino2codon_skeleton(organism)
594
+
595
+ # Count the frequencies of each codon for each amino acid
596
+ for dna, protein in zip(dna_sequences, protein_sequences):
597
+ for i, amino in enumerate(protein):
598
+ codon = dna[i * 3 : (i + 1) * 3]
599
+ codon_loc = amino2codon[amino][0].index(codon)
600
+ amino2codon[amino][1][codon_loc] += 1
601
+
602
+ # Normalize codon frequencies per amino acid so they sum to 1
603
+ amino2codon = {
604
+ amino: (codons, [freq / (sum(frequencies) + 1e-100) for freq in frequencies])
605
+ for amino, (codons, frequencies) in amino2codon.items()
606
+ }
607
+
608
+ return amino2codon
609
+
610
+
611
+ def get_organism_to_codon_frequencies(
612
+ dataset: pd.DataFrame, organisms: List[str]
613
+ ) -> Dict[str, AMINO2CODON_TYPE]:
614
+ """
615
+ Return a dictionary mapping each organism to their codon frequency distribution.
616
+
617
+ Args:
618
+ dataset (pd.DataFrame): DataFrame containing DNA sequences.
619
+ organisms (List[str]): List of organisms.
620
+
621
+ Returns:
622
+ Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon
623
+ frequency distribution.
624
+ """
625
+ organism2frequencies = {}
626
+
627
+ # Calculate codon frequencies for each organism in the dataset
628
+ for organism in tqdm(
629
+ organisms, desc="Calculating Codon Frequencies: ", unit="Organism"
630
+ ):
631
+ organism_data = dataset.loc[dataset["organism"] == organism]
632
+
633
+ dna_sequences = organism_data["dna"].to_list()
634
+ protein_sequences = organism_data["protein"].to_list()
635
+
636
+ codon_frequencies = get_codon_frequencies(dna_sequences, protein_sequences)
637
+ organism2frequencies[organism] = codon_frequencies
638
+
639
+ return organism2frequencies
640
+
641
+
642
+ def get_codon_table(organism: str) -> int:
643
+ """
644
+ Return the appropriate NCBI codon table for a given organism.
645
+
646
+ Args:
647
+ organism (str): Name of the organism.
648
+
649
+ Returns:
650
+ int: Codon table number.
651
+ """
652
+ # Common codon table (Table 1) for many model organisms
653
+ if organism in [
654
+ "Arabidopsis thaliana",
655
+ "Caenorhabditis elegans",
656
+ "Chlamydomonas reinhardtii",
657
+ "Saccharomyces cerevisiae",
658
+ "Danio rerio",
659
+ "Drosophila melanogaster",
660
+ "Homo sapiens",
661
+ "Mus musculus",
662
+ "Nicotiana tabacum",
663
+ "Solanum tuberosum",
664
+ "Solanum lycopersicum",
665
+ "Oryza sativa",
666
+ "Glycine max",
667
+ "Zea mays",
668
+ ]:
669
+ codon_table = 1
670
+
671
+ # Chloroplast codon table (Table 11)
672
+ elif organism in [
673
+ "Chlamydomonas reinhardtii chloroplast",
674
+ "Nicotiana tabacum chloroplast",
675
+ ]:
676
+ codon_table = 11
677
+
678
+ # Default to Table 11 for other bacteria and archaea
679
+ else:
680
+ codon_table = 11
681
+
682
+ return codon_table
CodonTransformer/CodonEvaluation.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonEvaluation.py
3
+ ---------------------------
4
+ Includes functions to calculate various evaluation metrics along with helper
5
+ functions.
6
+ """
7
+
8
+ from typing import Dict, List, Tuple, Optional
9
+
10
+ import pandas as pd
11
+ from CAI import CAI, relative_adaptiveness
12
+ from tqdm import tqdm
13
+ import math
14
+ import numpy as np
15
+ from collections import Counter
16
+ from itertools import chain
17
+ from statistics import mean
18
+ import sys
19
+ import os
20
+ from io import StringIO
21
+
22
+
23
+ def get_CSI_weights(sequences: List[str]) -> Dict[str, float]:
24
+ """
25
+ Calculate the Codon Similarity Index (CSI) weights for a list of DNA sequences.
26
+
27
+ Args:
28
+ sequences (List[str]): List of DNA sequences.
29
+
30
+ Returns:
31
+ dict: The CSI weights.
32
+ """
33
+ return relative_adaptiveness(sequences=sequences)
34
+
35
+
36
+ def get_CSI_value(dna: str, weights: Dict[str, float]) -> float:
37
+ """
38
+ Calculate the Codon Similarity Index (CSI) for a DNA sequence.
39
+
40
+ Args:
41
+ dna (str): The DNA sequence.
42
+ weights (dict): The CSI weights from get_CSI_weights.
43
+
44
+ Returns:
45
+ float: The CSI value.
46
+ """
47
+ return CAI(dna, weights)
48
+
49
+
50
+ def get_organism_to_CSI_weights(
51
+ dataset: pd.DataFrame, organisms: List[str]
52
+ ) -> Dict[str, dict]:
53
+ """
54
+ Calculate the Codon Similarity Index (CSI) weights for a list of organisms.
55
+
56
+ Args:
57
+ dataset (pd.DataFrame): Dataset containing organism and DNA sequence info.
58
+ organisms (List[str]): List of organism names.
59
+
60
+ Returns:
61
+ Dict[str, dict]: A dictionary mapping each organism to its CSI weights.
62
+ """
63
+ organism2weights = {}
64
+
65
+ # Iterate through each organism to calculate its CSI weights
66
+ for organism in tqdm(organisms, desc="Calculating CSI Weights: ", unit="Organism"):
67
+ organism_data = dataset.loc[dataset["organism"] == organism]
68
+ sequences = organism_data["dna"].to_list()
69
+ weights = get_CSI_weights(sequences)
70
+ organism2weights[organism] = weights
71
+
72
+ return organism2weights
73
+
74
+
75
+ def get_GC_content(dna: str) -> float:
76
+ """
77
+ Calculate the GC content of a DNA sequence.
78
+
79
+ Args:
80
+ dna (str): The DNA sequence.
81
+
82
+ Returns:
83
+ float: The GC content as a percentage.
84
+ """
85
+ dna = dna.upper()
86
+ if not dna:
87
+ return 0.0
88
+ return (dna.count("G") + dna.count("C")) / len(dna) * 100
89
+
90
+
91
+ def get_cfd(
92
+ dna: str,
93
+ codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
94
+ threshold: float = 0.3,
95
+ ) -> float:
96
+ """
97
+ Calculate the codon frequency distribution (CFD) metric for a DNA sequence.
98
+
99
+ Args:
100
+ dna (str): The DNA sequence.
101
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
102
+ frequency distribution per amino acid.
103
+ threshold (float): Frequency threshold for counting rare codons.
104
+
105
+ Returns:
106
+ float: The CFD metric as a percentage.
107
+ """
108
+ # Get a dictionary mapping each codon to its normalized frequency
109
+ codon2frequency = {
110
+ codon: freq / max(frequencies)
111
+ for amino, (codons, frequencies) in codon_frequencies.items()
112
+ for codon, freq in zip(codons, frequencies)
113
+ }
114
+
115
+ cfd = 0
116
+
117
+ # Iterate through the DNA sequence in steps of 3 to process each codon
118
+ for i in range(0, len(dna), 3):
119
+ codon = dna[i : i + 3]
120
+ codon_frequency = codon2frequency[codon]
121
+
122
+ if codon_frequency < threshold:
123
+ cfd += 1
124
+
125
+ return cfd / (len(dna) / 3) * 100
126
+
127
+
128
+ def get_min_max_percentage(
129
+ dna: str,
130
+ codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
131
+ window_size: int = 18,
132
+ ) -> List[float]:
133
+ """
134
+ Calculate the %MinMax metric for a DNA sequence.
135
+
136
+ Args:
137
+ dna (str): The DNA sequence.
138
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
139
+ frequency distribution per amino acid.
140
+ window_size (int): Size of the window to calculate %MinMax.
141
+
142
+ Returns:
143
+ List[float]: List of %MinMax values for the sequence.
144
+
145
+ Credit: https://github.com/chowington/minmax
146
+ """
147
+ # Get a dictionary mapping each codon to its respective amino acid
148
+ codon2amino = {
149
+ codon: amino
150
+ for amino, (codons, frequencies) in codon_frequencies.items()
151
+ for codon in codons
152
+ }
153
+
154
+ min_max_values = []
155
+ codons = [dna[i : i + 3] for i in range(0, len(dna), 3)] # Split DNA into codons
156
+
157
+ # Iterate through the DNA sequence using the specified window size
158
+ for i in range(len(codons) - window_size + 1):
159
+ codon_window = codons[i : i + window_size] # Codons in the current window
160
+
161
+ Actual = 0.0 # Average of the actual codon frequencies
162
+ Max = 0.0 # Average of the min codon frequencies
163
+ Min = 0.0 # Average of the max codon frequencies
164
+ Avg = 0.0 # Average of the averages of all frequencies for each amino acid
165
+
166
+ # Sum the frequencies for codons in the current window
167
+ for codon in codon_window:
168
+ aminoacid = codon2amino[codon]
169
+ frequencies = codon_frequencies[aminoacid][1]
170
+ codon_index = codon_frequencies[aminoacid][0].index(codon)
171
+ codon_frequency = codon_frequencies[aminoacid][1][codon_index]
172
+
173
+ Actual += codon_frequency
174
+ Max += max(frequencies)
175
+ Min += min(frequencies)
176
+ Avg += sum(frequencies) / len(frequencies)
177
+
178
+ # Divide by the window size to get the averages
179
+ Actual = Actual / window_size
180
+ Max = Max / window_size
181
+ Min = Min / window_size
182
+ Avg = Avg / window_size
183
+
184
+ # Calculate %MinMax
185
+ percentMax = ((Actual - Avg) / (Max - Avg)) * 100
186
+ percentMin = ((Avg - Actual) / (Avg - Min)) * 100
187
+
188
+ # Append the appropriate %MinMax value
189
+ if percentMax >= 0:
190
+ min_max_values.append(percentMax)
191
+ else:
192
+ min_max_values.append(-percentMin)
193
+
194
+ # Populate the last floor(window_size / 2) entries of min_max_values with None
195
+ for i in range(int(window_size / 2)):
196
+ min_max_values.append(None)
197
+
198
+ return min_max_values
199
+
200
+
201
+ def get_sequence_complexity(dna: str) -> float:
202
+ """
203
+ Calculate the sequence complexity score of a DNA sequence.
204
+
205
+ Args:
206
+ dna (str): The DNA sequence.
207
+
208
+ Returns:
209
+ float: The sequence complexity score.
210
+ """
211
+
212
+ def sum_up_to(x):
213
+ """Recursive function to calculate the sum of integers from 1 to x."""
214
+ if x <= 1:
215
+ return 1
216
+ else:
217
+ return x + sum_up_to(x - 1)
218
+
219
+ def f(x):
220
+ """Returns 4 if x is greater than or equal to 4, else returns x."""
221
+ if x >= 4:
222
+ return 4
223
+ elif x < 4:
224
+ return x
225
+
226
+ unique_subseq_length = []
227
+
228
+ # Calculate unique subsequences lengths
229
+ for i in range(1, len(dna) + 1):
230
+ unique_subseq = set()
231
+ for j in range(len(dna) - (i - 1)):
232
+ unique_subseq.add(dna[j : (j + i)])
233
+ unique_subseq_length.append(len(unique_subseq))
234
+
235
+ # Calculate complexity score
236
+ complexity_score = (
237
+ sum(unique_subseq_length) / (sum_up_to(len(dna) - 1) + f(len(dna)))
238
+ ) * 100
239
+
240
+ return complexity_score
241
+
242
+
243
+ def get_sequence_similarity(
244
+ original: str, predicted: str, truncate: bool = True, window_length: int = 1
245
+ ) -> float:
246
+ """
247
+ Calculate the sequence similarity between two sequences.
248
+
249
+ Args:
250
+ original (str): The original sequence.
251
+ predicted (str): The predicted sequence.
252
+ truncate (bool): If True, truncate the original sequence to match the length
253
+ of the predicted sequence.
254
+ window_length (int): Length of the window for comparison (1 for amino acids,
255
+ 3 for codons).
256
+
257
+ Returns:
258
+ float: The sequence similarity as a percentage.
259
+
260
+ Preconditions:
261
+ len(predicted) <= len(original).
262
+ """
263
+ if not truncate and len(original) != len(predicted):
264
+ raise ValueError(
265
+ "Set truncate to True if the length of sequences do not match."
266
+ )
267
+
268
+ identity = 0.0
269
+ original = original.strip()
270
+ predicted = predicted.strip()
271
+
272
+ if truncate:
273
+ original = original[: len(predicted)]
274
+
275
+ if window_length == 1:
276
+ # Simple comparison for amino acid
277
+ for i in range(len(predicted)):
278
+ if original[i] == predicted[i]:
279
+ identity += 1
280
+ else:
281
+ # Comparison for substrings based on window_length
282
+ for i in range(0, len(original) - window_length + 1, window_length):
283
+ if original[i : i + window_length] == predicted[i : i + window_length]:
284
+ identity += 1
285
+
286
+ return (identity / (len(predicted) / window_length)) * 100
287
+
288
+
289
+ def scan_for_restriction_sites(seq: str, sites: List[str] = ['GAATTC', 'GGATCC', 'AAGCTT']) -> int:
290
+ """
291
+ Scans for a list of restriction enzyme sites in a DNA sequence.
292
+ """
293
+ return sum(seq.upper().count(site.upper()) for site in sites)
294
+
295
+
296
+ def count_negative_cis_elements(seq: str, motifs: List[str] = ['TATAAT', 'TTGACA', 'AGCTAGT']) -> int:
297
+ """
298
+ Counts occurrences of negative cis-regulatory elements in a DNA sequence.
299
+ """
300
+ return sum(seq.upper().count(m.upper()) for m in motifs)
301
+
302
+
303
+ def calculate_homopolymer_runs(seq: str, max_len: int = 8) -> int:
304
+ """
305
+ Calculates the number of homopolymer runs longer than a given length.
306
+ """
307
+ import re
308
+ min_len = max_len + 1
309
+ return len(re.findall(r'(A{%d,}|T{%d,}|G{%d,}|C{%d,})' % (min_len, min_len, min_len, min_len), seq.upper()))
310
+
311
+
312
+ def get_min_max_profile(
313
+ dna: str,
314
+ codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
315
+ window_size: int = 18,
316
+ ) -> List[float]:
317
+ """
318
+ Calculate the %MinMax profile for a DNA sequence. This is a list of
319
+ %MinMax values for sliding windows across the sequence.
320
+
321
+ Args:
322
+ dna (str): The DNA sequence.
323
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
324
+ frequency distribution per amino acid.
325
+ window_size (int): Size of the window to calculate %MinMax.
326
+
327
+ Returns:
328
+ List[float]: List of %MinMax values for the sequence.
329
+ """
330
+ return get_min_max_percentage(dna, codon_frequencies, window_size)
331
+
332
+
333
+ def calculate_dtw_distance(profile1: List[float], profile2: List[float]) -> float:
334
+ """
335
+ Calculates the Dynamic Time Warping (DTW) distance between two profiles.
336
+
337
+ Args:
338
+ profile1 (List[float]): The first profile (e.g., %MinMax of generated sequence).
339
+ profile2 (List[float]): The second profile (e.g., %MinMax of natural sequence).
340
+
341
+ Returns:
342
+ float: The DTW distance between the two profiles.
343
+ """
344
+ from dtw import dtw
345
+ import numpy as np
346
+
347
+ # Ensure profiles are numpy arrays and handle potential None and NaN values
348
+ p1 = np.array([v for v in profile1 if v is not None and not np.isnan(v)]).reshape(
349
+ -1, 1
350
+ )
351
+ p2 = np.array([v for v in profile2 if v is not None and not np.isnan(v)]).reshape(
352
+ -1, 1
353
+ )
354
+
355
+ if len(p1) == 0 or len(p2) == 0:
356
+ return np.inf # Return infinity if one of the profiles is empty
357
+
358
+ alignment = dtw(p1, p2, keep_internals=True)
359
+ return alignment.distance # type: ignore
360
+
361
+
362
+ def get_ecoli_tai_weights():
363
+ """
364
+ Returns a dictionary of tAI weights for E. coli based on tRNA gene copy numbers.
365
+ These weights are pre-calculated based on the relative adaptiveness of each codon.
366
+ """
367
+ codons = [
368
+ "TTT", "TTC", "TTA", "TTG", "TCT", "TCC", "TCA", "TCG", "TAT", "TAC",
369
+ "TGT", "TGC", "TGG", "CTT", "CTC", "CTA", "CTG", "CCT", "CCC", "CCA",
370
+ "CCG", "CAT", "CAC", "CAA", "CAG", "CGT", "CGC", "CGA", "CGG", "ATT",
371
+ "ATC", "ATA", "ACT", "ACC", "ACA", "ACG", "AAT", "AAC", "AAA", "AAG",
372
+ "AGT", "AGC", "AGA", "AGG", "GTT", "GTC", "GTA", "GTG", "GCT", "GCC",
373
+ "GCA", "GCG", "GAT", "GAC", "GAA", "GAG", "GGT", "GGC", "GGA", "GGG"
374
+ ]
375
+ weights = [
376
+ 0.1966667, 0.3333333, 0.1666667, 0.2200000, 0.1966667, 0.3333333,
377
+ 0.1666667, 0.2200000, 0.2950000, 0.5000000, 0.09833333, 0.1666667,
378
+ 0.2200000, 0.09833333, 0.1666667, 0.1666667, 0.7200000, 0.09833333,
379
+ 0.1666667, 0.1666667, 0.2200000, 0.09833333, 0.1666667, 0.3333333,
380
+ 0.4400000, 0.6666667, 0.4800000, 0.00006666667, 0.1666667, 0.2950000,
381
+ 0.5000000, 0.01833333, 0.1966667, 0.3333333, 0.1666667, 0.3866667,
382
+ 0.3933333, 0.6666667, 1.0000000, 0.3200000, 0.09833333, 0.1666667,
383
+ 0.1666667, 0.2200000, 0.1966667, 0.3333333, 0.8333333, 0.2666667,
384
+ 0.1966667, 0.3333333, 0.5000000, 0.1600000, 0.2950000, 0.5000000,
385
+ 0.6666667, 0.2133333, 0.3933333, 0.6666667, 0.1666667, 0.2200000
386
+ ]
387
+ return dict(zip(codons, weights))
388
+
389
+
390
+ def calculate_tAI(sequence: str, tai_weights: Dict[str, float]) -> float:
391
+ """
392
+ Calculates the tRNA Adaptation Index (tAI) for a given DNA sequence.
393
+
394
+ Args:
395
+ sequence (str): The DNA sequence to analyze.
396
+ tai_weights (Dict[str, float]): A dictionary of tAI weights for each codon.
397
+
398
+ Returns:
399
+ float: The tAI value for the sequence.
400
+ """
401
+ from scipy.stats.mstats import gmean
402
+
403
+ codons = [sequence[i:i+3] for i in range(0, len(sequence), 3)]
404
+
405
+ # Filter out stop codons and codons not in weights
406
+ weights = [tai_weights[codon] for codon in codons if codon in tai_weights and tai_weights[codon] > 0]
407
+
408
+ if not weights:
409
+ return 0.0
410
+
411
+ return gmean(weights)
412
+
413
+
414
+ def calculate_ENC(sequence: str) -> float:
415
+ """
416
+ Calculate the Effective Number of Codons (ENC) for a DNA sequence.
417
+ Uses the codonbias library implementation based on Wright (1990).
418
+
419
+ Args:
420
+ sequence (str): The DNA sequence.
421
+
422
+ Returns:
423
+ float: The ENC value for the sequence.
424
+ """
425
+ try:
426
+ from codonbias.scores import EffectiveNumberOfCodons
427
+
428
+ # Initialize ENC calculator
429
+ enc_calculator = EffectiveNumberOfCodons(
430
+ k_mer=1, # Standard codon analysis
431
+ bg_correction=True, # Use background correction
432
+ robust=True, # Use robust calculation
433
+ genetic_code=1 # Standard genetic code
434
+ )
435
+
436
+ # Calculate ENC for the sequence
437
+ enc_value = enc_calculator.get_score(sequence)
438
+
439
+ return float(enc_value)
440
+
441
+ except ImportError:
442
+ raise ImportError("codonbias library is required for ENC calculation. Install with: pip install codonbias")
443
+ except Exception as e:
444
+ # Fallback to a simple ENC approximation if library fails
445
+ print(f"Warning: ENC calculation failed with error: {e}. Using approximation.")
446
+ return 45.0 # Typical E. coli ENC value as fallback
447
+
448
+
449
+ def calculate_CPB(sequence: str, reference_sequences: Optional[List[str]] = None) -> float:
450
+ """
451
+ Calculate the Codon Pair Bias (CPB) for a DNA sequence.
452
+ Uses the codonbias library implementation based on Coleman et al. (2008).
453
+
454
+ Args:
455
+ sequence (str): The DNA sequence.
456
+ reference_sequences (List[str]): Reference sequences for calculating expected values.
457
+ If None, uses a default E. coli reference.
458
+
459
+ Returns:
460
+ float: The CPB value for the sequence.
461
+ """
462
+ try:
463
+ from codonbias.scores import CodonPairBias
464
+
465
+ # Use provided reference sequences or default
466
+ if reference_sequences is None:
467
+ # Use the input sequence as reference if none provided
468
+ reference_sequences = [sequence]
469
+
470
+ # Initialize CPB calculator with reference sequences
471
+ cpb_calculator = CodonPairBias(
472
+ ref_seq=reference_sequences,
473
+ k_mer=2, # Codon pairs
474
+ genetic_code=1, # Standard genetic code
475
+ ignore_stop=True, # Ignore stop codons
476
+ pseudocount=1 # Pseudocount for unseen pairs
477
+ )
478
+
479
+ # Calculate CPB for the sequence
480
+ cpb_value = cpb_calculator.get_score(sequence)
481
+
482
+ return float(cpb_value)
483
+
484
+ except ImportError:
485
+ raise ImportError("codonbias library is required for CPB calculation. Install with: pip install codonbias")
486
+ except Exception as e:
487
+ # Fallback calculation if library fails
488
+ print(f"Warning: CPB calculation failed with error: {e}. Using approximation.")
489
+ return 0.0 # Neutral CPB as fallback
490
+
491
+
492
+ def calculate_SCUO(sequence: str) -> float:
493
+ """
494
+ Calculate the Synonymous Codon Usage Order (SCUO) for a DNA sequence.
495
+ Uses the GCUA library implementation based on information theory.
496
+
497
+ Args:
498
+ sequence (str): The DNA sequence.
499
+
500
+ Returns:
501
+ float: The SCUO value (0-1, where 1 indicates maximum bias).
502
+ """
503
+ # Self-contained SCUO implementation (no external GCUA dependency).
504
+ # Based on Wan et al., 2004 information-theoretic definition.
505
+
506
+ from math import log2 # local import to avoid global cost
507
+ try:
508
+ # Build standard genetic code mapping using built-in tables (Biopython optional).
509
+ # Fall back to hard-coded table if Biopython absent.
510
+ try:
511
+ from Bio.Data import CodonTable # type: ignore
512
+ codon_to_aa = CodonTable.unambiguous_dna_by_id[1].forward_table
513
+ except Exception:
514
+ codon_to_aa = {
515
+ # Partial table sufficient for SCUO calculation; stop codons omitted.
516
+ 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
517
+ 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
518
+ 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
519
+ 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
520
+ 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
521
+ 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
522
+ 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
523
+ 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
524
+ 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
525
+ 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
526
+ 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
527
+ 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
528
+ 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
529
+ 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
530
+ 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
531
+ 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G',
532
+ }
533
+
534
+ # Group codons by amino acid (exclude stops)
535
+ aa_to_codons = {}
536
+ for codon, aa in codon_to_aa.items():
537
+ aa_to_codons.setdefault(aa, []).append(codon)
538
+
539
+ # Count codon occurrences in input sequence
540
+ seq = sequence.upper().replace('U', 'T')
541
+ codon_counts = {}
542
+ for i in range(0, len(seq) - len(seq) % 3, 3):
543
+ codon = seq[i:i+3]
544
+ if codon in codon_to_aa:
545
+ codon_counts[codon] = codon_counts.get(codon, 0) + 1
546
+
547
+ total_codons = sum(codon_counts.values())
548
+ if total_codons == 0:
549
+ return 0.0
550
+
551
+ scuo_sum = 0.0
552
+
553
+ for aa, codons in aa_to_codons.items():
554
+ n_codons = len(codons)
555
+ if n_codons == 1:
556
+ continue # SCUO undefined for Met/Trp
557
+
558
+ counts = [codon_counts.get(c, 0) for c in codons]
559
+ total_aa = sum(counts)
560
+ if total_aa == 0:
561
+ continue
562
+
563
+ probs = [c / total_aa for c in counts if c]
564
+ H_obs = -sum(p * log2(p) for p in probs)
565
+ H_max = log2(n_codons)
566
+ O_i = (H_max - H_obs) / H_max if H_max else 0.0
567
+ F_i = total_aa / total_codons
568
+ scuo_sum += F_i * O_i
569
+
570
+ return scuo_sum
571
+
572
+ except Exception as exc:
573
+ print(f"Warning: internal SCUO computation failed ({exc}). Returning 0.5.")
574
+ return 0.5
575
+
CodonTransformer/CodonJupyter.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonJupyter.py
3
+ ---------------------
4
+ Includes Jupyter-specific functions for displaying interactive widgets.
5
+ """
6
+
7
+ from typing import Dict, List, Tuple
8
+
9
+ import ipywidgets as widgets
10
+ from IPython.display import HTML, display
11
+
12
+ from CodonTransformer.CodonUtils import (
13
+ COMMON_ORGANISMS,
14
+ ID2ORGANISM,
15
+ ORGANISM2ID,
16
+ DNASequencePrediction,
17
+ )
18
+
19
+
20
+ class UserContainer:
21
+ """
22
+ A container class to store user inputs for organism and protein sequence.
23
+ Attributes:
24
+ organism (int): The selected organism id.
25
+ protein (str): The input protein sequence.
26
+ """
27
+
28
+ def __init__(self) -> None:
29
+ self.organism: int = -1
30
+ self.protein: str = ""
31
+
32
+
33
+ def create_styled_options(
34
+ organisms: list, organism2id: Dict[str, int], is_fine_tuned: bool = False
35
+ ) -> list:
36
+ """
37
+ Create styled options for the dropdown widget.
38
+
39
+ Args:
40
+ organisms (list): List of organism names.
41
+ organism2id (Dict[str, int]): Dictionary mapping organism names to their IDs.
42
+ is_fine_tuned (bool): Whether these are fine-tuned organisms.
43
+
44
+ Returns:
45
+ list: Styled options for the dropdown widget.
46
+ """
47
+ styled_options = []
48
+ for organism in organisms:
49
+ organism_id = organism2id[organism]
50
+ if is_fine_tuned:
51
+ if organism_id < 10:
52
+ styled_options.append(f"\u200b{organism_id:>6}. {organism}")
53
+ elif organism_id < 100:
54
+ styled_options.append(f"\u200b{organism_id:>5}. {organism}")
55
+ else:
56
+ styled_options.append(f"\u200b{organism_id:>4}. {organism}")
57
+ else:
58
+ if organism_id < 10:
59
+ styled_options.append(f"{organism_id:>6}. {organism}")
60
+ elif organism_id < 100:
61
+ styled_options.append(f"{organism_id:>5}. {organism}")
62
+ else:
63
+ styled_options.append(f"{organism_id:>4}. {organism}")
64
+ return styled_options
65
+
66
+
67
+ def create_dropdown_options(organism2id: Dict[str, int]) -> list:
68
+ """
69
+ Create the full list of dropdown options, including section headers.
70
+
71
+ Args:
72
+ organism2id (Dict[str, int]): Dictionary mapping organism names to their IDs.
73
+
74
+ Returns:
75
+ list: Full list of dropdown options.
76
+ """
77
+ fine_tuned_organisms = sorted(
78
+ [org for org in organism2id.keys() if org in COMMON_ORGANISMS]
79
+ )
80
+ all_organisms = sorted(organism2id.keys())
81
+
82
+ fine_tuned_options = create_styled_options(
83
+ fine_tuned_organisms, organism2id, is_fine_tuned=True
84
+ )
85
+ all_organisms_options = create_styled_options(
86
+ all_organisms, organism2id, is_fine_tuned=False
87
+ )
88
+
89
+ return (
90
+ [""]
91
+ + ["Selected Organisms"]
92
+ + fine_tuned_options
93
+ + [""]
94
+ + ["All Organisms"]
95
+ + all_organisms_options
96
+ )
97
+
98
+
99
+ def create_organism_dropdown(container: UserContainer) -> widgets.Dropdown:
100
+ """
101
+ Create and configure the organism dropdown widget.
102
+
103
+ Args:
104
+ container (UserContainer): Container to store the selected organism.
105
+
106
+ Returns:
107
+ widgets.Dropdown: Configured dropdown widget.
108
+ """
109
+ dropdown = widgets.Dropdown(
110
+ options=create_dropdown_options(ORGANISM2ID),
111
+ description="",
112
+ layout=widgets.Layout(width="40%", margin="0 0 10px 0"),
113
+ style={"description_width": "initial"},
114
+ )
115
+
116
+ def show_organism(change: Dict[str, str]) -> None:
117
+ """
118
+ Update the container with the selected organism and print to terminal.
119
+
120
+ Args:
121
+ change (Dict[str, str]): Information about the change in dropdown value.
122
+ """
123
+ dropdown_choice = change["new"]
124
+ if dropdown_choice and dropdown_choice not in [
125
+ "Selected Organisms",
126
+ "All Organisms",
127
+ ]:
128
+ organism = "".join(filter(str.isdigit, dropdown_choice))
129
+ organism_id = ID2ORGANISM[int(organism)]
130
+ container.organism = organism_id
131
+ else:
132
+ container.organism = None
133
+
134
+ dropdown.observe(show_organism, names="value")
135
+ return dropdown
136
+
137
+
138
+ def get_dropdown_style() -> str:
139
+ """
140
+ Return the custom CSS style for the dropdown widget.
141
+
142
+ Returns:
143
+ str: CSS style string.
144
+ """
145
+ return """
146
+ <style>
147
+ .widget-dropdown > select {
148
+ font-size: 16px;
149
+ font-weight: normal;
150
+ background-color: #f0f0f0;
151
+ border-radius: 5px;
152
+ padding: 5px;
153
+ }
154
+ .widget-label {
155
+ font-size: 18px;
156
+ font-weight: bold;
157
+ }
158
+ .custom-container {
159
+ display: flex;
160
+ flex-direction: column;
161
+ align-items: flex-start;
162
+ }
163
+ .widget-dropdown option[value^="\u200b"] {
164
+ font-family: sans-serif;
165
+ font-weight: bold;
166
+ font-size: 18px;
167
+ padding: 510px;
168
+ }
169
+ .widget-dropdown option[value*="Selected Organisms"],
170
+ .widget-dropdown option[value*="All Organisms"] {
171
+ text-align: center;
172
+ font-family: Arial, sans-serif;
173
+ font-weight: bold;
174
+ font-size: 20px;
175
+ color: #6900A1;
176
+ background-color: #00D8A1;
177
+ }
178
+ </style>
179
+ """
180
+
181
+
182
+ def display_organism_dropdown(container: UserContainer) -> None:
183
+ """
184
+ Display the organism dropdown widget and apply custom styles.
185
+
186
+ Args:
187
+ container (UserContainer): Container to store the selected organism.
188
+ """
189
+ dropdown = create_organism_dropdown(container)
190
+ header = widgets.HTML(
191
+ '<b style="font-size:20px;">Select Organism:</b>'
192
+ '<div style="height:10px;"></div>'
193
+ )
194
+ container_widget = widgets.VBox(
195
+ [header, dropdown],
196
+ layout=widgets.Layout(padding="12px 0 12px 25px"),
197
+ )
198
+ display(container_widget)
199
+ display(HTML(get_dropdown_style()))
200
+
201
+
202
+ def display_protein_input(container: UserContainer) -> None:
203
+ """
204
+ Display a widget for entering a protein sequence and save it to the container.
205
+
206
+ Args:
207
+ container (UserContainer): A container to store the entered protein sequence.
208
+ """
209
+ protein_input = widgets.Textarea(
210
+ value="",
211
+ placeholder="Enter here...",
212
+ description="",
213
+ layout=widgets.Layout(width="100%", height="100px", margin="0 0 10px 0"),
214
+ style={"description_width": "initial"},
215
+ )
216
+
217
+ # Custom CSS for the input widget
218
+ input_style = """
219
+ <style>
220
+ .widget-textarea > textarea {
221
+ font-size: 12px;
222
+ font-family: Arial, sans-serif;
223
+ font-weight: normal;
224
+ background-color: #f0f0f0;
225
+ border-radius: 5px;
226
+ padding: 10px;
227
+ }
228
+ .widget-label {
229
+ font-size: 18px;
230
+ font-weight: bold;
231
+ }
232
+ .custom-container {
233
+ display: flex;
234
+ flex-direction: column;
235
+ align-items: flex-start;
236
+ }
237
+ </style>
238
+ """
239
+
240
+ # Function to save the input protein sequence to the container
241
+ def save_protein(change: Dict[str, str]) -> None:
242
+ """
243
+ Save the input protein sequence to the container.
244
+
245
+ Args:
246
+ change (Dict[str, str]): A dictionary containing information about
247
+ the change in textarea value.
248
+ """
249
+ container.protein = (
250
+ change["new"]
251
+ .upper()
252
+ .strip()
253
+ .replace("\n", "")
254
+ .replace(" ", "")
255
+ .replace("\t", "")
256
+ )
257
+
258
+ # Attach the function to the input widget
259
+ protein_input.observe(save_protein, names="value")
260
+
261
+ # Display the input widget
262
+ header = widgets.HTML(
263
+ '<b style="font-size:20px;">Enter Protein Sequence:</b>'
264
+ '<div style="height:18px;"></div>'
265
+ )
266
+ container_widget = widgets.VBox(
267
+ [header, protein_input], layout=widgets.Layout(padding="12px 12px 0 25px")
268
+ )
269
+
270
+ display(container_widget)
271
+ display(widgets.HTML(input_style))
272
+
273
+
274
+ def format_model_output(output: DNASequencePrediction) -> str:
275
+ """
276
+ Format DNA sequence prediction output in an appealing and easy-to-read manner.
277
+
278
+ This function takes the prediction output and formats it into
279
+ a structured string with clear section headers and separators.
280
+
281
+ Args:
282
+ output (DNASequencePrediction): Object containing the prediction output.
283
+ Expected attributes:
284
+ - organism (str): The organism name.
285
+ - protein (str): The input protein sequence.
286
+ - processed_input (str): The processed input sequence.
287
+ - predicted_dna (str): The predicted DNA sequence.
288
+
289
+ Returns:
290
+ str: A formatted string containing the organized output.
291
+ """
292
+
293
+ def format_section(title: str, content: str) -> str:
294
+ """Helper function to format individual sections."""
295
+ separator = "-" * 29
296
+ title_line = f"| {title.center(25)} |"
297
+ return f"{separator}\n{title_line}\n{separator}\n{content}\n\n"
298
+
299
+ sections: List[Tuple[str, str]] = [
300
+ ("Organism", output.organism),
301
+ ("Input Protein", output.protein),
302
+ ("Processed Input", output.processed_input),
303
+ ("Predicted DNA", output.predicted_dna),
304
+ ]
305
+
306
+ formatted_output = ""
307
+ for title, content in sections:
308
+ formatted_output += format_section(title, content)
309
+
310
+ # Remove the last newline to avoid extra space at the end
311
+ return formatted_output.rstrip()
CodonTransformer/CodonPostProcessing.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonPostProcessing.py
3
+ ---------------------------
4
+ Post-processing utilities for codon optimization using DNAChisel.
5
+ This module provides sequence polishing capabilities to fix restriction sites,
6
+ homopolymers, and other constraints while preserving CAI and GC content.
7
+ """
8
+
9
+ import warnings
10
+ import numpy as np
11
+
12
+ try:
13
+ from dnachisel import (
14
+ DnaOptimizationProblem,
15
+ AvoidPattern,
16
+ EnforceGCContent,
17
+ EnforceTranslation,
18
+ CodonOptimize,
19
+ )
20
+ DNACHISEL_AVAILABLE = True
21
+ except ImportError:
22
+ DNACHISEL_AVAILABLE = False
23
+ # This warning will be shown when the module is first imported.
24
+ warnings.warn(
25
+ "DNAChisel is not installed. Post-processing features will be disabled."
26
+ )
27
+
28
+ def polish_sequence_with_dnachisel(
29
+ dna_sequence: str,
30
+ protein_sequence: str,
31
+ gc_bounds: tuple = (45.0, 55.0),
32
+ cai_species: str = "e_coli",
33
+ avoid_homopolymers_length: int = 6,
34
+ enzymes_to_avoid: list = None
35
+ ):
36
+ """
37
+ Polishes a DNA sequence using DNAChisel to meet lab synthesis constraints.
38
+ """
39
+ if not DNACHISEL_AVAILABLE:
40
+ warnings.warn("DNAChisel not available, skipping post-processing.")
41
+ return dna_sequence
42
+
43
+ if enzymes_to_avoid is None:
44
+ # Common cloning enzymes
45
+ enzymes_to_avoid = ["EcoRI", "XbaI", "SpeI", "PstI", "NotI"]
46
+
47
+ try:
48
+ # Start with the basic, essential constraints
49
+ constraints = [
50
+ EnforceTranslation(translation=protein_sequence),
51
+ EnforceGCContent(mini=gc_bounds[0] / 100.0, maxi=gc_bounds[1] / 100.0),
52
+ ]
53
+
54
+ # Add enzyme avoidance constraints safely
55
+ for enzyme in enzymes_to_avoid:
56
+ try:
57
+ # This is the modern way to avoid enzyme sites
58
+ constraints.append(AvoidPattern.from_enzyme_name(enzyme))
59
+ except Exception:
60
+ warnings.warn(f"Could not find enzyme '{enzyme}' in DNAChisel library.")
61
+
62
+ # Add homopolymer avoidance constraints
63
+ for base in "ATGC":
64
+ constraints.append(AvoidPattern(base * avoid_homopolymers_length))
65
+
66
+ # Define the optimization problem
67
+ problem = DnaOptimizationProblem(
68
+ sequence=dna_sequence,
69
+ constraints=constraints,
70
+ objectives=[CodonOptimize(species=cai_species, method="match_codon_usage")]
71
+ )
72
+
73
+ # Solve the problem
74
+ problem.resolve_constraints()
75
+ problem.optimize()
76
+
77
+ # Return the polished sequence
78
+ return problem.sequence
79
+
80
+ except Exception as e:
81
+ warnings.warn(f"DNAChisel post-processing failed with an error: {e}")
82
+ # Return the original sequence if polishing fails
83
+ return dna_sequence
CodonTransformer/CodonPrediction.py ADDED
@@ -0,0 +1,1374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonPrediction.py
3
+ ---------------------------
4
+ Includes functions to tokenize input, load models, infer predicted dna sequences and
5
+ helper functions related to processing data for passing to the model.
6
+ """
7
+
8
+ import warnings
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+ import heapq
11
+ from dataclasses import dataclass
12
+
13
+ import numpy as np
14
+ import onnxruntime as rt
15
+ import torch
16
+ import transformers
17
+ from transformers import (
18
+ AutoTokenizer,
19
+ BatchEncoding,
20
+ BigBirdConfig,
21
+ BigBirdForMaskedLM,
22
+ PreTrainedTokenizerFast,
23
+ )
24
+
25
+ from CodonTransformer.CodonData import get_merged_seq
26
+ from CodonTransformer.CodonUtils import (
27
+ AMINO_ACID_TO_INDEX,
28
+ INDEX2TOKEN,
29
+ NUM_ORGANISMS,
30
+ ORGANISM2ID,
31
+ TOKEN2INDEX,
32
+ DNASequencePrediction,
33
+ GC_COUNTS_PER_TOKEN,
34
+ CODON_GC_CONTENT,
35
+ AA_MIN_GC,
36
+ AA_MAX_GC,
37
+ )
38
+
39
+
40
+ def predict_dna_sequence(
41
+ protein: str,
42
+ organism: Union[int, str],
43
+ device: torch.device,
44
+ tokenizer: Union[str, PreTrainedTokenizerFast] = None,
45
+ model: Union[str, torch.nn.Module] = None,
46
+ attention_type: str = "original_full",
47
+ deterministic: bool = True,
48
+ temperature: float = 0.2,
49
+ top_p: float = 0.95,
50
+ num_sequences: int = 1,
51
+ match_protein: bool = False,
52
+ use_constrained_search: bool = False,
53
+ gc_bounds: Tuple[float, float] = (0.30, 0.70),
54
+ beam_size: int = 5,
55
+ length_penalty: float = 1.0,
56
+ diversity_penalty: float = 0.0,
57
+ ) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
58
+ """
59
+ Predict the DNA sequence(s) for a given protein using the CodonTransformer model.
60
+
61
+ This function takes a protein sequence and an organism (as ID or name) as input
62
+ and returns the predicted DNA sequence(s) using the CodonTransformer model. It can use
63
+ either provided tokenizer and model objects or load them from specified paths.
64
+
65
+ Args:
66
+ protein (str): The input protein sequence for which to predict the DNA sequence.
67
+ organism (Union[int, str]): Either the ID of the organism or its name (e.g.,
68
+ "Escherichia coli general"). If a string is provided, it will be converted
69
+ to the corresponding ID using ORGANISM2ID.
70
+ device (torch.device): The device (CPU or GPU) to run the model on.
71
+ tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file
72
+ path to load the tokenizer from, a pre-loaded tokenizer object, or None. If
73
+ None, it will be loaded from HuggingFace. Defaults to None.
74
+ model (Union[str, torch.nn.Module, None], optional): Either a file path to load
75
+ the model from, a pre-loaded model object, or None. If None, it will be
76
+ loaded from HuggingFace. Defaults to None.
77
+ attention_type (str, optional): The type of attention mechanism to use in the
78
+ model. Can be either 'block_sparse' or 'original_full'. Defaults to
79
+ "original_full".
80
+ deterministic (bool, optional): Whether to use deterministic decoding (most
81
+ likely tokens). If False, samples tokens according to their probabilities
82
+ adjusted by the temperature. Defaults to True.
83
+ temperature (float, optional): A value controlling the randomness of predictions
84
+ during non-deterministic decoding. Lower values (e.g., 0.2) make the model
85
+ more conservative, while higher values (e.g., 0.8) increase randomness.
86
+ Using high temperatures may result in prediction of DNA sequences that
87
+ do not translate to the input protein.
88
+ Recommended values are:
89
+ - Low randomness: 0.2
90
+ - Medium randomness: 0.5
91
+ - High randomness: 0.8
92
+ The temperature must be a positive float. Defaults to 0.2.
93
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
94
+ Tokens with cumulative probability up to top_p are considered for sampling.
95
+ This parameter helps balance diversity and coherence in the predicted DNA sequences.
96
+ The value must be a float between 0 and 1. Defaults to 0.95.
97
+ num_sequences (int, optional): The number of DNA sequences to generate. Only applicable
98
+ when deterministic is False. Defaults to 1.
99
+ match_protein (bool, optional): Ensures the predicted DNA sequence is translated
100
+ to the input protein sequence by sampling from only the respective codons of
101
+ given amino acids. Defaults to False.
102
+ use_constrained_search (bool, optional): Whether to use constrained beam search
103
+ with GC content bounds. Defaults to False.
104
+ gc_bounds (Tuple[float, float], optional): GC content bounds (min, max) for
105
+ constrained search. Defaults to (0.30, 0.70).
106
+ beam_size (int, optional): Beam size for constrained search. Defaults to 5.
107
+ length_penalty (float, optional): Length penalty for beam search scoring.
108
+ Defaults to 1.0.
109
+ diversity_penalty (float, optional): Diversity penalty to reduce repetitive
110
+ sequences. Defaults to 0.0.
111
+
112
+ Returns:
113
+ Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects
114
+ containing the prediction results:
115
+ - organism (str): Name of the organism used for prediction.
116
+ - protein (str): Input protein sequence for which DNA sequence is predicted.
117
+ - processed_input (str): Processed input sequence (merged protein and DNA).
118
+ - predicted_dna (str): Predicted DNA sequence.
119
+
120
+ Raises:
121
+ ValueError: If the protein sequence is empty, if the organism is invalid,
122
+ if the temperature is not a positive float, if top_p is not between 0 and 1,
123
+ or if num_sequences is less than 1 or used with deterministic mode.
124
+
125
+ Note:
126
+ This function uses ORGANISM2ID, INDEX2TOKEN, and AMINO_ACID_TO_INDEX dictionaries
127
+ imported from CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their
128
+ corresponding IDs. INDEX2TOKEN maps model output indices (token IDs) to
129
+ respective codons. AMINO_ACID_TO_INDEX maps each amino acid and stop symbol to indices
130
+ of codon tokens that translate to it.
131
+
132
+ Example:
133
+ >>> import torch
134
+ >>> from transformers import AutoTokenizer, BigBirdForMaskedLM
135
+ >>> from CodonTransformer.CodonPrediction import predict_dna_sequence
136
+ >>> from CodonTransformer.CodonJupyter import format_model_output
137
+ >>>
138
+ >>> # Set up device
139
+ >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140
+ >>>
141
+ >>> # Load tokenizer and model
142
+ >>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
143
+ >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
144
+ >>> model = model.to(device)
145
+ >>>
146
+ >>> # Define protein sequence and organism
147
+ >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"
148
+ >>> organism = "Escherichia coli general"
149
+ >>>
150
+ >>> # Predict DNA sequence with deterministic decoding (single sequence)
151
+ >>> output = predict_dna_sequence(
152
+ ... protein=protein,
153
+ ... organism=organism,
154
+ ... device=device,
155
+ ... tokenizer=tokenizer,
156
+ ... model=model,
157
+ ... attention_type="original_full",
158
+ ... deterministic=True
159
+ ... )
160
+ >>>
161
+ >>> # Predict DNA sequence with constrained beam search
162
+ >>> output_constrained = predict_dna_sequence(
163
+ ... protein=protein,
164
+ ... organism=organism,
165
+ ... device=device,
166
+ ... tokenizer=tokenizer,
167
+ ... model=model,
168
+ ... use_constrained_search=True,
169
+ ... gc_bounds=(0.40, 0.60),
170
+ ... beam_size=10,
171
+ ... length_penalty=1.2,
172
+ ... diversity_penalty=0.1
173
+ ... )
174
+ >>>
175
+ >>> # Predict multiple DNA sequences with low randomness and top_p sampling
176
+ >>> output_random = predict_dna_sequence(
177
+ ... protein=protein,
178
+ ... organism=organism,
179
+ ... device=device,
180
+ ... tokenizer=tokenizer,
181
+ ... model=model,
182
+ ... attention_type="original_full",
183
+ ... deterministic=False,
184
+ ... temperature=0.2,
185
+ ... top_p=0.95,
186
+ ... num_sequences=3
187
+ ... )
188
+ >>>
189
+ >>> print(format_model_output(output))
190
+ >>> for i, seq in enumerate(output_random, 1):
191
+ ... print(f"Sequence {i}:")
192
+ ... print(format_model_output(seq))
193
+ ... print()
194
+ """
195
+ if not protein:
196
+ raise ValueError("Protein sequence cannot be empty.")
197
+
198
+ if not isinstance(temperature, (float, int)) or temperature <= 0:
199
+ raise ValueError("Temperature must be a positive float.")
200
+
201
+ if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
202
+ raise ValueError("top_p must be a float between 0 and 1.")
203
+
204
+ if not isinstance(num_sequences, int) or num_sequences < 1:
205
+ raise ValueError("num_sequences must be a positive integer.")
206
+
207
+ if use_constrained_search:
208
+ if not isinstance(gc_bounds, tuple) or len(gc_bounds) != 2:
209
+ raise ValueError("gc_bounds must be a tuple of (min_gc, max_gc).")
210
+
211
+ if not (0.0 <= gc_bounds[0] <= gc_bounds[1] <= 1.0):
212
+ raise ValueError("gc_bounds must be between 0.0 and 1.0 with min <= max.")
213
+
214
+ if not isinstance(beam_size, int) or beam_size < 1:
215
+ raise ValueError("beam_size must be a positive integer.")
216
+
217
+ if deterministic and num_sequences > 1 and not use_constrained_search:
218
+ raise ValueError(
219
+ "Multiple sequences can only be generated in non-deterministic mode "
220
+ "(unless using constrained search)."
221
+ )
222
+
223
+ if use_constrained_search and num_sequences > 1:
224
+ raise ValueError(
225
+ "Constrained beam search currently supports only single sequence generation."
226
+ )
227
+
228
+ # Load tokenizer
229
+ if not isinstance(tokenizer, PreTrainedTokenizerFast):
230
+ tokenizer = load_tokenizer(tokenizer)
231
+
232
+ # Load model
233
+ if not isinstance(model, torch.nn.Module):
234
+ model = load_model(model_path=model, device=device, attention_type=attention_type)
235
+ else:
236
+ model.eval()
237
+ model.bert.set_attention_type(attention_type)
238
+ model.to(device)
239
+
240
+ # Validate organism and convert to organism_id and organism_name
241
+ organism_id, organism_name = validate_and_convert_organism(organism)
242
+
243
+ # Inference loop
244
+ with torch.no_grad():
245
+ # Tokenize the input sequence
246
+ merged_seq = get_merged_seq(protein=protein, dna="")
247
+ input_dict = {
248
+ "idx": 0, # sample index
249
+ "codons": merged_seq,
250
+ "organism": organism_id,
251
+ }
252
+ tokenized_input = tokenize([input_dict], tokenizer=tokenizer).to(device)
253
+
254
+ # Get the model predictions
255
+ output_dict = model(**tokenized_input, return_dict=True)
256
+ logits = output_dict.logits.detach().cpu()
257
+ logits = logits[:, 1:-1, :] # Remove [CLS] and [SEP] tokens
258
+
259
+ # Mask the logits of codons that do not correspond to the input protein sequence
260
+ if match_protein:
261
+ possible_tokens_per_position = [
262
+ AMINO_ACID_TO_INDEX[token[0]] for token in merged_seq.split(" ")
263
+ ]
264
+ seq_len = logits.shape[1]
265
+ if len(possible_tokens_per_position) > seq_len:
266
+ possible_tokens_per_position = possible_tokens_per_position[:seq_len]
267
+
268
+ mask = torch.full_like(logits, float("-inf"))
269
+
270
+ for pos, possible_tokens in enumerate(possible_tokens_per_position):
271
+ mask[:, pos, possible_tokens] = 0
272
+
273
+ logits = mask + logits
274
+
275
+ predictions = []
276
+ for _ in range(num_sequences):
277
+ # Decode the predicted DNA sequence from the model output
278
+ if use_constrained_search:
279
+ # Use constrained beam search with GC bounds
280
+ predicted_indices = constrained_beam_search_simple(
281
+ logits=logits.squeeze(0),
282
+ protein_sequence=protein,
283
+ gc_bounds=gc_bounds,
284
+ max_attempts=50,
285
+ )
286
+ elif deterministic:
287
+ predicted_indices = logits.argmax(dim=-1).squeeze().tolist()
288
+ else:
289
+ predicted_indices = sample_non_deterministic(
290
+ logits=logits, temperature=temperature, top_p=top_p
291
+ )
292
+
293
+ predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices))
294
+ predicted_dna = (
295
+ "".join([token[-3:] for token in predicted_dna]).strip().upper()
296
+ )
297
+
298
+ predictions.append(
299
+ DNASequencePrediction(
300
+ organism=organism_name,
301
+ protein=protein,
302
+ processed_input=merged_seq,
303
+ predicted_dna=predicted_dna,
304
+ )
305
+ )
306
+
307
+ return predictions[0] if num_sequences == 1 else predictions
308
+
309
+
310
+ @dataclass
311
+ class BeamCandidate:
312
+ """Represents a candidate sequence in the beam search."""
313
+ tokens: List[int]
314
+ score: float
315
+ gc_count: int
316
+ length: int
317
+
318
+ def __post_init__(self):
319
+ self.gc_ratio = self.gc_count / max(self.length, 1)
320
+
321
+ def __lt__(self, other):
322
+ return self.score < other.score
323
+
324
+
325
+ def _calculate_true_future_gc_range(
326
+ current_pos: int,
327
+ protein_sequence: str,
328
+ current_gc_count: int,
329
+ current_length: int
330
+ ) -> Tuple[float, float]:
331
+ """
332
+ Calculate the true minimum and maximum possible final GC content
333
+ given current state and remaining amino acids (perfect foresight).
334
+
335
+ Args:
336
+ current_pos: Current position in protein sequence
337
+ protein_sequence: Full protein sequence
338
+ current_gc_count: Current GC count in partial sequence
339
+ current_length: Current length in nucleotides
340
+
341
+ Returns:
342
+ Tuple of (min_possible_final_gc_ratio, max_possible_final_gc_ratio)
343
+ """
344
+ if current_pos >= len(protein_sequence):
345
+ # Already at end, return current ratio
346
+ final_ratio = current_gc_count / max(current_length, 1)
347
+ return final_ratio, final_ratio
348
+
349
+ # Calculate remaining amino acids
350
+ remaining_aas = protein_sequence[current_pos:]
351
+
352
+ # Calculate min/max possible GC from remaining amino acids
353
+ min_future_gc = 0
354
+ max_future_gc = 0
355
+
356
+ for aa in remaining_aas:
357
+ if aa.upper() in AA_MIN_GC and aa.upper() in AA_MAX_GC:
358
+ min_future_gc += AA_MIN_GC[aa.upper()]
359
+ max_future_gc += AA_MAX_GC[aa.upper()]
360
+ else:
361
+ # If amino acid not found, assume moderate GC (1-2 range)
362
+ min_future_gc += 1
363
+ max_future_gc += 2
364
+
365
+ # Calculate final sequence length
366
+ final_length = current_length + len(remaining_aas) * 3
367
+
368
+ # Calculate min/max possible final GC ratios
369
+ min_final_gc_ratio = (current_gc_count + min_future_gc) / final_length
370
+ max_final_gc_ratio = (current_gc_count + max_future_gc) / final_length
371
+
372
+ return min_final_gc_ratio, max_final_gc_ratio
373
+
374
+
375
+ def constrained_beam_search_simple(
376
+ logits: torch.Tensor,
377
+ protein_sequence: str,
378
+ gc_bounds: Tuple[float, float] = (0.30, 0.70),
379
+ max_attempts: int = 100,
380
+ ) -> List[int]:
381
+ """
382
+ Simple constrained search - try multiple greedy samples and pick best one within GC bounds.
383
+ """
384
+ min_gc, max_gc = gc_bounds
385
+ seq_len = min(logits.shape[0], len(protein_sequence))
386
+
387
+ # Convert to probabilities
388
+ probs = torch.softmax(logits, dim=-1)
389
+
390
+ valid_sequences = []
391
+
392
+ for attempt in range(max_attempts):
393
+ tokens = []
394
+ total_gc = 0
395
+
396
+ # Generate sequence position by position
397
+ for pos in range(seq_len):
398
+ aa = protein_sequence[pos]
399
+ possible_tokens = AMINO_ACID_TO_INDEX.get(aa, [])
400
+
401
+ if not possible_tokens:
402
+ continue
403
+
404
+ # Filter tokens by current constraints and get probabilities
405
+ candidates = []
406
+ for token_idx in possible_tokens:
407
+ if token_idx < len(probs[pos]) and token_idx < len(GC_COUNTS_PER_TOKEN):
408
+ prob = probs[pos][token_idx].item()
409
+ gc_contribution = int(GC_COUNTS_PER_TOKEN[token_idx].item())
410
+
411
+ # Check if this token could still lead to a valid final sequence (perfect foresight)
412
+ new_gc_total = total_gc + gc_contribution
413
+ new_length = (pos + 1) * 3
414
+
415
+ # Calculate what's possible for the final sequence given this choice
416
+ min_final_gc, max_final_gc = _calculate_true_future_gc_range(
417
+ pos + 1, protein_sequence, new_gc_total, new_length
418
+ )
419
+
420
+ # Only prune if there's NO OVERLAP between possible final range and target bounds
421
+ if max_final_gc >= min_gc and min_final_gc <= max_gc:
422
+ # Calculate gentle GC penalty to steer toward target center
423
+ target_gc = (min_gc + max_gc) / 2 # Target center (e.g., 0.50 for bounds 0.45-0.55)
424
+ current_projected_gc = (min_final_gc + max_final_gc) / 2 # Projected center
425
+
426
+ # Only apply penalty if we're significantly off-target AND late in sequence
427
+ sequence_progress = (pos + 1) / seq_len
428
+ if sequence_progress > 0.3: # Only apply penalty after 30% of sequence
429
+ gc_deviation = abs(current_projected_gc - target_gc)
430
+ if gc_deviation > 0.05: # Only if >5% deviation from target
431
+ # Gentle penalty: reduce probability by small factor
432
+ penalty_factor = max(0.7, 1.0 - 0.3 * gc_deviation) # 0.7-1.0 range
433
+ prob = prob * penalty_factor
434
+
435
+ candidates.append((token_idx, prob, gc_contribution))
436
+
437
+ if not candidates:
438
+ # If no valid candidates, break and try next attempt
439
+ break
440
+
441
+ # Sample from valid candidates (with temperature)
442
+ if attempt == 0:
443
+ # First attempt: greedy (highest probability)
444
+ best_token = max(candidates, key=lambda x: x[1])
445
+ else:
446
+ # Other attempts: sample with some randomness
447
+ probs_list = [c[1] for c in candidates]
448
+ if sum(probs_list) > 0:
449
+ # Normalize probabilities
450
+ probs_array = np.array(probs_list)
451
+ probs_array = probs_array / probs_array.sum()
452
+ # Sample
453
+ chosen_idx = np.random.choice(len(candidates), p=probs_array)
454
+ best_token = candidates[chosen_idx]
455
+ else:
456
+ best_token = candidates[0]
457
+
458
+ tokens.append(best_token[0])
459
+ total_gc += best_token[2]
460
+
461
+ # Check if we got a complete sequence
462
+ if len(tokens) == seq_len:
463
+ final_gc_ratio = total_gc / (seq_len * 3)
464
+ if min_gc <= final_gc_ratio <= max_gc:
465
+ # Calculate sequence score (sum of log probabilities)
466
+ score = sum(np.log(probs[i][tokens[i]].item() + 1e-8) for i in range(len(tokens)))
467
+ valid_sequences.append((tokens, score, final_gc_ratio))
468
+
469
+ if not valid_sequences:
470
+ raise ValueError(f"Could not generate valid sequence within GC bounds {gc_bounds} after {max_attempts} attempts")
471
+
472
+ # Return the sequence with highest score
473
+ best_sequence = max(valid_sequences, key=lambda x: x[1])
474
+ return best_sequence[0]
475
+
476
+
477
+ def constrained_beam_search(
478
+ logits: torch.Tensor,
479
+ protein_sequence: str,
480
+ gc_bounds: Tuple[float, float] = (0.30, 0.70),
481
+ beam_size: int = 5,
482
+ length_penalty: float = 1.0,
483
+ diversity_penalty: float = 0.0,
484
+ temperature: float = 1.0,
485
+ max_candidates: int = 100,
486
+ position_aware_gc_penalty: bool = True,
487
+ gc_penalty_strength: float = 2.0,
488
+ ) -> List[int]:
489
+ """
490
+ Constrained beam search with exact per-residue GC bounds tracking.
491
+
492
+ Priority #1: Exact per-residue GC bounds tracking
493
+ - Tracks cumulative GC content after each codon selection
494
+ - Prunes candidates that would violate GC bounds
495
+ - Maintains beam of valid candidates
496
+
497
+ Priority #2: Position-aware GC penalty mechanism
498
+ - Applies variable penalty weights based on sequence position
499
+ - Preserves flexibility early, applies pressure when necessary
500
+ - Uses progressive penalty scaling based on deviation severity
501
+
502
+ Args:
503
+ logits (torch.Tensor): Model logits of shape [seq_len, vocab_size]
504
+ protein_sequence (str): Input protein sequence
505
+ gc_bounds (Tuple[float, float]): (min_gc, max_gc) bounds
506
+ beam_size (int): Number of candidates to maintain
507
+ length_penalty (float): Length penalty for scoring
508
+ diversity_penalty (float): Diversity penalty for scoring
509
+ temperature (float): Temperature for probability scaling
510
+ max_candidates (int): Maximum candidates to consider per position
511
+ position_aware_gc_penalty (bool): Whether to use position-aware GC penalties
512
+ gc_penalty_strength (float): Strength of GC penalty adjustment
513
+
514
+ Returns:
515
+ List[int]: Best sequence token indices
516
+ """
517
+ min_gc, max_gc = gc_bounds
518
+ seq_len = logits.shape[0]
519
+ protein_len = len(protein_sequence)
520
+
521
+ # Ensure we don't go beyond the protein sequence
522
+ if seq_len > protein_len:
523
+ print(f"Warning: logits length ({seq_len}) > protein length ({protein_len}). Truncating to protein length.")
524
+ seq_len = protein_len
525
+ logits = logits[:protein_len]
526
+
527
+ # Initialize beam with empty candidate
528
+ beam = [BeamCandidate(tokens=[], score=0.0, gc_count=0, length=0)]
529
+
530
+ # Apply temperature scaling
531
+ if temperature != 1.0:
532
+ logits = logits / temperature
533
+
534
+ # Convert to probabilities
535
+ probs = torch.softmax(logits, dim=-1)
536
+
537
+ for pos in range(min(seq_len, len(protein_sequence))):
538
+ # Get possible tokens for current amino acid
539
+ aa = protein_sequence[pos]
540
+ possible_tokens = AMINO_ACID_TO_INDEX.get(aa, [])
541
+
542
+ if not possible_tokens:
543
+ # Fallback to all tokens if amino acid not found
544
+ possible_tokens = list(range(probs.shape[1]))
545
+
546
+ # Get top candidates for this position
547
+ pos_probs = probs[pos]
548
+ top_candidates = []
549
+
550
+ for token_idx in possible_tokens:
551
+ if token_idx < len(pos_probs) and token_idx < len(GC_COUNTS_PER_TOKEN):
552
+ prob = pos_probs[token_idx].item()
553
+ gc_contribution = int(GC_COUNTS_PER_TOKEN[token_idx].item())
554
+ # Only include tokens with valid probabilities
555
+ if prob > 1e-10: # Avoid extremely low probabilities
556
+ top_candidates.append((token_idx, prob, gc_contribution))
557
+
558
+ # Sort by probability and take top max_candidates
559
+ top_candidates.sort(key=lambda x: x[1], reverse=True)
560
+ top_candidates = top_candidates[:max_candidates]
561
+
562
+ # If no valid candidates found, fallback to all possible tokens for this amino acid
563
+ if not top_candidates:
564
+ for token_idx in possible_tokens[:min(len(possible_tokens), max_candidates)]:
565
+ if token_idx < len(pos_probs) and token_idx < len(GC_COUNTS_PER_TOKEN):
566
+ prob = max(pos_probs[token_idx].item(), 1e-10) # Ensure minimum probability
567
+ gc_contribution = int(GC_COUNTS_PER_TOKEN[token_idx].item())
568
+ top_candidates.append((token_idx, prob, gc_contribution))
569
+
570
+ # Generate new beam candidates
571
+ new_beam = []
572
+
573
+ for candidate in beam:
574
+ for token_idx, prob, gc_contribution in top_candidates:
575
+ # Calculate new GC stats
576
+ new_gc_count = candidate.gc_count + gc_contribution
577
+ new_length = candidate.length + 3 # Each codon is 3 nucleotides
578
+ new_gc_ratio = new_gc_count / new_length
579
+
580
+ # Priority #2: Position-aware GC penalty mechanism
581
+ gc_penalty = 0.0
582
+ if position_aware_gc_penalty:
583
+ # Calculate position weight (more penalty towards end of sequence)
584
+ position_weight = (pos + 1) / seq_len
585
+
586
+ # Calculate GC deviation severity
587
+ target_gc = (min_gc + max_gc) / 2
588
+ gc_deviation = abs(new_gc_ratio - target_gc)
589
+ deviation_severity = gc_deviation / ((max_gc - min_gc) / 2)
590
+
591
+ # Apply progressive penalty
592
+ if deviation_severity > 0.5: # Soft penalty zone
593
+ gc_penalty = gc_penalty_strength * position_weight * (deviation_severity - 0.5) ** 2
594
+
595
+ # Hard constraint: still prune sequences that exceed bounds
596
+ if new_gc_ratio < min_gc or new_gc_ratio > max_gc:
597
+ continue # Prune invalid candidates
598
+ else:
599
+ # Priority #1: Hard GC bounds only
600
+ if new_gc_ratio < min_gc or new_gc_ratio > max_gc:
601
+ continue # Prune invalid candidates
602
+
603
+ # Calculate score with GC penalty
604
+ new_score = candidate.score + np.log(prob + 1e-8) - gc_penalty
605
+
606
+ # Apply length penalty
607
+ if length_penalty != 1.0:
608
+ length_norm = ((pos + 1) ** length_penalty)
609
+ normalized_score = new_score / length_norm
610
+ else:
611
+ normalized_score = new_score
612
+
613
+ # Create new candidate
614
+ new_candidate = BeamCandidate(
615
+ tokens=candidate.tokens + [token_idx],
616
+ score=normalized_score,
617
+ gc_count=new_gc_count,
618
+ length=new_length
619
+ )
620
+
621
+ new_beam.append(new_candidate)
622
+
623
+ # Apply diversity penalty if specified
624
+ if diversity_penalty > 0.0:
625
+ new_beam = _apply_diversity_penalty(new_beam, diversity_penalty)
626
+
627
+ # Keep top beam_size candidates
628
+ beam = sorted(new_beam, key=lambda x: x.score, reverse=True)[:beam_size]
629
+
630
+ # Priority #3: Adaptive beam rescue for difficult sequences
631
+ if not beam:
632
+ # Attempt beam rescue by relaxing constraints progressively
633
+ rescue_attempts = 0
634
+ max_rescue_attempts = 3
635
+
636
+ while not beam and rescue_attempts < max_rescue_attempts:
637
+ rescue_attempts += 1
638
+
639
+ # Progressive relaxation strategy
640
+ if rescue_attempts == 1:
641
+ # First attempt: increase beam size and relax GC bounds slightly
642
+ temp_beam_size = min(beam_size * 2, max_candidates)
643
+ temp_gc_bounds = (min_gc * 0.95, max_gc * 1.05)
644
+ elif rescue_attempts == 2:
645
+ # Second attempt: further relax GC bounds and increase candidates
646
+ temp_beam_size = min(beam_size * 3, max_candidates)
647
+ temp_gc_bounds = (min_gc * 0.9, max_gc * 1.1)
648
+ else:
649
+ # Final attempt: maximum relaxation
650
+ temp_beam_size = max_candidates
651
+ temp_gc_bounds = (min_gc * 0.85, max_gc * 1.15)
652
+
653
+ # Retry beam generation with relaxed parameters
654
+ rescue_beam = []
655
+ # Use previous beam state or start fresh if this is the first position with no beam
656
+ previous_beam = beam if beam else [BeamCandidate(tokens=[], score=0.0, gc_count=0, length=0)]
657
+ for candidate in previous_beam:
658
+ for token_idx, prob, gc_contribution in top_candidates:
659
+ new_gc_count = candidate.gc_count + gc_contribution
660
+ new_length = candidate.length + 3
661
+ new_gc_ratio = new_gc_count / new_length
662
+
663
+ # Check relaxed bounds
664
+ if temp_gc_bounds[0] <= new_gc_ratio <= temp_gc_bounds[1]:
665
+ # Apply reduced GC penalty for rescue
666
+ gc_penalty = 0.0
667
+ if position_aware_gc_penalty:
668
+ position_weight = (pos + 1) / seq_len
669
+ target_gc = (min_gc + max_gc) / 2
670
+ gc_deviation = abs(new_gc_ratio - target_gc)
671
+ deviation_severity = gc_deviation / ((max_gc - min_gc) / 2)
672
+
673
+ # Reduced penalty for rescue
674
+ if deviation_severity > 0.7:
675
+ gc_penalty = (gc_penalty_strength * 0.5) * position_weight * (deviation_severity - 0.7) ** 2
676
+
677
+ new_score = candidate.score + np.log(prob + 1e-8) - gc_penalty
678
+
679
+ if length_penalty != 1.0:
680
+ length_norm = ((pos + 1) ** length_penalty)
681
+ normalized_score = new_score / length_norm
682
+ else:
683
+ normalized_score = new_score
684
+
685
+ rescue_candidate = BeamCandidate(
686
+ tokens=candidate.tokens + [token_idx],
687
+ score=normalized_score,
688
+ gc_count=new_gc_count,
689
+ length=new_length
690
+ )
691
+ rescue_beam.append(rescue_candidate)
692
+
693
+ # Keep top candidates from rescue attempt
694
+ if rescue_beam:
695
+ beam = sorted(rescue_beam, key=lambda x: x.score, reverse=True)[:temp_beam_size]
696
+ break
697
+
698
+ # If all rescue attempts failed, raise error
699
+ if not beam:
700
+ raise ValueError(
701
+ f"Beam rescue failed at position {pos} after {max_rescue_attempts} attempts. "
702
+ f"The GC constraints {gc_bounds} may be too restrictive for this protein sequence. "
703
+ f"Consider relaxing constraints or using a different approach."
704
+ )
705
+
706
+ # Return best candidate
707
+ best_candidate = max(beam, key=lambda x: x.score)
708
+ return best_candidate.tokens
709
+
710
+
711
+ # Wrapper function that tries simple approach first
712
+ def constrained_beam_search_wrapper(
713
+ logits: torch.Tensor,
714
+ protein_sequence: str,
715
+ gc_bounds: Tuple[float, float] = (0.30, 0.70),
716
+ **kwargs
717
+ ) -> List[int]:
718
+ """Wrapper that tries simple approach first, falls back to complex beam search."""
719
+ try:
720
+ # Try simple approach first
721
+ return constrained_beam_search_simple(logits, protein_sequence, gc_bounds)
722
+ except ValueError:
723
+ # Fall back to complex beam search
724
+ return constrained_beam_search(logits, protein_sequence, gc_bounds, **kwargs)
725
+
726
+
727
+ def _apply_diversity_penalty(candidates: List[BeamCandidate], penalty: float) -> List[BeamCandidate]:
728
+ """
729
+ Apply diversity penalty to reduce repetitive sequences.
730
+
731
+ Args:
732
+ candidates (List[BeamCandidate]): List of candidates
733
+ penalty (float): Diversity penalty strength
734
+
735
+ Returns:
736
+ List[BeamCandidate]: Candidates with diversity penalty applied
737
+ """
738
+ if not candidates:
739
+ return candidates
740
+
741
+ # Count token occurrences
742
+ token_counts = {}
743
+ for candidate in candidates:
744
+ for token in candidate.tokens:
745
+ token_counts[token] = token_counts.get(token, 0) + 1
746
+
747
+ # Apply penalty
748
+ for candidate in candidates:
749
+ diversity_score = 0.0
750
+ for token in candidate.tokens:
751
+ if token_counts[token] > 1:
752
+ diversity_score += penalty * np.log(token_counts[token])
753
+ candidate.score -= diversity_score
754
+
755
+ return candidates
756
+
757
+
758
+ def sample_non_deterministic(
759
+ logits: torch.Tensor,
760
+ temperature: float = 0.2,
761
+ top_p: float = 0.95,
762
+ ) -> List[int]:
763
+ """
764
+ Sample token indices from logits using temperature scaling and nucleus (top-p) sampling.
765
+
766
+ This function applies temperature scaling to the logits, computes probabilities,
767
+ and then performs nucleus sampling to select token indices. It is used for
768
+ non-deterministic decoding in language models to introduce randomness while
769
+ maintaining coherence in the generated sequences.
770
+
771
+ Args:
772
+ logits (torch.Tensor): The logits output from the model of shape
773
+ [seq_len, vocab_size] or [batch_size, seq_len, vocab_size].
774
+ temperature (float, optional): Temperature value for scaling logits.
775
+ Must be a positive float. Defaults to 1.0.
776
+ top_p (float, optional): Cumulative probability threshold for nucleus sampling.
777
+ Must be a float between 0 and 1. Tokens with cumulative probability up to
778
+ `top_p` are considered for sampling. Defaults to 0.95.
779
+
780
+ Returns:
781
+ List[int]: A list of sampled token indices corresponding to the predicted tokens.
782
+
783
+ Raises:
784
+ ValueError: If `temperature` is not a positive float or if `top_p` is not between 0 and 1.
785
+
786
+ Example:
787
+ >>> logits = model_output.logits # Assume logits is a tensor of shape [seq_len, vocab_size]
788
+ >>> predicted_indices = sample_non_deterministic(logits, temperature=0.7, top_p=0.9)
789
+ """
790
+ if not isinstance(temperature, (float, int)) or temperature <= 0:
791
+ raise ValueError("Temperature must be a positive float.")
792
+
793
+ if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
794
+ raise ValueError("top_p must be a float between 0 and 1.")
795
+
796
+ # Compute probabilities using temperature scaling
797
+ probs = torch.softmax(logits / temperature, dim=-1)
798
+
799
+
800
+ # Remove batch dimension if present
801
+ if probs.dim() == 3:
802
+ probs = probs.squeeze(0) # Shape: [seq_len, vocab_size]
803
+
804
+ # Sort probabilities in descending order
805
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
806
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
807
+ mask = probs_sum - probs_sort > top_p
808
+
809
+ # Zero out probabilities for tokens beyond the top-p threshold
810
+ probs_sort[mask] = 0.0
811
+
812
+ # Renormalize the probabilities
813
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
814
+ next_token = torch.multinomial(probs_sort, num_samples=1)
815
+ predicted_indices = torch.gather(probs_idx, -1, next_token).squeeze(-1)
816
+
817
+ return predicted_indices.tolist()
818
+
819
+
820
+ def load_model(
821
+ model_path: Optional[str] = None,
822
+ device: torch.device = None,
823
+ attention_type: str = "original_full",
824
+ num_organisms: int = None,
825
+ remove_prefix: bool = True,
826
+ ) -> torch.nn.Module:
827
+ """
828
+ Load a BigBirdForMaskedLM model from a model file, checkpoint, or HuggingFace.
829
+
830
+ Args:
831
+ model_path (Optional[str]): Path to the model file or checkpoint. If None,
832
+ load from HuggingFace.
833
+ device (torch.device, optional): The device to load the model onto.
834
+ attention_type (str, optional): The type of attention, 'block_sparse'
835
+ or 'original_full'.
836
+ num_organisms (int, optional): Number of organisms, needed if loading from a
837
+ checkpoint that requires this.
838
+ remove_prefix (bool, optional): Whether to remove the "model." prefix from the
839
+ keys in the state dict.
840
+
841
+ Returns:
842
+ torch.nn.Module: The loaded model.
843
+ """
844
+ if not model_path:
845
+ warnings.warn("Model path not provided. Loading from HuggingFace.", UserWarning)
846
+ model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
847
+ elif model_path.endswith(".ckpt"):
848
+ checkpoint = torch.load(model_path, map_location="cpu")
849
+
850
+ # Detect Lightning checkpoint vs raw state dict
851
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
852
+ state_dict = checkpoint["state_dict"]
853
+ if remove_prefix:
854
+ state_dict = {
855
+ k.replace("model.", ""): v for k, v in state_dict.items()
856
+ }
857
+ else:
858
+ # assume checkpoint itself is state_dict
859
+ state_dict = checkpoint
860
+
861
+ if num_organisms is None:
862
+ num_organisms = NUM_ORGANISMS
863
+
864
+ # Load model configuration and instantiate the model
865
+ config = load_bigbird_config(num_organisms)
866
+ model = BigBirdForMaskedLM(config=config)
867
+ model.load_state_dict(state_dict, strict=False)
868
+
869
+ elif model_path.endswith(".pt"):
870
+ state_dict = torch.load(model_path)
871
+ config = state_dict.pop("self.config")
872
+ model = BigBirdForMaskedLM(config=config)
873
+ model.load_state_dict(state_dict, strict=False)
874
+
875
+ else:
876
+ raise ValueError(
877
+ "Unsupported file type. Please provide a .ckpt or .pt file, "
878
+ "or None to load from HuggingFace."
879
+ )
880
+
881
+ # Prepare model for evaluation
882
+ model.bert.set_attention_type(attention_type)
883
+ model.eval()
884
+ if device:
885
+ model.to(device)
886
+
887
+ return model
888
+
889
+
890
+ def load_bigbird_config(num_organisms: int) -> BigBirdConfig:
891
+ """
892
+ Load the config object used to train the BigBird transformer.
893
+
894
+ Args:
895
+ num_organisms (int): The number of organisms.
896
+
897
+ Returns:
898
+ BigBirdConfig: The configuration object for BigBird.
899
+ """
900
+ config = transformers.BigBirdConfig(
901
+ vocab_size=len(TOKEN2INDEX), # Equal to len(tokenizer)
902
+ type_vocab_size=num_organisms,
903
+ sep_token_id=2,
904
+ )
905
+ return config
906
+
907
+
908
+ def create_model_from_checkpoint(
909
+ checkpoint_dir: str, output_model_dir: str, num_organisms: int
910
+ ) -> None:
911
+ """
912
+ Save a model to disk using a previous checkpoint.
913
+
914
+ Args:
915
+ checkpoint_dir (str): Directory where the checkpoint is stored.
916
+ output_model_dir (str): Directory where the model will be saved.
917
+ num_organisms (int): Number of organisms.
918
+ """
919
+ checkpoint = load_model(model_path=checkpoint_dir, num_organisms=num_organisms)
920
+ state_dict = checkpoint.state_dict()
921
+ state_dict["self.config"] = load_bigbird_config(num_organisms=num_organisms)
922
+
923
+ # Save the model state dict to the output directory
924
+ torch.save(state_dict, output_model_dir)
925
+
926
+
927
+ def load_tokenizer(tokenizer_path: Optional[Union[str, PreTrainedTokenizerFast]] = None) -> PreTrainedTokenizerFast:
928
+ """
929
+ Create and return a tokenizer object from tokenizer path or HuggingFace.
930
+
931
+ Args:
932
+ tokenizer_path (Optional[Union[str, PreTrainedTokenizerFast]]): Path to the tokenizer file,
933
+ a pre-loaded tokenizer object, or None. If None, load from HuggingFace.
934
+
935
+ Returns:
936
+ PreTrainedTokenizerFast: The tokenizer object.
937
+ """
938
+ # If a tokenizer object is already provided, return it
939
+ if isinstance(tokenizer_path, PreTrainedTokenizerFast):
940
+ return tokenizer_path
941
+
942
+ # If no path is provided, load from HuggingFace
943
+ if not tokenizer_path:
944
+ warnings.warn(
945
+ "Tokenizer path not provided. Loading from HuggingFace.", UserWarning
946
+ )
947
+ return AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
948
+
949
+ # Load from file path
950
+ return transformers.PreTrainedTokenizerFast(
951
+ tokenizer_file=tokenizer_path,
952
+ bos_token="[CLS]",
953
+ eos_token="[SEP]",
954
+ unk_token="[UNK]",
955
+ sep_token="[SEP]",
956
+ pad_token="[PAD]",
957
+ cls_token="[CLS]",
958
+ mask_token="[MASK]",
959
+ )
960
+
961
+
962
+ def tokenize(
963
+ batch: List[Dict[str, Any]],
964
+ tokenizer: Union[PreTrainedTokenizerFast, str] = None,
965
+ max_len: int = 2048,
966
+ ) -> BatchEncoding:
967
+ """
968
+ Return the tokenized sequences given a batch of input data.
969
+ Each data in the batch is expected to be a dictionary with "codons" and
970
+ "organism" keys.
971
+
972
+ Args:
973
+ batch (List[Dict[str, Any]]): A list of dictionaries with "codons" and
974
+ "organism" keys.
975
+ tokenizer (PreTrainedTokenizerFast, str, optional): The tokenizer object or
976
+ path to the tokenizer file.
977
+ max_len (int, optional): Maximum length of the tokenized sequence.
978
+
979
+ Returns:
980
+ BatchEncoding: The tokenized batch.
981
+ """
982
+ if not isinstance(tokenizer, PreTrainedTokenizerFast):
983
+ tokenizer = load_tokenizer(tokenizer)
984
+
985
+ tokenized = tokenizer(
986
+ [data["codons"] for data in batch],
987
+ return_attention_mask=True,
988
+ return_token_type_ids=True,
989
+ truncation=True,
990
+ padding=True,
991
+ max_length=max_len,
992
+ return_tensors="pt",
993
+ )
994
+
995
+ # Add token type IDs for species
996
+ seq_len = tokenized["input_ids"].shape[-1]
997
+ species_index = torch.tensor([[data["organism"]] for data in batch])
998
+ tokenized["token_type_ids"] = species_index.repeat(1, seq_len)
999
+
1000
+ return tokenized
1001
+
1002
+
1003
+ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]:
1004
+ """
1005
+ Validate and convert the organism input to both ID and name.
1006
+
1007
+ This function takes either an organism ID or name as input and returns both
1008
+ the ID and name. It performs validation to ensure the input corresponds to
1009
+ a valid organism in the ORGANISM2ID dictionary.
1010
+
1011
+ Args:
1012
+ organism (Union[int, str]): Either the ID of the organism (int) or its
1013
+ name (str).
1014
+
1015
+ Returns:
1016
+ Tuple[int, str]: A tuple containing the organism ID (int) and name (str).
1017
+
1018
+ Raises:
1019
+ ValueError: If the input is neither a string nor an integer, if the
1020
+ organism name is not found in ORGANISM2ID, if the organism ID is not a
1021
+ value in ORGANISM2ID, or if no name is found for a given ID.
1022
+
1023
+ Note:
1024
+ This function relies on the ORGANISM2ID dictionary imported from
1025
+ CodonTransformer.CodonUtils, which maps organism names to their
1026
+ corresponding IDs.
1027
+ """
1028
+ if isinstance(organism, str):
1029
+ if organism not in ORGANISM2ID:
1030
+ raise ValueError(
1031
+ f"Invalid organism name: {organism}. "
1032
+ "Please use a valid organism name or ID."
1033
+ )
1034
+ organism_id = ORGANISM2ID[organism]
1035
+ organism_name = organism
1036
+
1037
+ elif isinstance(organism, int):
1038
+ if organism not in ORGANISM2ID.values():
1039
+ raise ValueError(
1040
+ f"Invalid organism ID: {organism}. "
1041
+ "Please use a valid organism name or ID."
1042
+ )
1043
+
1044
+ organism_id = organism
1045
+ organism_name = next(
1046
+ (name for name, id in ORGANISM2ID.items() if id == organism), None
1047
+ )
1048
+ if organism_name is None:
1049
+ raise ValueError(f"No organism name found for ID: {organism}")
1050
+
1051
+ return organism_id, organism_name
1052
+
1053
+
1054
+ def get_high_frequency_choice_sequence(
1055
+ protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
1056
+ ) -> str:
1057
+ """
1058
+ Return the DNA sequence optimized using High Frequency Choice (HFC) approach
1059
+ in which the most frequent codon for a given amino acid is always chosen.
1060
+
1061
+ Args:
1062
+ protein (str): The protein sequence.
1063
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1064
+ frequencies for each amino acid.
1065
+
1066
+ Returns:
1067
+ str: The optimized DNA sequence.
1068
+ """
1069
+ # Select the most frequent codon for each amino acid in the protein sequence
1070
+ dna_codons = [
1071
+ codon_frequencies[aminoacid][0][np.argmax(codon_frequencies[aminoacid][1])]
1072
+ for aminoacid in protein
1073
+ ]
1074
+ return "".join(dna_codons)
1075
+
1076
+
1077
+ def precompute_most_frequent_codons(
1078
+ codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
1079
+ ) -> Dict[str, str]:
1080
+ """
1081
+ Precompute the most frequent codon for each amino acid.
1082
+
1083
+ Args:
1084
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1085
+ frequencies for each amino acid.
1086
+
1087
+ Returns:
1088
+ Dict[str, str]: The most frequent codon for each amino acid.
1089
+ """
1090
+ # Create a dictionary mapping each amino acid to its most frequent codon
1091
+ return {
1092
+ aminoacid: codons[np.argmax(frequencies)]
1093
+ for aminoacid, (codons, frequencies) in codon_frequencies.items()
1094
+ }
1095
+
1096
+
1097
+ def get_high_frequency_choice_sequence_optimized(
1098
+ protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
1099
+ ) -> str:
1100
+ """
1101
+ Efficient implementation of get_high_frequency_choice_sequence that uses
1102
+ vectorized operations and helper functions, achieving up to x10 faster speed.
1103
+
1104
+ Args:
1105
+ protein (str): The protein sequence.
1106
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1107
+ frequencies for each amino acid.
1108
+
1109
+ Returns:
1110
+ str: The optimized DNA sequence.
1111
+ """
1112
+ # Precompute the most frequent codons for each amino acid
1113
+ most_frequent_codons = precompute_most_frequent_codons(codon_frequencies)
1114
+
1115
+ return "".join(most_frequent_codons[aminoacid] for aminoacid in protein)
1116
+
1117
+
1118
+ def get_background_frequency_choice_sequence(
1119
+ protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
1120
+ ) -> str:
1121
+ """
1122
+ Return the DNA sequence optimized using Background Frequency Choice (BFC)
1123
+ approach in which a random codon for a given amino acid is chosen using
1124
+ the codon frequencies probability distribution.
1125
+
1126
+ Args:
1127
+ protein (str): The protein sequence.
1128
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1129
+ frequencies for each amino acid.
1130
+
1131
+ Returns:
1132
+ str: The optimized DNA sequence.
1133
+ """
1134
+ # Select a random codon for each amino acid based on the codon frequencies
1135
+ # probability distribution
1136
+ dna_codons = [
1137
+ np.random.choice(
1138
+ codon_frequencies[aminoacid][0], p=codon_frequencies[aminoacid][1]
1139
+ )
1140
+ for aminoacid in protein
1141
+ ]
1142
+ return "".join(dna_codons)
1143
+
1144
+
1145
+ def precompute_cdf(
1146
+ codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
1147
+ ) -> Dict[str, Tuple[List[str], Any]]:
1148
+ """
1149
+ Precompute the cumulative distribution function (CDF) for each amino acid.
1150
+
1151
+ Args:
1152
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1153
+ frequencies for each amino acid.
1154
+
1155
+ Returns:
1156
+ Dict[str, Tuple[List[str], Any]]: CDFs for each amino acid.
1157
+ """
1158
+ cdf = {}
1159
+
1160
+ # Calculate the cumulative distribution function for each amino acid
1161
+ for aminoacid, (codons, frequencies) in codon_frequencies.items():
1162
+ cdf[aminoacid] = (codons, np.cumsum(frequencies))
1163
+
1164
+ return cdf
1165
+
1166
+
1167
+ def get_background_frequency_choice_sequence_optimized(
1168
+ protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
1169
+ ) -> str:
1170
+ """
1171
+ Efficient implementation of get_background_frequency_choice_sequence that uses
1172
+ vectorized operations and helper functions, achieving up to x8 faster speed.
1173
+
1174
+ Args:
1175
+ protein (str): The protein sequence.
1176
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1177
+ frequencies for each amino acid.
1178
+
1179
+ Returns:
1180
+ str: The optimized DNA sequence.
1181
+ """
1182
+ dna_codons = []
1183
+ cdf = precompute_cdf(codon_frequencies)
1184
+
1185
+ # Select a random codon for each amino acid using the precomputed CDFs
1186
+ for aminoacid in protein:
1187
+ codons, cumulative_prob = cdf[aminoacid]
1188
+ selected_codon_index = np.searchsorted(cumulative_prob, np.random.rand())
1189
+ dna_codons.append(codons[selected_codon_index])
1190
+
1191
+ return "".join(dna_codons)
1192
+
1193
+
1194
+ def get_uniform_random_choice_sequence(
1195
+ protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
1196
+ ) -> str:
1197
+ """
1198
+ Return the DNA sequence optimized using Uniform Random Choice (URC) approach
1199
+ in which a random codon for a given amino acid is chosen using a uniform
1200
+ prior.
1201
+
1202
+ Args:
1203
+ protein (str): The protein sequence.
1204
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1205
+ frequencies for each amino acid.
1206
+
1207
+ Returns:
1208
+ str: The optimized DNA sequence.
1209
+ """
1210
+ # Select a random codon for each amino acid using a uniform prior distribution
1211
+ dna_codons = []
1212
+ for aminoacid in protein:
1213
+ codons = codon_frequencies[aminoacid][0]
1214
+ random_index = np.random.randint(0, len(codons))
1215
+ dna_codons.append(codons[random_index])
1216
+ return "".join(dna_codons)
1217
+
1218
+
1219
+ def get_icor_prediction(input_seq: str, model_path: str, stop_symbol: str) -> str:
1220
+ """
1221
+ Return the optimized codon sequence for the given protein sequence using ICOR.
1222
+
1223
+ Credit: ICOR: improving codon optimization with recurrent neural networks
1224
+ Rishab Jain, Aditya Jain, Elizabeth Mauro, Kevin LeShane, Douglas
1225
+ Densmore
1226
+
1227
+ Args:
1228
+ input_seq (str): The input protein sequence.
1229
+ model_path (str): The path to the ICOR model.
1230
+ stop_symbol (str): The symbol representing stop codons in the sequence.
1231
+
1232
+ Returns:
1233
+ str: The optimized DNA sequence.
1234
+ """
1235
+ input_seq = input_seq.strip().upper()
1236
+ input_seq = input_seq.replace(stop_symbol, "*")
1237
+
1238
+ # Define categorical labels from when model was trained.
1239
+ labels = [
1240
+ "AAA",
1241
+ "AAC",
1242
+ "AAG",
1243
+ "AAT",
1244
+ "ACA",
1245
+ "ACG",
1246
+ "ACT",
1247
+ "AGC",
1248
+ "ATA",
1249
+ "ATC",
1250
+ "ATG",
1251
+ "ATT",
1252
+ "CAA",
1253
+ "CAC",
1254
+ "CAG",
1255
+ "CCG",
1256
+ "CCT",
1257
+ "CTA",
1258
+ "CTC",
1259
+ "CTG",
1260
+ "CTT",
1261
+ "GAA",
1262
+ "GAT",
1263
+ "GCA",
1264
+ "GCC",
1265
+ "GCG",
1266
+ "GCT",
1267
+ "GGA",
1268
+ "GGC",
1269
+ "GTC",
1270
+ "GTG",
1271
+ "GTT",
1272
+ "TAA",
1273
+ "TAT",
1274
+ "TCA",
1275
+ "TCG",
1276
+ "TCT",
1277
+ "TGG",
1278
+ "TGT",
1279
+ "TTA",
1280
+ "TTC",
1281
+ "TTG",
1282
+ "TTT",
1283
+ "ACC",
1284
+ "CAT",
1285
+ "CCA",
1286
+ "CGG",
1287
+ "CGT",
1288
+ "GAC",
1289
+ "GAG",
1290
+ "GGT",
1291
+ "AGT",
1292
+ "GGG",
1293
+ "GTA",
1294
+ "TGC",
1295
+ "CCC",
1296
+ "CGA",
1297
+ "CGC",
1298
+ "TAC",
1299
+ "TAG",
1300
+ "TCC",
1301
+ "AGA",
1302
+ "AGG",
1303
+ "TGA",
1304
+ ]
1305
+
1306
+ # Define aa to integer table
1307
+ def aa2int(seq: str) -> List[int]:
1308
+ _aa2int = {
1309
+ "A": 1,
1310
+ "R": 2,
1311
+ "N": 3,
1312
+ "D": 4,
1313
+ "C": 5,
1314
+ "Q": 6,
1315
+ "E": 7,
1316
+ "G": 8,
1317
+ "H": 9,
1318
+ "I": 10,
1319
+ "L": 11,
1320
+ "K": 12,
1321
+ "M": 13,
1322
+ "F": 14,
1323
+ "P": 15,
1324
+ "S": 16,
1325
+ "T": 17,
1326
+ "W": 18,
1327
+ "Y": 19,
1328
+ "V": 20,
1329
+ "B": 21,
1330
+ "Z": 22,
1331
+ "X": 23,
1332
+ "*": 24,
1333
+ "-": 25,
1334
+ "?": 26,
1335
+ }
1336
+ return [_aa2int[i] for i in seq]
1337
+
1338
+ # Create empty array to fill
1339
+ oh_array = np.zeros(shape=(26, len(input_seq)))
1340
+
1341
+ # Load placements from aa2int
1342
+ aa_placement = aa2int(input_seq)
1343
+
1344
+ # One-hot encode the amino acid sequence:
1345
+
1346
+ # style nit: more pythonic to write for i in range(0, len(aa_placement)):
1347
+ for i in range(0, len(aa_placement)):
1348
+ oh_array[aa_placement[i], i] = 1
1349
+ i += 1
1350
+
1351
+ oh_array = [oh_array]
1352
+ x = np.array(np.transpose(oh_array))
1353
+
1354
+ y = x.astype(np.float32)
1355
+
1356
+ y = np.reshape(y, (y.shape[0], 1, 26))
1357
+
1358
+ # Start ICOR session using model.
1359
+ sess = rt.InferenceSession(model_path)
1360
+ input_name = sess.get_inputs()[0].name
1361
+
1362
+ # Get prediction:
1363
+ pred_onx = sess.run(None, {input_name: y})
1364
+
1365
+ # Get the index of the highest probability from softmax output:
1366
+ pred_indices = []
1367
+ for pred in pred_onx[0]:
1368
+ pred_indices.append(np.argmax(pred))
1369
+
1370
+ out_str = ""
1371
+ for index in pred_indices:
1372
+ out_str += labels[index]
1373
+
1374
+ return out_str
CodonTransformer/CodonUtils.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonUtils.py
3
+ ---------------------
4
+ Includes constants and helper functions used by other Python scripts.
5
+ """
6
+
7
+ import itertools
8
+ import json
9
+ import os
10
+ import pickle
11
+ import re
12
+ from abc import ABC, abstractmethod
13
+ from dataclasses import dataclass
14
+ from typing import Any, Dict, Iterator, List, Optional, Tuple
15
+
16
+ import pandas as pd
17
+ import requests
18
+ import torch
19
+
20
+ # List of all amino acids
21
+ AMINO_ACIDS: List[str] = [
22
+ "A", # Alanine
23
+ "C", # Cysteine
24
+ "D", # Aspartic acid
25
+ "E", # Glutamic acid
26
+ "F", # Phenylalanine
27
+ "G", # Glycine
28
+ "H", # Histidine
29
+ "I", # Isoleucine
30
+ "K", # Lysine
31
+ "L", # Leucine
32
+ "M", # Methionine
33
+ "N", # Asparagine
34
+ "P", # Proline
35
+ "Q", # Glutamine
36
+ "R", # Arginine
37
+ "S", # Serine
38
+ "T", # Threonine
39
+ "V", # Valine
40
+ "W", # Tryptophan
41
+ "Y", # Tyrosine
42
+ ]
43
+ STOP_SYMBOLS = ["_", "*"] # Stop codon symbols
44
+
45
+ # Dictionary ambiguous amino acids to standard amino acids
46
+ AMBIGUOUS_AMINOACID_MAP: Dict[str, list[str]] = {
47
+ "B": ["N", "D"], # Asparagine (N) or Aspartic acid (D)
48
+ "Z": ["Q", "E"], # Glutamine (Q) or Glutamic acid (E)
49
+ "X": ["A"], # Any amino acid (typically replaced with Alanine)
50
+ "J": ["L", "I"], # Leucine (L) or Isoleucine (I)
51
+ "U": ["C"], # Selenocysteine (typically replaced with Cysteine)
52
+ "O": ["K"], # Pyrrolysine (typically replaced with Lysine)
53
+ }
54
+
55
+ # List of all possible start and stop codons
56
+ START_CODONS: List[str] = ["ATG", "TTG", "CTG", "GTG"]
57
+ STOP_CODONS: List[str] = ["TAA", "TAG", "TGA"]
58
+
59
+ # Token-to-index mapping for amino acids and special tokens
60
+ TOKEN2INDEX: Dict[str, int] = {
61
+ "[UNK]": 0,
62
+ "[CLS]": 1,
63
+ "[SEP]": 2,
64
+ "[PAD]": 3,
65
+ "[MASK]": 4,
66
+ "a_unk": 5,
67
+ "c_unk": 6,
68
+ "d_unk": 7,
69
+ "e_unk": 8,
70
+ "f_unk": 9,
71
+ "g_unk": 10,
72
+ "h_unk": 11,
73
+ "i_unk": 12,
74
+ "k_unk": 13,
75
+ "l_unk": 14,
76
+ "m_unk": 15,
77
+ "n_unk": 16,
78
+ "p_unk": 17,
79
+ "q_unk": 18,
80
+ "r_unk": 19,
81
+ "s_unk": 20,
82
+ "t_unk": 21,
83
+ "v_unk": 22,
84
+ "w_unk": 23,
85
+ "y_unk": 24,
86
+ "__unk": 25,
87
+ "k_aaa": 26,
88
+ "n_aac": 27,
89
+ "k_aag": 28,
90
+ "n_aat": 29,
91
+ "t_aca": 30,
92
+ "t_acc": 31,
93
+ "t_acg": 32,
94
+ "t_act": 33,
95
+ "r_aga": 34,
96
+ "s_agc": 35,
97
+ "r_agg": 36,
98
+ "s_agt": 37,
99
+ "i_ata": 38,
100
+ "i_atc": 39,
101
+ "m_atg": 40,
102
+ "i_att": 41,
103
+ "q_caa": 42,
104
+ "h_cac": 43,
105
+ "q_cag": 44,
106
+ "h_cat": 45,
107
+ "p_cca": 46,
108
+ "p_ccc": 47,
109
+ "p_ccg": 48,
110
+ "p_cct": 49,
111
+ "r_cga": 50,
112
+ "r_cgc": 51,
113
+ "r_cgg": 52,
114
+ "r_cgt": 53,
115
+ "l_cta": 54,
116
+ "l_ctc": 55,
117
+ "l_ctg": 56,
118
+ "l_ctt": 57,
119
+ "e_gaa": 58,
120
+ "d_gac": 59,
121
+ "e_gag": 60,
122
+ "d_gat": 61,
123
+ "a_gca": 62,
124
+ "a_gcc": 63,
125
+ "a_gcg": 64,
126
+ "a_gct": 65,
127
+ "g_gga": 66,
128
+ "g_ggc": 67,
129
+ "g_ggg": 68,
130
+ "g_ggt": 69,
131
+ "v_gta": 70,
132
+ "v_gtc": 71,
133
+ "v_gtg": 72,
134
+ "v_gtt": 73,
135
+ "__taa": 74,
136
+ "y_tac": 75,
137
+ "__tag": 76,
138
+ "y_tat": 77,
139
+ "s_tca": 78,
140
+ "s_tcc": 79,
141
+ "s_tcg": 80,
142
+ "s_tct": 81,
143
+ "__tga": 82,
144
+ "c_tgc": 83,
145
+ "w_tgg": 84,
146
+ "c_tgt": 85,
147
+ "l_tta": 86,
148
+ "f_ttc": 87,
149
+ "l_ttg": 88,
150
+ "f_ttt": 89,
151
+ }
152
+
153
+ # Index-to-token mapping, reverse of TOKEN2INDEX
154
+ INDEX2TOKEN: Dict[int, str] = {i: c for c, i in TOKEN2INDEX.items()}
155
+
156
+ # Dictionary mapping each codon to its GC content
157
+ CODON_GC_CONTENT: Dict[str, int] = {
158
+ token.split("_")[1]: token.split("_")[1].upper().count("G") + token.split("_")[1].upper().count("C")
159
+ for token in TOKEN2INDEX
160
+ if "_" in token and len(token.split("_")[1]) == 3
161
+ }
162
+
163
+ # Tensor with GC counts for each token in the vocabulary
164
+ GC_COUNTS_PER_TOKEN = torch.zeros(len(TOKEN2INDEX))
165
+ for token, index in TOKEN2INDEX.items():
166
+ if "_" in token and len(token.split("_")[1]) == 3:
167
+ codon = token.split("_")[1].upper()
168
+ gc_count = codon.count("G") + codon.count("C")
169
+ GC_COUNTS_PER_TOKEN[index] = gc_count
170
+
171
+ G_indices = [idx for token, idx in TOKEN2INDEX.items() if "g" in token.split("_")[-1]]
172
+ C_indices = [idx for token, idx in TOKEN2INDEX.items() if "c" in token.split("_")[-1]]
173
+
174
+ # Dictionary mapping each amino acid and stop symbol to indices of codon tokens that translate to it
175
+ AMINO_ACID_TO_INDEX = {
176
+ aa: sorted(
177
+ [i for t, i in TOKEN2INDEX.items() if t[0].upper() == aa and t[-3:] != "unk"]
178
+ )
179
+ for aa in (AMINO_ACIDS + STOP_SYMBOLS)
180
+ }
181
+
182
+
183
+ # Dictionary mapping each amino acid to min/max GC content across all possible codons
184
+ AA_MIN_GC: Dict[str, int] = {}
185
+ AA_MAX_GC: Dict[str, int] = {}
186
+
187
+ for aa, token_indices in AMINO_ACID_TO_INDEX.items():
188
+ if token_indices: # Skip if no tokens for this amino acid
189
+ gc_counts = []
190
+ for token_idx in token_indices:
191
+ token = INDEX2TOKEN[token_idx]
192
+ if "_" in token and len(token.split("_")[1]) == 3:
193
+ codon = token.split("_")[1]
194
+ if codon in CODON_GC_CONTENT:
195
+ gc_counts.append(CODON_GC_CONTENT[codon])
196
+
197
+ if gc_counts:
198
+ AA_MIN_GC[aa] = min(gc_counts)
199
+ AA_MAX_GC[aa] = max(gc_counts)
200
+
201
+ # Mask token mapping
202
+ TOKEN2MASK: Dict[int, int] = {
203
+ 0: 0,
204
+ 1: 1,
205
+ 2: 2,
206
+ 3: 3,
207
+ 4: 4,
208
+ 5: 5,
209
+ 6: 6,
210
+ 7: 7,
211
+ 8: 8,
212
+ 9: 9,
213
+ 10: 10,
214
+ 11: 11,
215
+ 12: 12,
216
+ 13: 13,
217
+ 14: 14,
218
+ 15: 15,
219
+ 16: 16,
220
+ 17: 17,
221
+ 18: 18,
222
+ 19: 19,
223
+ 20: 20,
224
+ 21: 21,
225
+ 22: 22,
226
+ 23: 23,
227
+ 24: 24,
228
+ 25: 25,
229
+ 26: 13,
230
+ 27: 16,
231
+ 28: 13,
232
+ 29: 16,
233
+ 30: 21,
234
+ 31: 21,
235
+ 32: 21,
236
+ 33: 21,
237
+ 34: 19,
238
+ 35: 20,
239
+ 36: 19,
240
+ 37: 20,
241
+ 38: 12,
242
+ 39: 12,
243
+ 40: 15,
244
+ 41: 12,
245
+ 42: 18,
246
+ 43: 11,
247
+ 44: 18,
248
+ 45: 11,
249
+ 46: 17,
250
+ 47: 17,
251
+ 48: 17,
252
+ 49: 17,
253
+ 50: 19,
254
+ 51: 19,
255
+ 52: 19,
256
+ 53: 19,
257
+ 54: 14,
258
+ 55: 14,
259
+ 56: 14,
260
+ 57: 14,
261
+ 58: 8,
262
+ 59: 7,
263
+ 60: 8,
264
+ 61: 7,
265
+ 62: 5,
266
+ 63: 5,
267
+ 64: 5,
268
+ 65: 5,
269
+ 66: 10,
270
+ 67: 10,
271
+ 68: 10,
272
+ 69: 10,
273
+ 70: 22,
274
+ 71: 22,
275
+ 72: 22,
276
+ 73: 22,
277
+ 74: 25,
278
+ 75: 24,
279
+ 76: 25,
280
+ 77: 24,
281
+ 78: 20,
282
+ 79: 20,
283
+ 80: 20,
284
+ 81: 20,
285
+ 82: 25,
286
+ 83: 6,
287
+ 84: 23,
288
+ 85: 6,
289
+ 86: 14,
290
+ 87: 9,
291
+ 88: 14,
292
+ 89: 9,
293
+ }
294
+
295
+ # List of organisms used for fine-tuning
296
+ FINE_TUNE_ORGANISMS: List[str] = [
297
+ "Arabidopsis thaliana",
298
+ "Bacillus subtilis",
299
+ "Caenorhabditis elegans",
300
+ "Chlamydomonas reinhardtii",
301
+ "Chlamydomonas reinhardtii chloroplast",
302
+ "Danio rerio",
303
+ "Drosophila melanogaster",
304
+ "Homo sapiens",
305
+ "Mus musculus",
306
+ "Nicotiana tabacum",
307
+ "Nicotiana tabacum chloroplast",
308
+ "Pseudomonas putida",
309
+ "Saccharomyces cerevisiae",
310
+ "Escherichia coli O157-H7 str. Sakai",
311
+ "Escherichia coli general",
312
+ "Escherichia coli str. K-12 substr. MG1655",
313
+ "Thermococcus barophilus MPT",
314
+ ]
315
+
316
+ # List of organisms most commonly used for coodn optimization
317
+ COMMON_ORGANISMS: List[str] = [
318
+ "Arabidopsis thaliana",
319
+ "Bacillus subtilis",
320
+ "Caenorhabditis elegans",
321
+ "Chlamydomonas reinhardtii",
322
+ "Danio rerio",
323
+ "Drosophila melanogaster",
324
+ "Homo sapiens",
325
+ "Mus musculus",
326
+ "Nicotiana tabacum",
327
+ "Pseudomonas putida",
328
+ "Saccharomyces cerevisiae",
329
+ "Escherichia coli general",
330
+ ]
331
+
332
+ # Dictionary mapping each organism name to respective organism id
333
+ ORGANISM2ID: Dict[str, int] = {
334
+ "Arabidopsis thaliana": 0,
335
+ "Atlantibacter hermannii": 1,
336
+ "Bacillus subtilis": 2,
337
+ "Brenneria goodwinii": 3,
338
+ "Buchnera aphidicola (Schizaphis graminum)": 4,
339
+ "Caenorhabditis elegans": 5,
340
+ "Candidatus Erwinia haradaeae": 6,
341
+ "Candidatus Hamiltonella defensa 5AT (Acyrthosiphon pisum)": 7,
342
+ "Chlamydomonas reinhardtii": 8,
343
+ "Chlamydomonas reinhardtii chloroplast": 9,
344
+ "Citrobacter amalonaticus": 10,
345
+ "Citrobacter braakii": 11,
346
+ "Citrobacter cronae": 12,
347
+ "Citrobacter europaeus": 13,
348
+ "Citrobacter farmeri": 14,
349
+ "Citrobacter freundii": 15,
350
+ "Citrobacter koseri ATCC BAA-895": 16,
351
+ "Citrobacter portucalensis": 17,
352
+ "Citrobacter werkmanii": 18,
353
+ "Citrobacter youngae": 19,
354
+ "Cronobacter dublinensis subsp. dublinensis LMG 23823": 20,
355
+ "Cronobacter malonaticus LMG 23826": 21,
356
+ "Cronobacter sakazakii": 22,
357
+ "Cronobacter turicensis": 23,
358
+ "Danio rerio": 24,
359
+ "Dickeya dadantii 3937": 25,
360
+ "Dickeya dianthicola": 26,
361
+ "Dickeya fangzhongdai": 27,
362
+ "Dickeya solani": 28,
363
+ "Dickeya zeae": 29,
364
+ "Drosophila melanogaster": 30,
365
+ "Edwardsiella anguillarum ET080813": 31,
366
+ "Edwardsiella ictaluri": 32,
367
+ "Edwardsiella piscicida": 33,
368
+ "Edwardsiella tarda": 34,
369
+ "Enterobacter asburiae": 35,
370
+ "Enterobacter bugandensis": 36,
371
+ "Enterobacter cancerogenus": 37,
372
+ "Enterobacter chengduensis": 38,
373
+ "Enterobacter cloacae": 39,
374
+ "Enterobacter hormaechei": 40,
375
+ "Enterobacter kobei": 41,
376
+ "Enterobacter ludwigii": 42,
377
+ "Enterobacter mori": 43,
378
+ "Enterobacter quasiroggenkampii": 44,
379
+ "Enterobacter roggenkampii": 45,
380
+ "Enterobacter sichuanensis": 46,
381
+ "Erwinia amylovora CFBP1430": 47,
382
+ "Erwinia persicina": 48,
383
+ "Escherichia albertii": 49,
384
+ "Escherichia coli O157-H7 str. Sakai": 50,
385
+ "Escherichia coli general": 51,
386
+ "Escherichia coli str. K-12 substr. MG1655": 52,
387
+ "Escherichia fergusonii": 53,
388
+ "Escherichia marmotae": 54,
389
+ "Escherichia ruysiae": 55,
390
+ "Ewingella americana": 56,
391
+ "Hafnia alvei": 57,
392
+ "Hafnia paralvei": 58,
393
+ "Homo sapiens": 59,
394
+ "Kalamiella piersonii": 60,
395
+ "Klebsiella aerogenes": 61,
396
+ "Klebsiella grimontii": 62,
397
+ "Klebsiella michiganensis": 63,
398
+ "Klebsiella oxytoca": 64,
399
+ "Klebsiella pasteurii": 65,
400
+ "Klebsiella pneumoniae subsp. pneumoniae HS11286": 66,
401
+ "Klebsiella quasipneumoniae": 67,
402
+ "Klebsiella quasivariicola": 68,
403
+ "Klebsiella variicola": 69,
404
+ "Kosakonia cowanii": 70,
405
+ "Kosakonia radicincitans": 71,
406
+ "Leclercia adecarboxylata": 72,
407
+ "Lelliottia amnigena": 73,
408
+ "Lonsdalea populi": 74,
409
+ "Moellerella wisconsensis": 75,
410
+ "Morganella morganii": 76,
411
+ "Mus musculus": 77,
412
+ "Nicotiana tabacum": 78,
413
+ "Nicotiana tabacum chloroplast": 79,
414
+ "Obesumbacterium proteus": 80,
415
+ "Pantoea agglomerans": 81,
416
+ "Pantoea allii": 82,
417
+ "Pantoea ananatis PA13": 83,
418
+ "Pantoea dispersa": 84,
419
+ "Pantoea stewartii": 85,
420
+ "Pantoea vagans": 86,
421
+ "Pectobacterium aroidearum": 87,
422
+ "Pectobacterium atrosepticum": 88,
423
+ "Pectobacterium brasiliense": 89,
424
+ "Pectobacterium carotovorum": 90,
425
+ "Pectobacterium odoriferum": 91,
426
+ "Pectobacterium parmentieri": 92,
427
+ "Pectobacterium polaris": 93,
428
+ "Pectobacterium versatile": 94,
429
+ "Photorhabdus laumondii subsp. laumondii TTO1": 95,
430
+ "Plesiomonas shigelloides": 96,
431
+ "Pluralibacter gergoviae": 97,
432
+ "Proteus faecis": 98,
433
+ "Proteus mirabilis HI4320": 99,
434
+ "Proteus penneri": 100,
435
+ "Proteus terrae subsp. cibarius": 101,
436
+ "Proteus vulgaris": 102,
437
+ "Providencia alcalifaciens": 103,
438
+ "Providencia heimbachae": 104,
439
+ "Providencia rettgeri": 105,
440
+ "Providencia rustigianii": 106,
441
+ "Providencia stuartii": 107,
442
+ "Providencia thailandensis": 108,
443
+ "Pseudomonas putida": 109,
444
+ "Pyrococcus furiosus": 110,
445
+ "Pyrococcus horikoshii": 111,
446
+ "Pyrococcus yayanosii": 112,
447
+ "Rahnella aquatilis CIP 78.65 = ATCC 33071": 113,
448
+ "Raoultella ornithinolytica": 114,
449
+ "Raoultella planticola": 115,
450
+ "Raoultella terrigena": 116,
451
+ "Rosenbergiella epipactidis": 117,
452
+ "Rouxiella badensis": 118,
453
+ "Saccharolobus solfataricus": 119,
454
+ "Saccharomyces cerevisiae": 120,
455
+ "Salmonella bongori N268-08": 121,
456
+ "Salmonella enterica subsp. enterica serovar Typhimurium str. LT2": 122,
457
+ "Serratia bockelmannii": 123,
458
+ "Serratia entomophila": 124,
459
+ "Serratia ficaria": 125,
460
+ "Serratia fonticola": 126,
461
+ "Serratia grimesii": 127,
462
+ "Serratia liquefaciens": 128,
463
+ "Serratia marcescens": 129,
464
+ "Serratia nevei": 130,
465
+ "Serratia plymuthica AS9": 131,
466
+ "Serratia proteamaculans": 132,
467
+ "Serratia quinivorans": 133,
468
+ "Serratia rubidaea": 134,
469
+ "Serratia ureilytica": 135,
470
+ "Shigella boydii": 136,
471
+ "Shigella dysenteriae": 137,
472
+ "Shigella flexneri 2a str. 301": 138,
473
+ "Shigella sonnei": 139,
474
+ "Thermoccoccus kodakarensis": 140,
475
+ "Thermococcus barophilus MPT": 141,
476
+ "Thermococcus chitonophagus": 142,
477
+ "Thermococcus gammatolerans": 143,
478
+ "Thermococcus litoralis": 144,
479
+ "Thermococcus onnurineus": 145,
480
+ "Thermococcus sibiricus": 146,
481
+ "Xenorhabdus bovienii str. feltiae Florida": 147,
482
+ "Yersinia aldovae 670-83": 148,
483
+ "Yersinia aleksiciae": 149,
484
+ "Yersinia alsatica": 150,
485
+ "Yersinia enterocolitica": 151,
486
+ "Yersinia frederiksenii ATCC 33641": 152,
487
+ "Yersinia intermedia": 153,
488
+ "Yersinia kristensenii": 154,
489
+ "Yersinia massiliensis CCUG 53443": 155,
490
+ "Yersinia mollaretii ATCC 43969": 156,
491
+ "Yersinia pestis A1122": 157,
492
+ "Yersinia proxima": 158,
493
+ "Yersinia pseudotuberculosis IP 32953": 159,
494
+ "Yersinia rochesterensis": 160,
495
+ "Yersinia rohdei": 161,
496
+ "Yersinia ruckeri": 162,
497
+ "Yokenella regensburgei": 163,
498
+ }
499
+
500
+ # Dictionary mapping each organism id to respective organism name
501
+ ID2ORGANISM = {v: k for k, v in ORGANISM2ID.items()}
502
+
503
+ # Type alias for amino acid to codon mapping
504
+ AMINO2CODON_TYPE = Dict[str, Tuple[List[str], List[float]]]
505
+
506
+ # Constants for the number of organisms and sequence lengths
507
+ NUM_ORGANISMS = 164
508
+ MAX_LEN = 2048
509
+ MAX_AMINO_ACIDS = MAX_LEN - 2 # Without special tokens [CLS] and [SEP]
510
+ STOP_SYMBOL = "_"
511
+
512
+
513
+ @dataclass
514
+ class DNASequencePrediction:
515
+ """
516
+ A class to hold the output of the DNA sequence prediction.
517
+
518
+ Attributes:
519
+ organism (str): Name of the organism used for prediction.
520
+ protein (str): Input protein sequence for which DNA sequence is predicted.
521
+ processed_input (str): Processed input sequence (merged protein and DNA).
522
+ predicted_dna (str): Predicted DNA sequence.
523
+ """
524
+
525
+ organism: str
526
+ protein: str
527
+ processed_input: str
528
+ predicted_dna: str
529
+
530
+
531
+ class IterableData(torch.utils.data.IterableDataset):
532
+ """
533
+ Defines the logic for iterable datasets (working over streams of
534
+ data) in parallel multi-processing environments, e.g., multi-GPU.
535
+
536
+ Args:
537
+ dist_env (Optional[str]): The distribution environment identifier
538
+ (e.g., "slurm").
539
+
540
+ Credit: Guillaume Filion
541
+ """
542
+
543
+ def __init__(self, dist_env: Optional[str] = None):
544
+ super().__init__()
545
+ if dist_env is None:
546
+ self.world_size_handle, self.rank_handle = ("WORLD_SIZE", "LOCAL_RANK")
547
+ else:
548
+ self.world_size_handle, self.rank_handle = {
549
+ "slurm": ("SLURM_NTASKS", "SLURM_PROCID")
550
+ }.get(dist_env, ("WORLD_SIZE", "LOCAL_RANK"))
551
+
552
+ @property
553
+ def iterator(self) -> Iterator:
554
+ """Define the stream logic for the dataset. Implement in subclasses."""
555
+ raise NotImplementedError
556
+
557
+ def __iter__(self) -> Iterator:
558
+ """
559
+ Create an iterator for the dataset, handling multi-processing contexts.
560
+
561
+ Returns:
562
+ Iterator: The iterator for the dataset.
563
+ """
564
+ worker_info = torch.utils.data.get_worker_info()
565
+ if worker_info is None:
566
+ return self.iterator
567
+
568
+ # In multi-processing context, use 'os.environ' to
569
+ # find global worker rank. Then use 'islice' to allocate
570
+ # the items of the stream to the workers.
571
+ world_size = int(os.environ.get(self.world_size_handle, "1"))
572
+ global_rank = int(os.environ.get(self.rank_handle, "0"))
573
+ local_rank = worker_info.id
574
+ local_num_workers = worker_info.num_workers
575
+
576
+ # Assume that each process has the same number of local workers.
577
+ worker_rk = global_rank * local_num_workers + local_rank
578
+ worker_nb = world_size * local_num_workers
579
+ return itertools.islice(self.iterator, worker_rk, None, worker_nb)
580
+
581
+
582
+ class IterableJSONData(IterableData):
583
+ """
584
+ Iterate over the lines of a JSON file and uncompress if needed.
585
+
586
+ Args:
587
+ data_path (str): The path to the JSON data file.
588
+ train (bool): Flag indicating if the dataset is for training.
589
+ **kwargs: Additional keyword arguments for the base class.
590
+ """
591
+
592
+ def __init__(self, data_path: str, train: bool = True, **kwargs):
593
+ super().__init__(**kwargs)
594
+ self.data_path = data_path
595
+ self.train = train
596
+ with open(os.path.join(self.data_path, "finetune_set.json"), "r") as f:
597
+ self.records = [json.loads(line) for line in f]
598
+
599
+ def __len__(self):
600
+ return len(self.records)
601
+
602
+ @property
603
+ def iterator(self) -> Iterator:
604
+ """Define the stream logic for the dataset."""
605
+ for record in self.records:
606
+ yield record
607
+
608
+
609
+ class ConfigManager(ABC):
610
+ """
611
+ Abstract base class for managing configuration settings.
612
+ """
613
+ _config: Dict[str, Any]
614
+
615
+ def __enter__(self):
616
+ return self
617
+
618
+ def __exit__(self, exc_type, exc_value, traceback):
619
+ if exc_type is not None:
620
+ print(f"Exception occurred: {exc_type}, {exc_value}, {traceback}")
621
+ self.reset_config()
622
+
623
+ @abstractmethod
624
+ def reset_config(self) -> None:
625
+ """Reset the configuration to default values."""
626
+ pass
627
+
628
+ def get(self, key: str) -> Any:
629
+ """
630
+ Get the value of a configuration key.
631
+
632
+ Args:
633
+ key (str): The key to retrieve the value for.
634
+
635
+ Returns:
636
+ Any: The value of the configuration key.
637
+ """
638
+ return self._config.get(key)
639
+
640
+ def set(self, key: str, value: Any) -> None:
641
+ """
642
+ Set the value of a configuration key.
643
+
644
+ Args:
645
+ key (str): The key to set the value for.
646
+ value (Any): The value to set for the key.
647
+ """
648
+ self.validate_inputs(key, value)
649
+ self._config[key] = value
650
+
651
+ def update(self, config_dict: dict) -> None:
652
+ """
653
+ Update the configuration with a dictionary of key-value pairs after validating them.
654
+
655
+ Args:
656
+ config_dict (dict): A dictionary of key-value pairs to update the configuration.
657
+ """
658
+ for key, value in config_dict.items():
659
+ self.validate_inputs(key, value)
660
+ self._config.update(config_dict)
661
+
662
+ @abstractmethod
663
+ def validate_inputs(self, key: str, value: Any) -> None:
664
+ """Validate the inputs for the configuration."""
665
+ pass
666
+
667
+
668
+ class ProteinConfig(ConfigManager):
669
+ """
670
+ A class to manage configuration settings for protein sequences.
671
+
672
+ This class ensures that the configuration is a singleton.
673
+ It provides methods to get, set, and update configuration values.
674
+
675
+ Attributes:
676
+ _instance (Optional[ConfigManager]): The singleton instance of the ConfigManager.
677
+ _config (Dict[str, Any]): The configuration dictionary.
678
+ """
679
+
680
+ _instance = None
681
+
682
+ def __new__(cls):
683
+ """
684
+ Create a new instance of the ProteinConfig class.
685
+
686
+ Returns:
687
+ ProteinConfig: The singleton instance of the ProteinConfig.
688
+ """
689
+ if cls._instance is None:
690
+ cls._instance = super(ProteinConfig, cls).__new__(cls)
691
+ cls._instance.reset_config()
692
+ return cls._instance
693
+
694
+ def validate_inputs(self, key: str, value: Any) -> None:
695
+ """
696
+ Validate the inputs for the configuration.
697
+
698
+ Args:
699
+ key (str): The key to validate.
700
+ value (Any): The value to validate.
701
+
702
+ Raises:
703
+ ValueError: If the value is invalid.
704
+ TypeError: If the value is of the wrong type.
705
+ """
706
+ if key == "ambiguous_aminoacid_behavior":
707
+ if value not in [
708
+ "raise_error",
709
+ "standardize_deterministic",
710
+ "standardize_random",
711
+ ]:
712
+ raise ValueError(
713
+ f"Invalid value for ambiguous_aminoacid_behavior: {value}."
714
+ )
715
+ elif key == "ambiguous_aminoacid_map_override":
716
+ if not isinstance(value, dict):
717
+ raise TypeError(
718
+ f"Invalid type for ambiguous_aminoacid_map_override: {value}."
719
+ )
720
+ for ambiguous_aminoacid, aminoacids in value.items():
721
+ if not isinstance(aminoacids, list):
722
+ raise TypeError(f"Invalid type for aminoacids: {aminoacids}.")
723
+ if not aminoacids:
724
+ raise ValueError(
725
+ f"Override for aminoacid '{ambiguous_aminoacid}' cannot be empty list."
726
+ )
727
+ if ambiguous_aminoacid not in AMBIGUOUS_AMINOACID_MAP:
728
+ raise ValueError(
729
+ f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}"
730
+ )
731
+ else:
732
+ raise ValueError(f"Invalid configuration key: {key}")
733
+
734
+ def reset_config(self) -> None:
735
+ """
736
+ Reset the configuration to the default values.
737
+ """
738
+ self._config = {
739
+ "ambiguous_aminoacid_behavior": "standardize_random",
740
+ "ambiguous_aminoacid_map_override": {},
741
+ }
742
+
743
+
744
+ def load_python_object_from_disk(file_path: str) -> Any:
745
+ """
746
+ Load a Pickle object from disk and return it as a Python object.
747
+
748
+ Args:
749
+ file_path (str): The path to the Pickle file.
750
+
751
+ Returns:
752
+ Any: The loaded Python object.
753
+ """
754
+ with open(file_path, "rb") as file:
755
+ return pickle.load(file)
756
+
757
+
758
+ def save_python_object_to_disk(input_object: Any, file_path: str) -> None:
759
+ """
760
+ Save a Python object to disk using Pickle.
761
+
762
+ Args:
763
+ input_object (Any): The Python object to save.
764
+ file_path (str): The path where the object will be saved.
765
+ """
766
+ with open(file_path, "wb") as file:
767
+ pickle.dump(input_object, file)
768
+
769
+
770
+ def find_pattern_in_fasta(keyword: str, text: str) -> str:
771
+ """
772
+ Find a specific keyword pattern in text. Helpful for identifying parts
773
+ of a FASTA sequence.
774
+
775
+ Args:
776
+ keyword (str): The keyword pattern to search for.
777
+ text (str): The text to search within.
778
+
779
+ Returns:
780
+ str: The found pattern or an empty string if not found.
781
+ """
782
+ # Search for the keyword pattern in the text using regex
783
+ result = re.search(keyword + r"=(.*?)]", text)
784
+ return result.group(1) if result else ""
785
+
786
+
787
+ def get_organism2id_dict(organism_reference: str) -> Dict[str, int]:
788
+ """
789
+ Return a dictionary mapping each organism in training data to an index
790
+ used for training.
791
+
792
+ Args:
793
+ organism_reference (str): Path to a CSV file containing a list of
794
+ all organisms. The format of the CSV file should be as follows:
795
+
796
+ 0,Escherichia coli
797
+ 1,Homo sapiens
798
+ 2,Mus musculus
799
+
800
+ Returns:
801
+ Dict[str, int]: Dictionary mapping organism names to their respective indices.
802
+ """
803
+ # Read the CSV file and create a dictionary mapping organisms to their indices
804
+ organisms = pd.read_csv(organism_reference, index_col=0, header=None)
805
+ organism2id = {organisms.iloc[i].values[0]: i for i in organisms.index}
806
+
807
+ return organism2id
808
+
809
+
810
+ def get_taxonomy_id(
811
+ taxonomy_reference: str, organism: Optional[str] = None, return_dict: bool = False
812
+ ) -> Any:
813
+ """
814
+ Return the taxonomy id of a given organism using a reference file.
815
+ Optionally, return the whole dictionary instead if return_dict is True.
816
+
817
+ Args:
818
+ taxonomy_reference (str): Path to the taxonomy reference file.
819
+ organism (Optional[str]): The name of the organism to look up.
820
+ return_dict (bool): Whether to return the entire dictionary.
821
+
822
+ Returns:
823
+ Any: The taxonomy id of the organism or the entire dictionary.
824
+ """
825
+ # Load the organism-to-taxonomy mapping from a Pickle file
826
+ organism2taxonomy = load_python_object_from_disk(taxonomy_reference)
827
+
828
+ if return_dict:
829
+ return dict(sorted(organism2taxonomy.items()))
830
+
831
+ return organism2taxonomy[organism]
832
+
833
+
834
+ def sort_amino2codon_skeleton(amino2codon: Dict[str, Any]) -> Dict[str, Any]:
835
+ """
836
+ Sort the amino2codon dictionary alphabetically by amino acid and by codon name.
837
+
838
+ Args:
839
+ amino2codon (Dict[str, Any]): The amino2codon dictionary to sort.
840
+
841
+ Returns:
842
+ Dict[str, Any]: The sorted amino2codon dictionary.
843
+ """
844
+ # Sort the dictionary by amino acid and then by codon name
845
+ amino2codon = dict(sorted(amino2codon.items()))
846
+ amino2codon = {
847
+ amino: (
848
+ [codon for codon, _ in sorted(zip(codons, frequencies))],
849
+ [freq for _, freq in sorted(zip(codons, frequencies))],
850
+ )
851
+ for amino, (codons, frequencies) in amino2codon.items()
852
+ }
853
+
854
+ return amino2codon
855
+
856
+
857
+ def load_pkl_from_url(url: str) -> Any:
858
+ """
859
+ Download a Pickle file from a URL and return the loaded object.
860
+
861
+ Args:
862
+ url (str): The URL to download the Pickle file from.
863
+
864
+ Returns:
865
+ Any: The loaded Python object from the Pickle file.
866
+ """
867
+ response = requests.get(url)
868
+ response.raise_for_status() # Ensure the request was successful
869
+
870
+ # Load the Pickle object from the response content
871
+ return pickle.loads(response.content)
CodonTransformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """CodonTransformer package."""