Anton Bushuiev commited on
Commit
2af3eaa
·
1 Parent(s): e0dc24a

Make modified cosine optional

Browse files
Files changed (1) hide show
  1. app.py +131 -69
app.py CHANGED
@@ -11,6 +11,7 @@ License: MIT
11
 
12
  import gradio as gr
13
  import spaces
 
14
  import urllib.request
15
  from datetime import datetime
16
  from functools import partial
@@ -300,17 +301,17 @@ def _predict_gpu(in_pth, progress):
300
  Returns:
301
  numpy.ndarray: DreaMS embeddings
302
  """
303
- progress(0.1, desc="Loading spectra data...")
304
  msdata = MSData.load(in_pth)
305
 
306
- progress(0.2, desc="Computing DreaMS embeddings...")
307
  embs = dreams_embeddings(msdata)
308
  print(f'Shape of the query embeddings: {embs.shape}')
309
 
310
  return embs
311
 
312
 
313
- def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs):
314
  """
315
  Create a single result row for the DataFrame
316
 
@@ -323,6 +324,7 @@ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs):
323
  sims: Similarity matrix
324
  cos_sim: Cosine similarity calculator
325
  embs: Query embeddings
 
326
 
327
  Returns:
328
  dict: Result row data
@@ -331,7 +333,8 @@ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs):
331
  spec1 = msdata.get_spectra(i)
332
  spec2 = msdata_lib.get_spectra(j)
333
 
334
- return {
 
335
  'feature_id': i + 1,
336
  'precursor_mz': msdata.get_prec_mzs(i),
337
  'topk': n + 1,
@@ -342,25 +345,32 @@ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs):
342
  'Spectrum_raw': su.unpad_peak_list(spec1),
343
  'library_ID': msdata_lib.get_values('IDENTIFIER', j),
344
  'DreaMS_similarity': sims[i, j],
345
- 'Modified_cosine_similarity': cos_sim(
346
- spec1=spec1,
347
- prec_mz1=msdata.get_prec_mzs(i),
348
- spec2=spec2,
349
- prec_mz2=msdata_lib.get_prec_mzs(j),
350
- ),
351
  'i': i,
352
  'j': j,
353
  'DreaMS_embedding': embs[i],
354
  }
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
 
357
- def _process_results_dataframe(df, in_pth):
358
  """
359
  Process and clean the results DataFrame
360
 
361
  Args:
362
  df: Raw results DataFrame
363
  in_pth: Input file path for CSV export
 
364
 
365
  Returns:
366
  tuple: (processed_df, csv_path)
@@ -372,7 +382,11 @@ def _process_results_dataframe(df, in_pth):
372
  # Remove unnecessary columns and round similarity scores
373
  df = df.drop(columns=['i', 'j', 'library_j'])
374
  df['DreaMS_similarity'] = df['DreaMS_similarity'].astype(float).round(4)
375
- df['Modified_cosine_similarity'] = df['Modified_cosine_similarity'].astype(float).round(4)
 
 
 
 
376
  df['precursor_mz'] = df['precursor_mz'].astype(float).round(4)
377
 
378
  # Rename columns for display
@@ -386,9 +400,13 @@ def _process_results_dataframe(df, in_pth):
386
  "Spectrum": "Spectrum",
387
  "Spectrum_raw": "Input Spectrum",
388
  "DreaMS_similarity": "DreaMS similarity",
389
- "Modified_cosine_similarity": "Modified cos similarity",
390
  "DreaMS_embedding": "DreaMS embedding",
391
  }
 
 
 
 
 
392
  df = df.rename(columns=column_mapping)
393
 
394
  # Save full results to CSV
@@ -409,13 +427,14 @@ def _process_results_dataframe(df, in_pth):
409
  return df, str(df_path)
410
 
411
 
412
- def _predict_core(lib_pth, in_pth, progress):
413
  """
414
  Core prediction function that orchestrates the entire prediction pipeline
415
 
416
  Args:
417
  lib_pth: Library file path
418
  in_pth: Input file path
 
419
  progress: Gradio progress tracker
420
 
421
  Returns:
@@ -425,65 +444,77 @@ def _predict_core(lib_pth, in_pth, progress):
425
 
426
  # Clear cache at start to prevent memory buildup
427
  clear_smiles_cache()
