saketh11 commited on
Commit
6049bfc
·
1 Parent(s): e67a8f9

Refactor model and data downloading in app.py to include Hugging Face token support

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -102,18 +102,16 @@ def load_model_and_tokenizer():
102
 
103
  status_text.text("Loading fine-tuned model from Hugging Face...")
104
  progress_bar.progress(50)
105
- # Try to download and load fine-tuned model from Hugging Face
106
  try:
107
- # Download the checkpoint file from Hugging Face
108
  from huggingface_hub import hf_hub_download
109
-
110
  status_text.text("⬇️ Downloading model from saketh11/ColiFormer...")
111
  model_path = hf_hub_download(
112
  repo_id="saketh11/ColiFormer",
113
  filename="balanced_alm_finetune.ckpt",
114
- cache_dir="./hf_cache"
 
115
  )
116
-
117
  status_text.text("🔄 Loading downloaded model...")
118
  st.session_state.model = load_model(
119
  model_path=model_path,
@@ -127,10 +125,7 @@ def load_model_and_tokenizer():
127
  status_text.text("Loading base model as fallback...")
128
  st.session_state.model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
129
  if isinstance(st.session_state.model, torch.nn.Module):
130
- if isinstance(st.session_state.model, torch.nn.Module):
131
- st.session_state.model.to(st.session_state.device)
132
- else:
133
- st.warning("Fallback model loaded is not a PyTorch module. Cannot move to device.")
134
  else:
135
  st.warning("Fallback model loaded is not a PyTorch module. Cannot move to device.")
136
  st.session_state.model_type = "base"
@@ -145,17 +140,18 @@ def load_model_and_tokenizer():
145
  def download_reference_data():
146
  """Download and cache reference data from Hugging Face"""
147
  try:
148
- # Download the processed genes file from Hugging Face
 
149
  file_path = hf_hub_download(
150
  repo_id="saketh11/ColiFormer-Data",
151
  filename="ecoli_processed_genes.csv",
152
- repo_type="dataset"
 
153
  )
154
  df = pd.read_csv(file_path)
155
  return df['dna_sequence'].tolist()
156
  except Exception as e:
157
  st.warning(f"Could not download reference data from Hugging Face: {e}")
158
- # Fallback to minimal sequences
159
  return [
160
  "ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC",
161
  "ATGAAATTTATTTATTATTATAAATTTATTTATTATTATAAATTTATTTAT",
@@ -166,11 +162,13 @@ def download_reference_data():
166
  def download_tai_weights():
167
  """Download and cache tAI weights from Hugging Face"""
168
  try:
169
- # Download the tAI weights file from Hugging Face
 
170
  file_path = hf_hub_download(
171
  repo_id="saketh11/ColiFormer-Data",
172
  filename="organism_tai_weights.json",
173
- repo_type="dataset"
 
174
  )
175
  with open(file_path, 'r') as f:
176
  all_weights = json.load(f)
 
102
 
103
  status_text.text("Loading fine-tuned model from Hugging Face...")
104
  progress_bar.progress(50)
 
105
  try:
 
106
  from huggingface_hub import hf_hub_download
107
+ hf_token = os.environ.get("HF_TOKEN")
108
  status_text.text("⬇️ Downloading model from saketh11/ColiFormer...")
109
  model_path = hf_hub_download(
110
  repo_id="saketh11/ColiFormer",
111
  filename="balanced_alm_finetune.ckpt",
112
+ cache_dir="./hf_cache",
113
+ token=hf_token
114
  )
 
115
  status_text.text("🔄 Loading downloaded model...")
116
  st.session_state.model = load_model(
117
  model_path=model_path,
 
125
  status_text.text("Loading base model as fallback...")
126
  st.session_state.model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
127
  if isinstance(st.session_state.model, torch.nn.Module):
128
+ st.session_state.model = st.session_state.model.to(st.session_state.device)
 
 
 
129
  else:
130
  st.warning("Fallback model loaded is not a PyTorch module. Cannot move to device.")
131
  st.session_state.model_type = "base"
 
140
  def download_reference_data():
141
  """Download and cache reference data from Hugging Face"""
142
  try:
143
+ from huggingface_hub import hf_hub_download
144
+ hf_token = os.environ.get("HF_TOKEN")
145
  file_path = hf_hub_download(
146
  repo_id="saketh11/ColiFormer-Data",
147
  filename="ecoli_processed_genes.csv",
148
+ repo_type="dataset",
149
+ token=hf_token
150
  )
151
  df = pd.read_csv(file_path)
152
  return df['dna_sequence'].tolist()
153
  except Exception as e:
154
  st.warning(f"Could not download reference data from Hugging Face: {e}")
 
155
  return [
156
  "ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC",
157
  "ATGAAATTTATTTATTATTATAAATTTATTTATTATTATAAATTTATTTAT",
 
162
  def download_tai_weights():
163
  """Download and cache tAI weights from Hugging Face"""
164
  try:
165
+ from huggingface_hub import hf_hub_download
166
+ hf_token = os.environ.get("HF_TOKEN")
167
  file_path = hf_hub_download(
168
  repo_id="saketh11/ColiFormer-Data",
169
  filename="organism_tai_weights.json",
170
+ repo_type="dataset",
171
+ token=hf_token
172
  )
173
  with open(file_path, 'r') as f:
174
  all_weights = json.load(f)