MoraxCheng commited on
Commit
e809d91
·
1 Parent(s): 9957ccd

Fix resource management and memory leaks

Browse files

- Add try-finally blocks for proper model cleanup even on errors
- Fix matplotlib memory leak with proper figure cleanup (close, clf, cla)
- Limit figure size to prevent excessive memory usage
- Remove unused imports (tempfile, atexit, threading)
- Remove unused global variables (active_inferences, inference_lock)
- Add better error handling for file cleanup operations
- Add error handling for initial repository setup
- Ensure model is always deleted from memory after inference

Files changed (1) hide show
  1. app.py +86 -69
app.py CHANGED
@@ -16,9 +16,6 @@ from huggingface_hub import hf_hub_download
16
  import zipfile
17
  import shutil
18
  import uuid
19
- import tempfile
20
- import atexit
21
- import threading
22
  import gc
23
 
24
  # Add current directory to path
@@ -27,12 +24,20 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
27
  # Check if we need to download and extract the tranception module
28
  if not os.path.exists("tranception"):
29
  print("Downloading Tranception repository...")
30
- # Clone the repository structure
31
- os.system("git clone https://github.com/OATML-Markslab/Tranception.git temp_tranception")
32
- # Move the tranception module to current directory
33
- shutil.move("temp_tranception/tranception", "tranception")
34
- # Clean up
35
- shutil.rmtree("temp_tranception")
 
 
 
 
 
 
 
 
36
 
37
  import tranception
38
  from tranception import config, model_pytorch
@@ -110,7 +115,10 @@ def create_scoring_matrix_visual(scores,sequence,image_index=0,mutation_range_st
110
 
111
  # Continue with visualization
112
  mutation_range_len = mutation_range_end - mutation_range_start + 1
113
- fig, ax = plt.subplots(figsize=(50,mutation_range_len))
 
 
 
114
  scores_dict = {}
115
  valid_mutant_set=set(filtered_scores.mutant)
116
  ax.tick_params(bottom=True, top=True, left=True, right=True)
@@ -144,11 +152,15 @@ def create_scoring_matrix_visual(scores,sequence,image_index=0,mutation_range_st
144
 
145
  # Set x-axis labels (amino acids) - ensuring correct number
146
  heat.set_xticklabels(list(AA_vocab), fontsize=fontsize)
147
- plt.tight_layout()
148
- image_path = 'fitness_scoring_substitution_matrix_{}_{}.png'.format(unique_id, image_index)
149
- plt.savefig(image_path,dpi=100)
150
- plt.close()
151
- return image_path, csv_path
 
 
 
 
152
 
153
  def suggest_mutations(scores):
154
  intro_message = "The following mutations may be sensible options to improve fitness: \n\n"
@@ -181,10 +193,6 @@ def check_valid_mutant(sequence,mutant,AA_vocab=AA_vocab):
181
  if to_AA not in AA_vocab: valid=False
182
  return valid
183
 
184
- # Global variable to track active inference threads
185
- active_inferences = {}
186
- inference_lock = threading.Lock()
187
-
188
  def cleanup_old_files(max_age_minutes=30):
189
  """Clean up old inference files"""
190
  import glob
@@ -194,14 +202,20 @@ def cleanup_old_files(max_age_minutes=30):
194
  "fitness_scoring_substitution_matrix_*.csv",
195
  "all_mutations_fitness_scores_*.csv"]
196
 
 
197
  for pattern in patterns:
198
  for file_path in glob.glob(pattern):
199
  try:
200
  file_age = current_time - os.path.getmtime(file_path)
201
  if file_age > max_age_minutes * 60:
202
  os.remove(file_path)
203
- except:
204
- pass
 
 
 
 
 
205
 
206
  def get_mutated_protein(sequence,mutant):
207
  if not check_valid_mutant(sequence,mutant):
@@ -257,55 +271,58 @@ def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutat
257
  # Reduce batch size for CPU inference
258
  batch_size_inference = min(batch_size_inference, 10)
259
 
260
- model.eval()
261
- model.config.tokenizer = tokenizer
262
-
263
- all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
264
-
265
- with torch.no_grad():
266
- scores = model.score_mutants(DMS_data=all_single_mutants,
267
- target_seq=sequence,
268
- scoring_mirror=scoring_mirror,
269
- batch_size_inference=batch_size_inference,
270
- num_workers=num_workers,
271
- indel_mode=False
272
- )
273
-
274
- scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left")
275
- scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1]))
276
- scores["target_AA"] = scores["mutant"].map(lambda x: x[-1])
277
-
278
- score_heatmaps = []
279
- csv_files = []
280
- mutation_range = mutation_range_end - mutation_range_start + 1
281
- number_heatmaps = int((mutation_range - 1) / max_number_positions_per_heatmap) + 1
282
- image_index = 0
283
- window_start = mutation_range_start
284
- window_end = min(mutation_range_end,mutation_range_start+max_number_positions_per_heatmap-1)
285
-
286
- for image_index in range(number_heatmaps):
287
- image_path, csv_path = create_scoring_matrix_visual(scores,sequence,image_index,window_start,window_end,AA_vocab,unique_id=unique_id)
288
- score_heatmaps.append(image_path)
289
- csv_files.append(csv_path)
290
- window_start += max_number_positions_per_heatmap
291
- window_end = min(mutation_range_end,window_start+max_number_positions_per_heatmap-1)
292
-
293
- # Also save a comprehensive CSV with all mutations
294
- comprehensive_csv_path = 'all_mutations_fitness_scores_{}.csv'.format(unique_id)
295
- scores_export = scores[['mutant', 'position', 'target_AA', 'avg_score', 'mutated_sequence']].copy()
296
- scores_export['original_AA'] = scores_export['mutant'].str[0]
297
- scores_export = scores_export.rename(columns={'avg_score': 'fitness_score'})
298
- scores_export = scores_export[['position', 'original_AA', 'target_AA', 'mutant', 'fitness_score', 'mutated_sequence']]
299
- scores_export.to_csv(comprehensive_csv_path, index=False)
300
- csv_files.append(comprehensive_csv_path)
301
-
302
- # Clean up model from memory after inference
303
- del model
304
- gc.collect()
305
- if torch.cuda.is_available():
306
- torch.cuda.empty_cache()
307
 
