Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
e809d91
1
Parent(s):
9957ccd
Fix resource management and memory leaks
Browse files- Add try-finally blocks for proper model cleanup even on errors
- Fix matplotlib memory leak with proper figure cleanup (close, clf, cla)
- Limit figure size to prevent excessive memory usage
- Remove unused imports (tempfile, atexit, threading)
- Remove unused global variables (active_inferences, inference_lock)
- Add better error handling for file cleanup operations
- Add error handling for initial repository setup
- Ensure model is always deleted from memory after inference
app.py
CHANGED
@@ -16,9 +16,6 @@ from huggingface_hub import hf_hub_download
|
|
16 |
import zipfile
|
17 |
import shutil
|
18 |
import uuid
|
19 |
-
import tempfile
|
20 |
-
import atexit
|
21 |
-
import threading
|
22 |
import gc
|
23 |
|
24 |
# Add current directory to path
|
@@ -27,12 +24,20 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
27 |
# Check if we need to download and extract the tranception module
|
28 |
if not os.path.exists("tranception"):
|
29 |
print("Downloading Tranception repository...")
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
import tranception
|
38 |
from tranception import config, model_pytorch
|
@@ -110,7 +115,10 @@ def create_scoring_matrix_visual(scores,sequence,image_index=0,mutation_range_st
|
|
110 |
|
111 |
# Continue with visualization
|
112 |
mutation_range_len = mutation_range_end - mutation_range_start + 1
|
113 |
-
|
|
|
|
|
|
|
114 |
scores_dict = {}
|
115 |
valid_mutant_set=set(filtered_scores.mutant)
|
116 |
ax.tick_params(bottom=True, top=True, left=True, right=True)
|
@@ -144,11 +152,15 @@ def create_scoring_matrix_visual(scores,sequence,image_index=0,mutation_range_st
|
|
144 |
|
145 |
# Set x-axis labels (amino acids) - ensuring correct number
|
146 |
heat.set_xticklabels(list(AA_vocab), fontsize=fontsize)
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
152 |
|
153 |
def suggest_mutations(scores):
|
154 |
intro_message = "The following mutations may be sensible options to improve fitness: \n\n"
|
@@ -181,10 +193,6 @@ def check_valid_mutant(sequence,mutant,AA_vocab=AA_vocab):
|
|
181 |
if to_AA not in AA_vocab: valid=False
|
182 |
return valid
|
183 |
|
184 |
-
# Global variable to track active inference threads
|
185 |
-
active_inferences = {}
|
186 |
-
inference_lock = threading.Lock()
|
187 |
-
|
188 |
def cleanup_old_files(max_age_minutes=30):
|
189 |
"""Clean up old inference files"""
|
190 |
import glob
|
@@ -194,14 +202,20 @@ def cleanup_old_files(max_age_minutes=30):
|
|
194 |
"fitness_scoring_substitution_matrix_*.csv",
|
195 |
"all_mutations_fitness_scores_*.csv"]
|
196 |
|
|
|
197 |
for pattern in patterns:
|
198 |
for file_path in glob.glob(pattern):
|
199 |
try:
|
200 |
file_age = current_time - os.path.getmtime(file_path)
|
201 |
if file_age > max_age_minutes * 60:
|
202 |
os.remove(file_path)
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
def get_mutated_protein(sequence,mutant):
|
207 |
if not check_valid_mutant(sequence,mutant):
|
@@ -257,55 +271,58 @@ def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutat
|
|
257 |
# Reduce batch size for CPU inference
|
258 |
batch_size_inference = min(batch_size_inference, 10)
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
gc.collect()
|
305 |
-
if torch.cuda.is_available():
|
306 |
-
torch.cuda.empty_cache()
|
307 |
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
def extract_sequence(protein_id, taxon, sequence):
|
311 |
return sequence
|
|
|
16 |
import zipfile
|
17 |
import shutil
|
18 |
import uuid
|
|
|
|
|
|
|
19 |
import gc
|
20 |
|
21 |
# Add current directory to path
|
|
|
24 |
# Check if we need to download and extract the tranception module
|
25 |
if not os.path.exists("tranception"):
|
26 |
print("Downloading Tranception repository...")
|
27 |
+
try:
|
28 |
+
# Clone the repository structure
|
29 |
+
result = os.system("git clone https://github.com/OATML-Markslab/Tranception.git temp_tranception")
|
30 |
+
if result != 0:
|
31 |
+
raise Exception("Failed to clone Tranception repository")
|
32 |
+
# Move the tranception module to current directory
|
33 |
+
shutil.move("temp_tranception/tranception", "tranception")
|
34 |
+
# Clean up
|
35 |
+
shutil.rmtree("temp_tranception")
|
36 |
+
except Exception as e:
|
37 |
+
print(f"Error setting up Tranception: {e}")
|
38 |
+
if os.path.exists("temp_tranception"):
|
39 |
+
shutil.rmtree("temp_tranception")
|
40 |
+
raise
|
41 |
|
42 |
import tranception
|
43 |
from tranception import config, model_pytorch
|
|
|
115 |
|
116 |
# Continue with visualization
|
117 |
mutation_range_len = mutation_range_end - mutation_range_start + 1
|
118 |
+
# Limit figure size to prevent memory issues
|
119 |
+
fig_width = min(50, len(AA_vocab) * 0.8)
|
120 |
+
fig_height = min(mutation_range_len, 50)
|
121 |
+
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
|
122 |
scores_dict = {}
|
123 |
valid_mutant_set=set(filtered_scores.mutant)
|
124 |
ax.tick_params(bottom=True, top=True, left=True, right=True)
|
|
|
152 |
|
153 |
# Set x-axis labels (amino acids) - ensuring correct number
|
154 |
heat.set_xticklabels(list(AA_vocab), fontsize=fontsize)
|
155 |
+
try:
|
156 |
+
plt.tight_layout()
|
157 |
+
image_path = 'fitness_scoring_substitution_matrix_{}_{}.png'.format(unique_id, image_index)
|
158 |
+
plt.savefig(image_path,dpi=100)
|
159 |
+
return image_path, csv_path
|
160 |
+
finally:
|
161 |
+
plt.close('all') # Ensure all figures are closed
|
162 |
+
plt.clf() # Clear the current figure
|
163 |
+
plt.cla() # Clear the current axes
|
164 |
|
165 |
def suggest_mutations(scores):
|
166 |
intro_message = "The following mutations may be sensible options to improve fitness: \n\n"
|
|
|
193 |
if to_AA not in AA_vocab: valid=False
|
194 |
return valid
|
195 |
|
|
|
|
|
|
|
|
|
196 |
def cleanup_old_files(max_age_minutes=30):
|
197 |
"""Clean up old inference files"""
|
198 |
import glob
|
|
|
202 |
"fitness_scoring_substitution_matrix_*.csv",
|
203 |
"all_mutations_fitness_scores_*.csv"]
|
204 |
|
205 |
+
cleaned_count = 0
|
206 |
for pattern in patterns:
|
207 |
for file_path in glob.glob(pattern):
|
208 |
try:
|
209 |
file_age = current_time - os.path.getmtime(file_path)
|
210 |
if file_age > max_age_minutes * 60:
|
211 |
os.remove(file_path)
|
212 |
+
cleaned_count += 1
|
213 |
+
except Exception as e:
|
214 |
+
# Log error but continue cleaning other files
|
215 |
+
print(f"Warning: Could not remove {file_path}: {e}")
|
216 |
+
|
217 |
+
if cleaned_count > 0:
|
218 |
+
print(f"Cleaned up {cleaned_count} old files")
|
219 |
|
220 |
def get_mutated_protein(sequence,mutant):
|
221 |
if not check_valid_mutant(sequence,mutant):
|
|
|
271 |
# Reduce batch size for CPU inference
|
272 |
batch_size_inference = min(batch_size_inference, 10)
|
273 |
|
274 |
+
try:
|
275 |
+
model.eval()
|
276 |
+
model.config.tokenizer = tokenizer
|
277 |
+
|
278 |
+
all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
|
279 |
+
|
280 |
+
with torch.no_grad():
|
281 |
+
scores = model.score_mutants(DMS_data=all_single_mutants,
|
282 |
+
target_seq=sequence,
|
283 |
+
scoring_mirror=scoring_mirror,
|
284 |
+
batch_size_inference=batch_size_inference,
|
285 |
+
num_workers=num_workers,
|
286 |
+
indel_mode=False
|
287 |
+
)
|
288 |
+
|
289 |
+
scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left")
|
290 |
+
scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1]))
|
291 |
+
scores["target_AA"] = scores["mutant"].map(lambda x: x[-1])
|
292 |
+
|
293 |
+
score_heatmaps = []
|
294 |
+
csv_files = []
|
295 |
+
mutation_range = mutation_range_end - mutation_range_start + 1
|
296 |
+
number_heatmaps = int((mutation_range - 1) / max_number_positions_per_heatmap) + 1
|
297 |
+
image_index = 0
|
298 |
+
window_start = mutation_range_start
|
299 |
+
window_end = min(mutation_range_end,mutation_range_start+max_number_positions_per_heatmap-1)
|
300 |
+
|
301 |
+
for image_index in range(number_heatmaps):
|
302 |
+
image_path, csv_path = create_scoring_matrix_visual(scores,sequence,image_index,window_start,window_end,AA_vocab,unique_id=unique_id)
|
303 |
+
score_heatmaps.append(image_path)
|
304 |
+
csv_files.append(csv_path)
|
305 |
+
window_start += max_number_positions_per_heatmap
|
306 |
+
window_end = min(mutation_range_end,window_start+max_number_positions_per_heatmap-1)
|
307 |
+
|
308 |
+
# Also save a comprehensive CSV with all mutations
|
309 |
+
comprehensive_csv_path = 'all_mutations_fitness_scores_{}.csv'.format(unique_id)
|
310 |
+
scores_export = scores[['mutant', 'position', 'target_AA', 'avg_score', 'mutated_sequence']].copy()
|
311 |
+
scores_export['original_AA'] = scores_export['mutant'].str[0]
|
312 |
+
scores_export = scores_export.rename(columns={'avg_score': 'fitness_score'})
|
313 |
+
scores_export = scores_export[['position', 'original_AA', 'target_AA', 'mutant', 'fitness_score', 'mutated_sequence']]
|
314 |
+
scores_export.to_csv(comprehensive_csv_path, index=False)
|
315 |
+
csv_files.append(comprehensive_csv_path)
|
316 |
+
|
317 |
+
return score_heatmaps, suggest_mutations(scores), csv_files
|
|
|
|
|
|
|
318 |
|
319 |
+
finally:
|
320 |
+
# Always clean up model from memory
|
321 |
+
if 'model' in locals():
|
322 |
+
del model
|
323 |
+
gc.collect()
|
324 |
+
if torch.cuda.is_available():
|
325 |
+
torch.cuda.empty_cache()
|
326 |
|
327 |
def extract_sequence(protein_id, taxon, sequence):
|
328 |
return sequence
|