MoraxCheng commited on
Commit
ac2c54a
·
1 Parent(s): 4990b34

Add resource management and cancellation support

Browse files

- Add automatic cleanup of old files (30 minutes)
- Free model memory after inference with gc.collect()
- Clear CUDA cache if available
- Configure queue with size limits and status updates
- Limit concurrent threads to 2 for better resource management
- Add proper error handling for file operations
- Prevent API access to avoid external requests

Files changed (1) hide show
  1. app.py +49 -2
app.py CHANGED
@@ -16,6 +16,10 @@ from huggingface_hub import hf_hub_download
16
  import zipfile
17
  import shutil
18
  import uuid
 
 
 
 
19
 
20
  # Add current directory to path
21
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
@@ -187,6 +191,28 @@ def check_valid_mutant(sequence,mutant,AA_vocab=AA_vocab):
187
  if to_AA not in AA_vocab: valid=False
188
  return valid
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  def get_mutated_protein(sequence,mutant):
191
  if not check_valid_mutant(sequence,mutant):
192
  return "The mutant is not valid"
@@ -195,6 +221,9 @@ def get_mutated_protein(sequence,mutant):
195
  return ''.join(mutated_sequence)
196
 
197
  def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Large",scoring_mirror=False,batch_size_inference=20,max_number_positions_per_heatmap=50,num_workers=0,AA_vocab=AA_vocab):
 
 
 
198
  # Generate unique ID for this request
199
  unique_id = str(uuid.uuid4())
200
 
@@ -275,6 +304,12 @@ def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutat
275
  scores_export = scores_export[['position', 'original_AA', 'target_AA', 'mutant', 'fitness_score', 'mutated_sequence']]
276
  scores_export.to_csv(comprehensive_csv_path, index=False)
277
  csv_files.append(comprehensive_csv_path)
 
 
 
 
 
 
278
 
279
  return score_heatmaps, suggest_mutations(scores), csv_files
280
 
@@ -383,5 +418,17 @@ with tranception_design:
383
  gr.Markdown("Links: <a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Paper</a> <a href='https://github.com/OATML-Markslab/Tranception' target='_blank'>Code</a> <a href='https://sites.google.com/view/proteingym/substitutions' target='_blank'>ProteinGym</a>")
384
 
385
  if __name__ == "__main__":
386
- tranception_design.queue()
387
- tranception_design.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import zipfile
17
  import shutil
18
  import uuid
19
+ import tempfile
20
+ import atexit
21
+ import threading
22
+ import gc
23
 
24
  # Add current directory to path
25
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
191
  if to_AA not in AA_vocab: valid=False
192
  return valid
193
 
194
+ # Global variable to track active inference threads
195
+ active_inferences = {}
196
+ inference_lock = threading.Lock()
197
+
198
+ def cleanup_old_files(max_age_minutes=30):
199
+ """Clean up old inference files"""
200
+ import glob
201
+ import time
202
+ current_time = time.time()
203
+ patterns = ["fitness_scoring_substitution_matrix_*.png",
204
+ "fitness_scoring_substitution_matrix_*.csv",
205
+ "all_mutations_fitness_scores_*.csv"]
206
+
207
+ for pattern in patterns:
208
+ for file_path in glob.glob(pattern):
209
+ try:
210
+ file_age = current_time - os.path.getmtime(file_path)
211
+ if file_age > max_age_minutes * 60:
212
+ os.remove(file_path)
213
+ except:
214
+ pass
215
+
216
  def get_mutated_protein(sequence,mutant):
217
  if not check_valid_mutant(sequence,mutant):
218
  return "The mutant is not valid"
 
221
  return ''.join(mutated_sequence)
222
 
223
  def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Large",scoring_mirror=False,batch_size_inference=20,max_number_positions_per_heatmap=50,num_workers=0,AA_vocab=AA_vocab):
224
+ # Clean up old files periodically
225
+ cleanup_old_files()
226
+
227
  # Generate unique ID for this request
228
  unique_id = str(uuid.uuid4())
229
 
 
304
  scores_export = scores_export[['position', 'original_AA', 'target_AA', 'mutant', 'fitness_score', 'mutated_sequence']]
305
  scores_export.to_csv(comprehensive_csv_path, index=False)
306
  csv_files.append(comprehensive_csv_path)
307
+
308
+ # Clean up model from memory after inference
309
+ del model
310
+ gc.collect()
311
+ if torch.cuda.is_available():
312
+ torch.cuda.empty_cache()
313
 
314
  return score_heatmaps, suggest_mutations(scores), csv_files
315
 
 
418
  gr.Markdown("Links: <a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Paper</a> <a href='https://github.com/OATML-Markslab/Tranception' target='_blank'>Code</a> <a href='https://sites.google.com/view/proteingym/substitutions' target='_blank'>ProteinGym</a>")
419
 
420
  if __name__ == "__main__":
421
+ # Configure queue for better resource management
422
+ tranception_design.queue(
423
+ max_size=10, # Limit queue size
424
+ status_update_rate="auto", # Show status updates
425
+ api_open=False # Disable API to prevent external requests
426
+ )
427
+
428
+ # Launch with appropriate settings for HF Spaces
429
+ tranception_design.launch(
430
+ max_threads=2, # Limit concurrent threads
431
+ show_error=True,
432
+ server_name="0.0.0.0",
433
+ server_port=7860
434
+ )