428
-
429
- # Load library data
430
- progress(0, desc="Loading library data...")
431
- msdata_lib = MSData.load(lib_pth)
432
- embs_lib = msdata_lib[DREAMS_EMBEDDING]
433
- print(f'Shape of the library embeddings: {embs_lib.shape}')
434
-
435
- # Get query embeddings
436
- embs = _predict_gpu(in_pth, progress)
437
-
438
- # Compute similarity matrix
439
- progress(0.4, desc="Computing similarity matrix...")
440
- sims = cosine_similarity(embs, embs_lib)
441
- print(f'Shape of the similarity matrix: {sims.shape}')
442
-
443
- # Get top-k candidates
444
- k = 1
445
- topk_cands = np.argsort(sims, axis=1)[:, -k:][:, ::-1]
446
-
447
- # Load query data for processing
448
- msdata = MSData.load(in_pth)
449
- print(f'Available columns: {msdata.columns()}')
450
-
451
- # Construct results DataFrame
452
- progress(0.5, desc="Constructing results table...")
453
- df = []
454
- cos_sim = su.PeakListModifiedCosine()
455
- total_spectra = len(topk_cands)
456
-
457
- for i, topk in enumerate(topk_cands):
458
- progress(0.5 + 0.4 * (i / total_spectra),
459
- desc=f"Processing hits for spectrum {i+1}/{total_spectra}...")
460
 
461
- for n, j in enumerate(topk):
462
- row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs)
463
- df.append(row_data)
464
 
465
- # Clear cache every 100 spectra to prevent memory buildup
466
- if (i + 1) % 100 == 0:
467
- clear_smiles_cache()
468
-
469
- df = pd.DataFrame(df)
470
-
471
- # Process and clean results
472
- progress(0.9, desc="Post-processing results...")
473
- df, csv_path = _process_results_dataframe(df, in_pth)
474
-
475
- progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
 
477
- return df, csv_path
 
 
 
478
 
479
 
480
- def predict(lib_pth, in_pth, progress=gr.Progress(track_tqdm=True)):
481
  """
482
  Main prediction function with error handling
483
 
484
  Args:
485
  lib_pth: Library file path
486
  in_pth: Input file path
 
487
  progress: Gradio progress tracker
488
 
489
  Returns:
@@ -501,7 +532,9 @@ def predict(lib_pth, in_pth, progress=gr.Progress(track_tqdm=True)):
501
  if not Path(lib_pth).exists():
502
  raise gr.Error("Spectral library not found. Please ensure the library file exists.")
503
 
504
- return _predict_core(lib_pth, in_pth, progress)
 
 
505
 
506
  except gr.Error:
507
  # Re-raise Gradio errors as-is
@@ -573,20 +606,28 @@ def _create_gradio_interface():
573
  label="Examples (click on a file to load as input)",
574
  )
575
 
 
 
 
 
 
 
 
 
576
  # Prediction button
577
  predict_button = gr.Button(value="Run DreaMS", variant="primary")
578
 
579
- # Output section
580
  gr.Markdown("## Predictions")
581
  df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
582
 
583
  # Results table
584
  df = gr.Dataframe(
585
  headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum",
586
- "Library ID", "DreaMS similarity", "Modified cosine similarity"],
587
- datatype=["number", "number", "number", "html", "html", "str", "number", "number"],
588
- col_count=(8, "fixed"),
589
- column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px", "40px"],
590
  max_height=1000,
591
  show_fullscreen_button=True,
592
  show_row_numbers=False,
@@ -594,8 +635,29 @@ def _create_gradio_interface():
594
  )
595
 
596
  # Connect prediction logic
597
- inputs = [in_pth]
598
  outputs = [df, df_file]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  predict_func = partial(predict, LIBRARY_PATH)
600
  predict_button.click(predict_func, inputs=inputs, outputs=outputs, show_progress="first")
601
 
 
11
 
12
  import gradio as gr
13
  import spaces
14
+ import shutil
15
  import urllib.request
16
  from datetime import datetime
17
  from functools import partial
 
301
  Returns:
302
  numpy.ndarray: DreaMS embeddings
303
  """
304
+ progress(0.2, desc="Loading spectra data...")
305
  msdata = MSData.load(in_pth)
306
 
307
+ progress(0.3, desc="Computing DreaMS embeddings...")
308
  embs = dreams_embeddings(msdata)
309
  print(f'Shape of the query embeddings: {embs.shape}')
310
 
311
  return embs
312
 
313
 
