saketh11 commited on
Commit
e67a8f9
·
1 Parent(s): c5352e6

Enhance model loading and error handling in app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -126,7 +126,13 @@ def load_model_and_tokenizer():
126
  status_text.text(f"⚠️ Failed to load from Hugging Face: {str(e)[:50]}...")
127
  status_text.text("Loading base model as fallback...")
128
  st.session_state.model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
129
- st.session_state.model = st.session_state.model.to(st.session_state.device)
 
 
 
 
 
 
130
  st.session_state.model_type = "base"
131
 
132
  progress_bar.progress(100)
@@ -296,7 +302,8 @@ def calculate_input_metrics(sequence: str, organism: str, sequence_type: str) ->
296
  metrics['tai'] = None
297
  try:
298
  analysis_dna = metrics['baseline_dna']
299
- metrics['restriction_sites'] = len(scan_for_restriction_sites(analysis_dna))
 
300
  metrics['negative_cis_elements'] = count_negative_cis_elements(analysis_dna)
301
  metrics['homopolymer_runs'] = calculate_homopolymer_runs(analysis_dna)
302
  except:
@@ -608,9 +615,9 @@ def run_optimization(protein: str, organism: str, use_post_processing: bool = Fa
608
  # Create enhanced result object
609
  from CodonTransformer.CodonUtils import DNASequencePrediction
610
  st.session_state.post_processed_results = DNASequencePrediction(
611
- organism=result.organism,
612
- protein=result.protein,
613
- processed_input=result.processed_input,
614
  predicted_dna=polished_sequence
615
  )
616
  except Exception as e:
@@ -706,7 +713,7 @@ def single_sequence_optimization():
706
  st.header("🧬 Input Sequence")
707
  sequence_input = st.text_area(
708
  "Enter Protein or DNA Sequence",
709
- height=150,
710
  placeholder="Enter protein sequence (MKWVT...) or DNA sequence (ATGGCG...)\n\nExample protein: MKWVTFISLLFLFSSAYSRGVFRRDAHKSEVAHRFKDLGEENFKALVLIAFAQYLQQCPFEDHVKLVNEVTEFAKTCVADESAENCDKSLHTLFGDKLCTVATLRETYGEMADCCAKQEPERNECFLQHKDDNPNLPRLVRPEVDVMCTAFHDNEETFLKKYLYEIARRHPYFYAPELLFFAKRYKAAFTECCQAADKAACLLPKLDELRDEGKASSAKQRLKCASLQKFGERAFKAWAVARLSQRFPKAEFAEVSKLVTDLTKVHTECCHGDLLECADDRADLAKYICENQDSISSKLKECCEKPLLEKSHCIAEVENDEMPADLPSLAADFVESKDVCKNYAEAKDVFLGMFLYEYARRHPDYSVVLLLRLAKTYETTLEKCCAAADPHECYAKVFDEFKPLVEEPQNLIKQNCELFEQLGEYKFQNALLVRYTKKVPQVSTPTLVEVSRNLGKVGSKCCKHPEAKRMPCAEDYLSVVLNQLCVLHEKTPVSDRVTKCCTE"
711
  )
712
  analyze_btn = st.button("Analyze Sequence", type="primary")
@@ -809,10 +816,10 @@ def single_sequence_optimization():
809
  if st.button("🚀 Optimize Sequence", type="primary", use_container_width=True):
810
  st.session_state.results = None
811
  if st.session_state.sequence_type == "dna":
812
- protein_sequence = translate_dna_to_protein(st.session_state.sequence_clean)
813
- run_optimization(protein_sequence, st.session_state.organism, use_post_processing)
814
  else:
815
- run_optimization(st.session_state.sequence_clean, st.session_state.organism, use_post_processing)
816
 
817
  # Enhanced progress display
818
  if st.session_state.optimization_running:
@@ -906,7 +913,7 @@ def display_optimization_results(result, organism, original_sequence, sequence_t
906
 
907
  # Optimized DNA sequence display
908
  st.subheader("🧬 Optimized DNA Sequence")
909
- st.text_area("Optimized DNA Sequence", result.predicted_dna, height=100)
910
 
911
  # Enhanced download and export options
912
  col1, col2, col3 = st.columns(3)
 
126
  status_text.text(f"⚠️ Failed to load from Hugging Face: {str(e)[:50]}...")
127
  status_text.text("Loading base model as fallback...")
128
  st.session_state.model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
129
+ if isinstance(st.session_state.model, torch.nn.Module):
130
+ if isinstance(st.session_state.model, torch.nn.Module):
131
+ st.session_state.model.to(st.session_state.device)
132
+ else:
133
+ st.warning("Fallback model loaded is not a PyTorch module. Cannot move to device.")
134
+ else:
135
+ st.warning("Fallback model loaded is not a PyTorch module. Cannot move to device.")
136
  st.session_state.model_type = "base"
137
 
138
  progress_bar.progress(100)
 
302
  metrics['tai'] = None
303
  try:
304
  analysis_dna = metrics['baseline_dna']
305
+ # scan_for_restriction_sites returns an int, not a list, so no need for len()
306
+ metrics['restriction_sites'] = scan_for_restriction_sites(analysis_dna)
307
  metrics['negative_cis_elements'] = count_negative_cis_elements(analysis_dna)
308
  metrics['homopolymer_runs'] = calculate_homopolymer_runs(analysis_dna)
309
  except:
 
615
  # Create enhanced result object
616
  from CodonTransformer.CodonUtils import DNASequencePrediction
617
  st.session_state.post_processed_results = DNASequencePrediction(
618
+ organism=_res.organism,
619
+ protein=_res.protein,
620
+ processed_input=_res.processed_input,
621
  predicted_dna=polished_sequence
622
  )
623
  except Exception as e:
 
713
  st.header("🧬 Input Sequence")
714
  sequence_input = st.text_area(
715
  "Enter Protein or DNA Sequence",
716
+ height=300,
717
  placeholder="Enter protein sequence (MKWVT...) or DNA sequence (ATGGCG...)\n\nExample protein: MKWVTFISLLFLFSSAYSRGVFRRDAHKSEVAHRFKDLGEENFKALVLIAFAQYLQQCPFEDHVKLVNEVTEFAKTCVADESAENCDKSLHTLFGDKLCTVATLRETYGEMADCCAKQEPERNECFLQHKDDNPNLPRLVRPEVDVMCTAFHDNEETFLKKYLYEIARRHPYFYAPELLFFAKRYKAAFTECCQAADKAACLLPKLDELRDEGKASSAKQRLKCASLQKFGERAFKAWAVARLSQRFPKAEFAEVSKLVTDLTKVHTECCHGDLLECADDRADLAKYICENQDSISSKLKECCEKPLLEKSHCIAEVENDEMPADLPSLAADFVESKDVCKNYAEAKDVFLGMFLYEYARRHPDYSVVLLLRLAKTYETTLEKCCAAADPHECYAKVFDEFKPLVEEPQNLIKQNCELFEQLGEYKFQNALLVRYTKKVPQVSTPTLVEVSRNLGKVGSKCCKHPEAKRMPCAEDYLSVVLNQLCVLHEKTPVSDRVTKCCTE"
718
  )
719
  analyze_btn = st.button("Analyze Sequence", type="primary")
 
816
  if st.button("🚀 Optimize Sequence", type="primary", use_container_width=True):
817
  st.session_state.results = None
818
  if st.session_state.sequence_type == "dna":
819
+ protein_sequence = translate_dna_to_protein(str(st.session_state.sequence_clean))
820
+ run_optimization(protein_sequence, str(st.session_state.organism), use_post_processing)
821
  else:
822
+ run_optimization(str(st.session_state.sequence_clean), str(st.session_state.organism), use_post_processing)
823
 
824
  # Enhanced progress display
825
  if st.session_state.optimization_running:
 
913
 
914
  # Optimized DNA sequence display
915
  st.subheader("🧬 Optimized DNA Sequence")
916
+ st.text_area("Optimized DNA Sequence", result.predicted_dna, height=300)
917
 
918
  # Enhanced download and export options
919
  col1, col2, col3 = st.columns(3)