308
- return score_heatmaps, suggest_mutations(scores), csv_files
 
 
 
 
 
 
309
 
310
  def extract_sequence(protein_id, taxon, sequence):
311
  return sequence
 
16
  import zipfile
17
  import shutil
18
  import uuid
 
 
 
19
  import gc
20
 
21
  # Add current directory to path
 
24
  # Check if we need to download and extract the tranception module
25
  if not os.path.exists("tranception"):
26
  print("Downloading Tranception repository...")
27
+ try:
28
+ # Clone the repository structure
29
+ result = os.system("git clone https://github.com/OATML-Markslab/Tranception.git temp_tranception")
30
+ if result != 0:
31
+ raise Exception("Failed to clone Tranception repository")
32
+ # Move the tranception module to current directory
33
+ shutil.move("temp_tranception/tranception", "tranception")
34
+ # Clean up
35
+ shutil.rmtree("temp_tranception")
36
+ except Exception as e:
37
+ print(f"Error setting up Tranception: {e}")
38
+ if os.path.exists("temp_tranception"):
39
+ shutil.rmtree("temp_tranception")
40
+ raise
41
 
42
  import tranception
43
  from tranception import config, model_pytorch
 
115
 
116
  # Continue with visualization
117
  mutation_range_len = mutation_range_end - mutation_range_start + 1
118
+ # Limit figure size to prevent memory issues
119
+ fig_width = min(50, len(AA_vocab) * 0.8)
120
+ fig_height = min(mutation_range_len, 50)
121
+ fig, ax = plt.subplots(figsize=(fig_width, fig_height))
122
  scores_dict = {}
123
  valid_mutant_set=set(filtered_scores.mutant)
124
  ax.tick_params(bottom=True, top=True, left=True, right=True)
 
152
 
153
  # Set x-axis labels (amino acids) - ensuring correct number
154
  heat.set_xticklabels(list(AA_vocab), fontsize=fontsize)