314
+ def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, calculate_modified_cosine=False):
315
  """
316
  Create a single result row for the DataFrame
317
 
 
324
  sims: Similarity matrix
325
  cos_sim: Cosine similarity calculator
326
  embs: Query embeddings
327
+ calculate_modified_cosine: Whether to calculate modified cosine similarity
328
 
329
  Returns:
330
  dict: Result row data
 
333
  spec1 = msdata.get_spectra(i)
334
  spec2 = msdata_lib.get_spectra(j)
335
 
336
+ # Base row data
337
+ row_data = {
338
  'feature_id': i + 1,
339
  'precursor_mz': msdata.get_prec_mzs(i),
340
  'topk': n + 1,
 
345
  'Spectrum_raw': su.unpad_peak_list(spec1),
346
  'library_ID': msdata_lib.get_values('IDENTIFIER', j),
347
  'DreaMS_similarity': sims[i, j],
 
 
 
 
 
 
348
  'i': i,
349
  'j': j,
350
  'DreaMS_embedding': embs[i],
351
  }
352
+
353
+ # Add modified cosine similarity only if enabled
354
+ if calculate_modified_cosine:
355
+ modified_cosine_sim = cos_sim(
356
+ spec1=spec1,
357
+ prec_mz1=msdata.get_prec_mzs(i),
358
+ spec2=spec2,
359
+ prec_mz2=msdata_lib.get_prec_mzs(j),
360
+ )
361
+ row_data['Modified_cosine_similarity'] = modified_cosine_sim
362
+
363
+ return row_data
364
 
365
 
366
+ def _process_results_dataframe(df, in_pth, calculate_modified_cosine=False):
367
  """
368
  Process and clean the results DataFrame
369
 
370
  Args:
371
  df: Raw results DataFrame
372
  in_pth: Input file path for CSV export
373
+ calculate_modified_cosine: Whether modified cosine similarity was calculated
374
 
375
  Returns:
376
  tuple: (processed_df, csv_path)
 
382
  # Remove unnecessary columns and round similarity scores
383
  df = df.drop(columns=['i', 'j', 'library_j'])
384
  df['DreaMS_similarity'] = df['DreaMS_similarity'].astype(float).round(4)
385
+
386
+ # Handle modified cosine similarity column conditionally
387
+ if calculate_modified_cosine and 'Modified_cosine_similarity' in df.columns:
388
+ df['Modified_cosine_similarity'] = df['Modified_cosine_similarity'].astype(float).round(4)
389
+
390
  df['precursor_mz'] = df['precursor_mz'].astype(float).round(4)
391
 
392
  # Rename columns for display
 
400
  "Spectrum": "Spectrum",
401
  "Spectrum_raw": "Input Spectrum",
402
  "DreaMS_similarity": "DreaMS similarity",
 
403
  "DreaMS_embedding": "DreaMS embedding",
404
  }
405
+
406
+ # Add modified cosine similarity to column mapping only if it exists
407
+ if calculate_modified_cosine and 'Modified_cosine_similarity' in df.columns:
408
+ column_mapping["Modified_cosine_similarity"] = "Modified cos similarity"
409
+
410
  df = df.rename(columns=column_mapping)
411
 
412
  # Save full results to CSV
 
427
  return df, str(df_path)
428
 
429
 
