Spaces:
Running
on
Zero
Running
on
Zero
Anton Bushuiev
commited on
Commit
·
2af3eaa
1
Parent(s):
e0dc24a
Make modified cosine optional
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ License: MIT
|
|
11 |
|
12 |
import gradio as gr
|
13 |
import spaces
|
|
|
14 |
import urllib.request
|
15 |
from datetime import datetime
|
16 |
from functools import partial
|
@@ -300,17 +301,17 @@ def _predict_gpu(in_pth, progress):
|
|
300 |
Returns:
|
301 |
numpy.ndarray: DreaMS embeddings
|
302 |
"""
|
303 |
-
progress(0.
|
304 |
msdata = MSData.load(in_pth)
|
305 |
|
306 |
-
progress(0.
|
307 |
embs = dreams_embeddings(msdata)
|
308 |
print(f'Shape of the query embeddings: {embs.shape}')
|
309 |
|
310 |
return embs
|
311 |
|
312 |
|
313 |
-
def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs):
|
314 |
"""
|
315 |
Create a single result row for the DataFrame
|
316 |
|
@@ -323,6 +324,7 @@ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs):
|
|
323 |
sims: Similarity matrix
|
324 |
cos_sim: Cosine similarity calculator
|
325 |
embs: Query embeddings
|
|
|
326 |
|
327 |
Returns:
|
328 |
dict: Result row data
|
@@ -331,7 +333,8 @@ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs):
|
|
331 |
spec1 = msdata.get_spectra(i)
|
332 |
spec2 = msdata_lib.get_spectra(j)
|
333 |
|
334 |
-
|
|
|
335 |
'feature_id': i + 1,
|
336 |
'precursor_mz': msdata.get_prec_mzs(i),
|
337 |
'topk': n + 1,
|
@@ -342,25 +345,32 @@ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs):
|
|
342 |
'Spectrum_raw': su.unpad_peak_list(spec1),
|
343 |
'library_ID': msdata_lib.get_values('IDENTIFIER', j),
|
344 |
'DreaMS_similarity': sims[i, j],
|
345 |
-
'Modified_cosine_similarity': cos_sim(
|
346 |
-
spec1=spec1,
|
347 |
-
prec_mz1=msdata.get_prec_mzs(i),
|
348 |
-
spec2=spec2,
|
349 |
-
prec_mz2=msdata_lib.get_prec_mzs(j),
|
350 |
-
),
|
351 |
'i': i,
|
352 |
'j': j,
|
353 |
'DreaMS_embedding': embs[i],
|
354 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
|
357 |
-
def _process_results_dataframe(df, in_pth):
|
358 |
"""
|
359 |
Process and clean the results DataFrame
|
360 |
|
361 |
Args:
|
362 |
df: Raw results DataFrame
|
363 |
in_pth: Input file path for CSV export
|
|
|
364 |
|
365 |
Returns:
|
366 |
tuple: (processed_df, csv_path)
|
@@ -372,7 +382,11 @@ def _process_results_dataframe(df, in_pth):
|
|
372 |
# Remove unnecessary columns and round similarity scores
|
373 |
df = df.drop(columns=['i', 'j', 'library_j'])
|
374 |
df['DreaMS_similarity'] = df['DreaMS_similarity'].astype(float).round(4)
|
375 |
-
|
|
|
|
|
|
|
|
|
376 |
df['precursor_mz'] = df['precursor_mz'].astype(float).round(4)
|
377 |
|
378 |
# Rename columns for display
|
@@ -386,9 +400,13 @@ def _process_results_dataframe(df, in_pth):
|
|
386 |
"Spectrum": "Spectrum",
|
387 |
"Spectrum_raw": "Input Spectrum",
|
388 |
"DreaMS_similarity": "DreaMS similarity",
|
389 |
-
"Modified_cosine_similarity": "Modified cos similarity",
|
390 |
"DreaMS_embedding": "DreaMS embedding",
|
391 |
}
|
|
|
|
|
|
|
|
|
|
|
392 |
df = df.rename(columns=column_mapping)
|
393 |
|
394 |
# Save full results to CSV
|
@@ -409,13 +427,14 @@ def _process_results_dataframe(df, in_pth):
|
|
409 |
return df, str(df_path)
|
410 |
|
411 |
|
412 |
-
def _predict_core(lib_pth, in_pth, progress):
|
413 |
"""
|
414 |
Core prediction function that orchestrates the entire prediction pipeline
|
415 |
|
416 |
Args:
|
417 |
lib_pth: Library file path
|
418 |
in_pth: Input file path
|
|
|
419 |
progress: Gradio progress tracker
|
420 |
|
421 |
Returns:
|
@@ -425,65 +444,77 @@ def _predict_core(lib_pth, in_pth, progress):
|
|
425 |
|
426 |
# Clear cache at start to prevent memory buildup
|
427 |
clear_smiles_cache()
|
428 |
-
|
429 |
-
#
|
430 |
-
progress(0, desc="
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
sims = cosine_similarity(embs, embs_lib)
|
441 |
-
print(f'Shape of the similarity matrix: {sims.shape}')
|
442 |
-
|
443 |
-
# Get top-k candidates
|
444 |
-
k = 1
|
445 |
-
topk_cands = np.argsort(sims, axis=1)[:, -k:][:, ::-1]
|
446 |
-
|
447 |
-
# Load query data for processing
|
448 |
-
msdata = MSData.load(in_pth)
|
449 |
-
print(f'Available columns: {msdata.columns()}')
|
450 |
-
|
451 |
-
# Construct results DataFrame
|
452 |
-
progress(0.5, desc="Constructing results table...")
|
453 |
-
df = []
|
454 |
-
cos_sim = su.PeakListModifiedCosine()
|
455 |
-
total_spectra = len(topk_cands)
|
456 |
-
|
457 |
-
for i, topk in enumerate(topk_cands):
|
458 |
-
progress(0.5 + 0.4 * (i / total_spectra),
|
459 |
-
desc=f"Processing hits for spectrum {i+1}/{total_spectra}...")
|
460 |
|
461 |
-
|
462 |
-
|
463 |
-
df.append(row_data)
|
464 |
|
465 |
-
#
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
|
477 |
-
|
|
|
|
|
|
|
478 |
|
479 |
|
480 |
-
def predict(lib_pth, in_pth, progress=gr.Progress(track_tqdm=True)):
|
481 |
"""
|
482 |
Main prediction function with error handling
|
483 |
|
484 |
Args:
|
485 |
lib_pth: Library file path
|
486 |
in_pth: Input file path
|
|
|
487 |
progress: Gradio progress tracker
|
488 |
|
489 |
Returns:
|
@@ -501,7 +532,9 @@ def predict(lib_pth, in_pth, progress=gr.Progress(track_tqdm=True)):
|
|
501 |
if not Path(lib_pth).exists():
|
502 |
raise gr.Error("Spectral library not found. Please ensure the library file exists.")
|
503 |
|
504 |
-
|
|
|
|
|
505 |
|
506 |
except gr.Error:
|
507 |
# Re-raise Gradio errors as-is
|
@@ -573,20 +606,28 @@ def _create_gradio_interface():
|
|
573 |
label="Examples (click on a file to load as input)",
|
574 |
)
|
575 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
576 |
# Prediction button
|
577 |
predict_button = gr.Button(value="Run DreaMS", variant="primary")
|
578 |
|
579 |
-
#
|
580 |
gr.Markdown("## Predictions")
|
581 |
df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
|
582 |
|
583 |
# Results table
|
584 |
df = gr.Dataframe(
|
585 |
headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum",
|
586 |
-
"Library ID", "DreaMS similarity"
|
587 |
-
datatype=["number", "number", "number", "html", "html", "str", "number"
|
588 |
-
col_count=(
|
589 |
-
column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px"
|
590 |
max_height=1000,
|
591 |
show_fullscreen_button=True,
|
592 |
show_row_numbers=False,
|
@@ -594,8 +635,29 @@ def _create_gradio_interface():
|
|
594 |
)
|
595 |
|
596 |
# Connect prediction logic
|
597 |
-
inputs = [in_pth]
|
598 |
outputs = [df, df_file]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
predict_func = partial(predict, LIBRARY_PATH)
|
600 |
predict_button.click(predict_func, inputs=inputs, outputs=outputs, show_progress="first")
|
601 |
|
|
|
11 |
|
12 |
import gradio as gr
|
13 |
import spaces
|
14 |
+
import shutil
|
15 |
import urllib.request
|
16 |
from datetime import datetime
|
17 |
from functools import partial
|
|
|
301 |
Returns:
|
302 |
numpy.ndarray: DreaMS embeddings
|
303 |
"""
|
304 |
+
progress(0.2, desc="Loading spectra data...")
|
305 |
msdata = MSData.load(in_pth)
|
306 |
|
307 |
+
progress(0.3, desc="Computing DreaMS embeddings...")
|
308 |
embs = dreams_embeddings(msdata)
|
309 |
print(f'Shape of the query embeddings: {embs.shape}')
|
310 |
|
311 |
return embs
|
312 |
|
313 |
|
314 |
+
def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, calculate_modified_cosine=False):
|
315 |
"""
|
316 |
Create a single result row for the DataFrame
|
317 |
|
|
|
324 |
sims: Similarity matrix
|
325 |
cos_sim: Cosine similarity calculator
|
326 |
embs: Query embeddings
|
327 |
+
calculate_modified_cosine: Whether to calculate modified cosine similarity
|
328 |
|
329 |
Returns:
|
330 |
dict: Result row data
|
|
|
333 |
spec1 = msdata.get_spectra(i)
|
334 |
spec2 = msdata_lib.get_spectra(j)
|
335 |
|
336 |
+
# Base row data
|
337 |
+
row_data = {
|
338 |
'feature_id': i + 1,
|
339 |
'precursor_mz': msdata.get_prec_mzs(i),
|
340 |
'topk': n + 1,
|
|
|
345 |
'Spectrum_raw': su.unpad_peak_list(spec1),
|
346 |
'library_ID': msdata_lib.get_values('IDENTIFIER', j),
|
347 |
'DreaMS_similarity': sims[i, j],
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
'i': i,
|
349 |
'j': j,
|
350 |
'DreaMS_embedding': embs[i],
|
351 |
}
|
352 |
+
|
353 |
+
# Add modified cosine similarity only if enabled
|
354 |
+
if calculate_modified_cosine:
|
355 |
+
modified_cosine_sim = cos_sim(
|
356 |
+
spec1=spec1,
|
357 |
+
prec_mz1=msdata.get_prec_mzs(i),
|
358 |
+
spec2=spec2,
|
359 |
+
prec_mz2=msdata_lib.get_prec_mzs(j),
|
360 |
+
)
|
361 |
+
row_data['Modified_cosine_similarity'] = modified_cosine_sim
|
362 |
+
|
363 |
+
return row_data
|
364 |
|
365 |
|
366 |
+
def _process_results_dataframe(df, in_pth, calculate_modified_cosine=False):
|
367 |
"""
|
368 |
Process and clean the results DataFrame
|
369 |
|
370 |
Args:
|
371 |
df: Raw results DataFrame
|
372 |
in_pth: Input file path for CSV export
|
373 |
+
calculate_modified_cosine: Whether modified cosine similarity was calculated
|
374 |
|
375 |
Returns:
|
376 |
tuple: (processed_df, csv_path)
|
|
|
382 |
# Remove unnecessary columns and round similarity scores
|
383 |
df = df.drop(columns=['i', 'j', 'library_j'])
|
384 |
df['DreaMS_similarity'] = df['DreaMS_similarity'].astype(float).round(4)
|
385 |
+
|
386 |
+
# Handle modified cosine similarity column conditionally
|
387 |
+
if calculate_modified_cosine and 'Modified_cosine_similarity' in df.columns:
|
388 |
+
df['Modified_cosine_similarity'] = df['Modified_cosine_similarity'].astype(float).round(4)
|
389 |
+
|
390 |
df['precursor_mz'] = df['precursor_mz'].astype(float).round(4)
|
391 |
|
392 |
# Rename columns for display
|
|
|
400 |
"Spectrum": "Spectrum",
|
401 |
"Spectrum_raw": "Input Spectrum",
|
402 |
"DreaMS_similarity": "DreaMS similarity",
|
|
|
403 |
"DreaMS_embedding": "DreaMS embedding",
|
404 |
}
|
405 |
+
|
406 |
+
# Add modified cosine similarity to column mapping only if it exists
|
407 |
+
if calculate_modified_cosine and 'Modified_cosine_similarity' in df.columns:
|
408 |
+
column_mapping["Modified_cosine_similarity"] = "Modified cos similarity"
|
409 |
+
|
410 |
df = df.rename(columns=column_mapping)
|
411 |
|
412 |
# Save full results to CSV
|
|
|
427 |
return df, str(df_path)
|
428 |
|
429 |
|
430 |
+
def _predict_core(lib_pth, in_pth, calculate_modified_cosine, progress):
|
431 |
"""
|
432 |
Core prediction function that orchestrates the entire prediction pipeline
|
433 |
|
434 |
Args:
|
435 |
lib_pth: Library file path
|
436 |
in_pth: Input file path
|
437 |
+
calculate_modified_cosine: Whether to calculate modified cosine similarity
|
438 |
progress: Gradio progress tracker
|
439 |
|
440 |
Returns:
|
|
|
444 |
|
445 |
# Clear cache at start to prevent memory buildup
|
446 |
clear_smiles_cache()
|
447 |
+
|
448 |
+
# Create temporary copy of library file to allow multiple processes
|
449 |
+
progress(0, desc="Creating temporary library copy...")
|
450 |
+
temp_lib_path = Path(lib_pth).parent / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{Path(lib_pth).name}"
|
451 |
+
shutil.copy2(lib_pth, temp_lib_path)
|
452 |
+
|
453 |
+
try:
|
454 |
+
# Load library data
|
455 |
+
progress(0.1, desc="Loading library data...")
|
456 |
+
msdata_lib = MSData.load(temp_lib_path)
|
457 |
+
embs_lib = msdata_lib[DREAMS_EMBEDDING]
|
458 |
+
print(f'Shape of the library embeddings: {embs_lib.shape}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
|
460 |
+
# Get query embeddings
|
461 |
+
embs = _predict_gpu(in_pth, progress)
|
|
|
462 |
|
463 |
+
# Compute similarity matrix
|
464 |
+
progress(0.4, desc="Computing similarity matrix...")
|
465 |
+
sims = cosine_similarity(embs, embs_lib)
|
466 |
+
print(f'Shape of the similarity matrix: {sims.shape}')
|
467 |
+
|
468 |
+
# Get top-k candidates
|
469 |
+
k = 1
|
470 |
+
topk_cands = np.argsort(sims, axis=1)[:, -k:][:, ::-1]
|
471 |
+
|
472 |
+
# Load query data for processing
|
473 |
+
msdata = MSData.load(in_pth)
|
474 |
+
print(f'Available columns: {msdata.columns()}')
|
475 |
+
|
476 |
+
# Construct results DataFrame
|
477 |
+
progress(0.5, desc="Constructing results table...")
|
478 |
+
df = []
|
479 |
+
cos_sim = su.PeakListModifiedCosine()
|
480 |
+
total_spectra = len(topk_cands)
|
481 |
+
|
482 |
+
for i, topk in enumerate(topk_cands):
|
483 |
+
progress(0.5 + 0.4 * (i / total_spectra),
|
484 |
+
desc=f"Processing hits for spectrum {i+1}/{total_spectra}...")
|
485 |
+
|
486 |
+
for n, j in enumerate(topk):
|
487 |
+
row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, calculate_modified_cosine)
|
488 |
+
df.append(row_data)
|
489 |
+
|
490 |
+
# Clear cache every 100 spectra to prevent memory buildup
|
491 |
+
if (i + 1) % 100 == 0:
|
492 |
+
clear_smiles_cache()
|
493 |
+
|
494 |
+
df = pd.DataFrame(df)
|
495 |
+
|
496 |
+
# Process and clean results
|
497 |
+
progress(0.9, desc="Post-processing results...")
|
498 |
+
df, csv_path = _process_results_dataframe(df, in_pth, calculate_modified_cosine)
|
499 |
+
|
500 |
+
progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
|
501 |
+
|
502 |
+
return df, csv_path
|
503 |
|
504 |
+
finally:
|
505 |
+
# Clean up temporary library file
|
506 |
+
if temp_lib_path.exists():
|
507 |
+
temp_lib_path.unlink()
|
508 |
|
509 |
|
510 |
+
def predict(lib_pth, in_pth, calculate_modified_cosine=False, progress=gr.Progress(track_tqdm=True)):
|
511 |
"""
|
512 |
Main prediction function with error handling
|
513 |
|
514 |
Args:
|
515 |
lib_pth: Library file path
|
516 |
in_pth: Input file path
|
517 |
+
calculate_modified_cosine: Whether to calculate modified cosine similarity
|
518 |
progress: Gradio progress tracker
|
519 |
|
520 |
Returns:
|
|
|
532 |
if not Path(lib_pth).exists():
|
533 |
raise gr.Error("Spectral library not found. Please ensure the library file exists.")
|
534 |
|
535 |
+
df, csv_path = _predict_core(lib_pth, in_pth, calculate_modified_cosine, progress)
|
536 |
+
|
537 |
+
return df, csv_path
|
538 |
|
539 |
except gr.Error:
|
540 |
# Re-raise Gradio errors as-is
|
|
|
606 |
label="Examples (click on a file to load as input)",
|
607 |
)
|
608 |
|
609 |
+
# Settings section
|
610 |
+
with gr.Accordion("⚙️ Settings", open=False):
|
611 |
+
calculate_modified_cosine = gr.Checkbox(
|
612 |
+
label="Calculate modified cosine similarity",
|
613 |
+
value=False,
|
614 |
+
info="Enable to calculate traditional modified cosine similarity scores (slower)"
|
615 |
+
)
|
616 |
+
|
617 |
# Prediction button
|
618 |
predict_button = gr.Button(value="Run DreaMS", variant="primary")
|
619 |
|
620 |
+
# Results table
|
621 |
gr.Markdown("## Predictions")
|
622 |
df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
|
623 |
|
624 |
# Results table
|
625 |
df = gr.Dataframe(
|
626 |
headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum",
|
627 |
+
"Library ID", "DreaMS similarity"],
|
628 |
+
datatype=["number", "number", "number", "html", "html", "str", "number"],
|
629 |
+
col_count=(7, "fixed"),
|
630 |
+
column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px"],
|
631 |
max_height=1000,
|
632 |
show_fullscreen_button=True,
|
633 |
show_row_numbers=False,
|
|
|
635 |
)
|
636 |
|
637 |
# Connect prediction logic
|
638 |
+
inputs = [in_pth, calculate_modified_cosine]
|
639 |
outputs = [df, df_file]
|
640 |
+
|
641 |
+
# Function to update dataframe headers based on setting
|
642 |
+
def update_headers(show_cosine):
|
643 |
+
if show_cosine:
|
644 |
+
return gr.update(headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum",
|
645 |
+
"Library ID", "DreaMS similarity", "Modified cosine similarity"],
|
646 |
+
col_count=(8, "fixed"),
|
647 |
+
column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px", "40px"])
|
648 |
+
else:
|
649 |
+
return gr.update(headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum",
|
650 |
+
"Library ID", "DreaMS similarity"],
|
651 |
+
col_count=(7, "fixed"),
|
652 |
+
column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px"])
|
653 |
+
|
654 |
+
# Update headers when setting changes
|
655 |
+
calculate_modified_cosine.change(
|
656 |
+
fn=update_headers,
|
657 |
+
inputs=[calculate_modified_cosine],
|
658 |
+
outputs=[df]
|
659 |
+
)
|
660 |
+
|
661 |
predict_func = partial(predict, LIBRARY_PATH)
|
662 |
predict_button.click(predict_func, inputs=inputs, outputs=outputs, show_progress="first")
|
663 |
|