155
+ try:
156
+ plt.tight_layout()
157
+ image_path = 'fitness_scoring_substitution_matrix_{}_{}.png'.format(unique_id, image_index)
158
+ plt.savefig(image_path,dpi=100)
159
+ return image_path, csv_path
160
+ finally:
161
+ plt.close('all') # Ensure all figures are closed
162
+ plt.clf() # Clear the current figure
163
+ plt.cla() # Clear the current axes
164
 
165
  def suggest_mutations(scores):
166
  intro_message = "The following mutations may be sensible options to improve fitness: \n\n"
 
193
  if to_AA not in AA_vocab: valid=False
194
  return valid
195
 
 
 
 
 
196
  def cleanup_old_files(max_age_minutes=30):
197
  """Clean up old inference files"""
198
  import glob
 
202
  "fitness_scoring_substitution_matrix_*.csv",
203
  "all_mutations_fitness_scores_*.csv"]
204
 
205
+ cleaned_count = 0
206
  for pattern in patterns:
207
  for file_path in glob.glob(pattern):
208
  try:
209
  file_age = current_time - os.path.getmtime(file_path)
210
  if file_age > max_age_minutes * 60:
211
  os.remove(file_path)
212
+ cleaned_count += 1
213
+ except Exception as e:
214
+ # Log error but continue cleaning other files
215
+ print(f"Warning: Could not remove {file_path}: {e}")
216
+
217
+ if cleaned_count > 0:
218
+ print(f"Cleaned up {cleaned_count} old files")
219
 
220
  def get_mutated_protein(sequence,mutant):
221
  if not check_valid_mutant(sequence,mutant):
 
271
  # Reduce batch size for CPU inference
272
  batch_size_inference = min(batch_size_inference, 10)
273
 
274
+ try:
275
+ model.eval()
276
+ model.config.tokenizer = tokenizer
277
+
278
+ all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
279
+
280
+ with torch.no_grad():
281
+ scores = model.score_mutants(DMS_data=all_single_mutants,
282
+ target_seq=sequence,
283
+ scoring_mirror=scoring_mirror,
284
+ batch_size_inference=batch_size_inference,
285
+ num_workers=num_workers,
286
+ indel_mode=False
287
+ )
288
+
289
+ scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left")
290
+ scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1]))
291
+ scores["target_AA"] = scores["mutant"].map(lambda x: x[-1])
292
+
293
+ score_heatmaps = []
294
+ csv_files = []
295
+ mutation_range = mutation_range_end - mutation_range_start + 1
296
+ number_heatmaps = int((mutation_range - 1) / max_number_positions_per_heatmap) + 1
297
+ image_index = 0
298
+ window_start = mutation_range_start
299
+ window_end = min(mutation_range_end,mutation_range_start+max_number_positions_per_heatmap-1)
300
+
301
+ for image_index in range(number_heatmaps):
302
+ image_path, csv_path = create_scoring_matrix_visual(scores,sequence,image_index,window_start,window_end,AA_vocab,unique_id=unique_id)
303
+ score_heatmaps.append(image_path)
304
+ csv_files.append(csv_path)
305
+ window_start += max_number_positions_per_heatmap
306
+ window_end = min(mutation_range_end,window_start+max_number_positions_per_heatmap-1)
307
+
308
+ # Also save a comprehensive CSV with all mutations
309
+ comprehensive_csv_path = 'all_mutations_fitness_scores_{}.csv'.format(unique_id)
310
+ scores_export = scores[['mutant', 'position', 'target_AA', 'avg_score', 'mutated_sequence']].copy()
311
+ scores_export['original_AA'] = scores_export['mutant'].str[0]
312
+ scores_export = scores_export.rename(columns={'avg_score': 'fitness_score'})
313
+ scores_export = scores_export[['position', 'original_AA', 'target_AA', 'mutant', 'fitness_score', 'mutated_sequence']]
314
+ scores_export.to_csv(comprehensive_csv_path, index=False)
315
+ csv_files.append(comprehensive_csv_path)
316
+
317
+ return score_heatmaps, suggest_mutations(scores), csv_files
 
 
 
318
 
319
+ finally:
320
+ # Always clean up model from memory
321
+ if 'model' in locals():
322
+ del model
323
+ gc.collect()
324
+ if torch.cuda.is_available():
325
+ torch.cuda.empty_cache()
326
 
327
  def extract_sequence(protein_id, taxon, sequence):
328
  return sequence