430
+ def _predict_core(lib_pth, in_pth, calculate_modified_cosine, progress):
431
  """
432
  Core prediction function that orchestrates the entire prediction pipeline
433
 
434
  Args:
435
  lib_pth: Library file path
436
  in_pth: Input file path
437
+ calculate_modified_cosine: Whether to calculate modified cosine similarity
438
  progress: Gradio progress tracker
439
 
440
  Returns:
 
444
 
445
  # Clear cache at start to prevent memory buildup
446
  clear_smiles_cache()
447
+
448
+ # Create temporary copy of library file to allow multiple processes
449
+ progress(0, desc="Creating temporary library copy...")
450
+ temp_lib_path = Path(lib_pth).parent / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{Path(lib_pth).name}"
451
+ shutil.copy2(lib_pth, temp_lib_path)
452
+
453
+ try:
454
+ # Load library data
455
+ progress(0.1, desc="Loading library data...")
456
+ msdata_lib = MSData.load(temp_lib_path)
457
+ embs_lib = msdata_lib[DREAMS_EMBEDDING]
458
+ print(f'Shape of the library embeddings: {embs_lib.shape}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
+ # Get query embeddings
461
+ embs = _predict_gpu(in_pth, progress)
 
462
 
463
+ # Compute similarity matrix
464
+ progress(0.4, desc="Computing similarity matrix...")
465
+ sims = cosine_similarity(embs, embs_lib)
466
+ print(f'Shape of the similarity matrix: {sims.shape}')
467
+
468
+ # Get top-k candidates
469
+ k = 1
470
+ topk_cands = np.argsort(sims, axis=1)[:, -k:][:, ::-1]
471
+
472
+ # Load query data for processing
473
+ msdata = MSData.load(in_pth)
474
+ print(f'Available columns: {msdata.columns()}')
475
+
476
+ # Construct results DataFrame
477
+ progress(0.5, desc="Constructing results table...")
478
+ df = []
479
+ cos_sim = su.PeakListModifiedCosine()
480
+ total_spectra = len(topk_cands)
481
+
482
+ for i, topk in enumerate(topk_cands):
483
+ progress(0.5 + 0.4 * (i / total_spectra),
484
+ desc=f"Processing hits for spectrum {i+1}/{total_spectra}...")
485
+
486
+ for n, j in enumerate(topk):
487
+ row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, calculate_modified_cosine)
488
+ df.append(row_data)
489
+
490
+ # Clear cache every 100 spectra to prevent memory buildup
491
+ if (i + 1) % 100 == 0:
492
+ clear_smiles_cache()
493
+
494
+ df = pd.DataFrame(df)
495
+
496
+ # Process and clean results
497
+ progress(0.9, desc="Post-processing results...")
498
+ df, csv_path = _process_results_dataframe(df, in_pth, calculate_modified_cosine)
499
+
500
+ progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.")
501
+
502
+ return df, csv_path
503
 
504
+ finally:
505
+ # Clean up temporary library file
506
+ if temp_lib_path.exists():
507
+ temp_lib_path.unlink()
508
 
509
 
510
+ def predict(lib_pth, in_pth, calculate_modified_cosine=False, progress=gr.Progress(track_tqdm=True)):
511
  """
512
  Main prediction function with error handling
513
 
514
  Args:
515
  lib_pth: Library file path
516
  in_pth: Input file path
517
+ calculate_modified_cosine: Whether to calculate modified cosine similarity
518
  progress: Gradio progress tracker
519
 
520
  Returns:
 
532
  if not Path(lib_pth).exists():
533
  raise gr.Error("Spectral library not found. Please ensure the library file exists.")
534
 
535
+ df, csv_path = _predict_core(lib_pth, in_pth, calculate_modified_cosine, progress)
536
+
537
+ return df, csv_path
538
 
539
  except gr.Error:
540
  # Re-raise Gradio errors as-is
 
606
  label="Examples (click on a file to load as input)",
607
  )
608
 
609
+ # Settings section
610
+ with gr.Accordion("⚙️ Settings", open=False):
611
+ calculate_modified_cosine = gr.Checkbox(
612
+ label="Calculate modified cosine similarity",
613
+ value=False,
614
+ info="Enable to calculate traditional modified cosine similarity scores (slower)"
615
+ )
616
+
617
  # Prediction button
618
  predict_button = gr.Button(value="Run DreaMS", variant="primary")
619
 
620
+ # Results table
621
  gr.Markdown("## Predictions")
622
  df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
623
 
624
  # Results table
625
  df = gr.Dataframe(
626
  headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum",
627
+ "Library ID", "DreaMS similarity"],
628
+ datatype=["number", "number", "number", "html", "html", "str", "number"],
629
+ col_count=(7, "fixed"),
630
+ column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px"],
631
  max_height=1000,
632
  show_fullscreen_button=True,
633
  show_row_numbers=False,
 
635
  )
636
 
637
  # Connect prediction logic
638
+ inputs = [in_pth, calculate_modified_cosine]
639
  outputs = [df, df_file]
640
+
641
+ # Function to update dataframe headers based on setting
642
+ def update_headers(show_cosine):
643
+ if show_cosine:
644
+ return gr.update(headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum",
645
+ "Library ID", "DreaMS similarity", "Modified cosine similarity"],
646
+ col_count=(8, "fixed"),
647
+ column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px", "40px"])
648
+ else:
649
+ return gr.update(headers=["Row", "Feature ID", "Precursor m/z", "Molecule", "Spectrum",
650
+ "Library ID", "DreaMS similarity"],
651
+ col_count=(7, "fixed"),
652
+ column_widths=["25px", "25px", "28px", "60px", "60px", "50px", "40px"])
653
+
654
+ # Update headers when setting changes
655
+ calculate_modified_cosine.change(
656
+ fn=update_headers,
657
+ inputs=[calculate_modified_cosine],
658
+ outputs=[df]
659
+ )
660
+
661
  predict_func = partial(predict, LIBRARY_PATH)
662
  predict_button.click(predict_func, inputs=inputs, outputs=outputs, show_progress="first")
663