Omartificial-Intelligence-Space commited on
Commit
7159e40
·
verified ·
1 Parent(s): 12d9c1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -81
app.py CHANGED
@@ -12,17 +12,52 @@ from transformers import AutoTokenizer, AutoModel
12
  # Initialize the embedder at module level
13
  embedder = None
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @spaces.GPU(duration=120)
16
  def initialize_embedder(embedding_dim=768):
 
 
 
 
 
 
 
 
 
 
 
17
  global embedder
18
  if embedder is None:
19
- # Check for GPU support and configure appropriately
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
- print(f"Initializing embedder on device: {device}")
22
-
23
- embedder = QwenEmbedder(embedding_dim=embedding_dim)
24
- embedder.model = embedder.model.to(device)
25
- return embedder
 
 
 
 
 
 
 
 
 
26
 
27
  # Check for GPU support and configure appropriately
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -49,40 +84,6 @@ def tokenize(tokenizer, input_texts, eod_id, max_length):
49
  batch_dict = tokenizer.pad(batch_dict, padding=True, return_tensors="pt")
50
  return batch_dict
51
 
52
- class QwenEmbedder:
53
- def __init__(self, embedding_dim=768):
54
- self.tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
55
- self.model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')
56
- # Uncomment below for better performance if GPU available
57
- # self.model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B',
58
- # attn_implementation="flash_attention_2",
59
- # torch_dtype=torch.float16
60
- # ).cuda()
61
- self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
62
- self.max_length = 8192
63
- self.embedding_dim = embedding_dim
64
- self.projection = torch.nn.Linear(768, embedding_dim) if embedding_dim != 768 else None
65
-
66
- def get_embeddings(self, texts: List[str], with_instruction: bool = False) -> Tensor:
67
- if with_instruction:
68
- task = 'Process and understand the following text'
69
- texts = [get_detailed_instruct(task, text) for text in texts]
70
-
71
- batch_dict = tokenize(self.tokenizer, texts, self.eod_id, self.max_length)
72
- batch_dict.to(self.model.device)
73
-
74
- with torch.no_grad():
75
- outputs = self.model(**batch_dict)
76
- embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
77
-
78
- # Project to desired dimension if needed
79
- if self.projection is not None:
80
- embeddings = self.projection(embeddings)
81
-
82
- embeddings = F.normalize(embeddings, p=2, dim=1)
83
-
84
- return embeddings
85
-
86
  def compute_similarity(embedder: QwenEmbedder, text1: str, text2: str) -> float:
87
  embeddings = embedder.get_embeddings([text1, text2])
88
  similarity = torch.cosine_similarity(embeddings[0:1], embeddings[1:2]).item()
@@ -259,12 +260,6 @@ def extract_concepts(embedder: QwenEmbedder, text: str, concept_type: str) -> Li
259
 
260
  return [(concept, round(score, 3)) for concept, score in results]
261
 
262
- # Add a function to reinitialize embedder with new dimension
263
- def reinitialize_embedder(dim: int) -> QwenEmbedder:
264
- global embedder
265
- embedder = QwenEmbedder(embedding_dim=dim)
266
- return "Embedder reinitialized with dimension: " + str(dim)
267
-
268
  # Update the CSS to improve feature visibility
269
  custom_css = """
270
  :root {
@@ -454,16 +449,9 @@ button.secondary {
454
 
455
  # Create the Gradio interface
456
  def create_demo():
457
- global embedder
458
- # Initialize embedder with GPU support
459
- embedder = initialize_embedder()
460
-
461
  demo = gr.Blocks(title="Advanced Text Processing with Qwen", css=custom_css, theme=gr.themes.Soft())
462
 
463
  with demo:
464
- # Store embedder in state
465
- state = gr.State(embedder)
466
-
467
  with gr.Row():
468
  # Sidebar
469
  with gr.Column(scale=1, elem_classes="sidebar"):
@@ -601,8 +589,8 @@ def create_demo():
601
  similarity_score = gr.Number(label="Similarity Score")
602
 
603
  similarity_btn.click(
604
- fn=lambda t1, t2, s: compute_similarity(s.value, t1, t2),
605
- inputs=[text1, text2, state],
606
  outputs=similarity_score
607
  )
608
 
@@ -644,8 +632,8 @@ def create_demo():
644
  )
645
 
646
  rerank_btn.click(
647
- fn=lambda q, d, s: rerank_documents(s.value, q, d),
648
- inputs=[query_text, documents_text, state],
649
  outputs=rerank_results
650
  )
651
 
@@ -679,8 +667,8 @@ def create_demo():
679
  )
680
 
681
  process_btn.click(
682
- fn=lambda t, s: process_batch_embeddings(s.value, t),
683
- inputs=[batch_texts, state],
684
  outputs=[similarity_matrix]
685
  )
686
 
@@ -751,8 +739,8 @@ def create_demo():
751
  """)
752
 
753
  retrieve_btn.click(
754
- fn=lambda p, q, d, s: process_retrieval(s.value, p, q, d),
755
- inputs=[task_prompt, queries_text, documents_text, state],
756
  outputs=[retrieval_matrix]
757
  )
758
 
@@ -807,8 +795,8 @@ def create_demo():
807
  """)
808
 
809
  match_btn.click(
810
- fn=lambda a, e, s: process_cross_lingual(s.value, a, e)["similarity"],
811
- inputs=[arabic_text, english_text, state],
812
  outputs=[cross_lingual_score]
813
  )
814
 
@@ -850,8 +838,8 @@ def create_demo():
850
  )
851
 
852
  classify_btn.click(
853
- fn=lambda t, c, s: classify_text(s.value, t, c),
854
- inputs=[input_text, categories_text, state],
855
  outputs=classification_results
856
  )
857
 
@@ -899,8 +887,8 @@ def create_demo():
899
  )
900
 
901
  cluster_btn.click(
902
- fn=lambda d, n, s: cluster_documents(s.value, d, n),
903
- inputs=[cluster_docs, num_clusters, state],
904
  outputs=clustering_results
905
  )
906
 
@@ -932,8 +920,8 @@ def create_demo():
932
  sentiment_scores = gr.Json(label="Detailed Scores")
933
 
934
  analyze_btn.click(
935
- fn=lambda t, s: analyze_sentiment(s.value, t),
936
- inputs=[sentiment_text, state],
937
  outputs=[sentiment_label, sentiment_scores]
938
  )
939
 
@@ -972,29 +960,30 @@ def create_demo():
972
  )
973
 
974
  extract_btn.click(
975
- fn=lambda t, c, s: extract_concepts(s.value, t, c),
976
- inputs=[concept_text, concept_type, state],
977
  outputs=concept_results
978
  )
979
 
980
- # Fix dimension update functionality
981
- def update_embedder_dim(dim, state):
 
 
982
  try:
983
- new_embedder = initialize_embedder(embedding_dim=dim)
984
- state.value = new_embedder
985
- return state, f"Successfully updated embedding dimension to {dim}"
986
  except Exception as e:
987
- return state, f"Error updating dimension: {str(e)}"
988
 
989
  update_dim_btn.click(
990
  fn=update_embedder_dim,
991
- inputs=[embedding_dim, state],
992
- outputs=[state, dim_status]
993
  )
994
 
995
  return demo
996
 
997
  if __name__ == "__main__":
998
  demo = create_demo()
999
- demo.queue() # Enable queuing for better handling of GPU resources
1000
  demo.launch()
 
12
  # Initialize the embedder at module level
13
  embedder = None
14
 
15
+ class QwenEmbedder:
16
+ def __init__(self, embedding_dim=768):
17
+ self.tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
18
+ self.model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')
19
+ self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
20
+ self.max_length = 8192
21
+ self.embedding_dim = embedding_dim
22
+ self.projection = torch.nn.Linear(768, embedding_dim) if embedding_dim != 768 else None
23
+
24
+ def to_device(self, device):
25
+ self.model = self.model.to(device)
26
+ if self.projection is not None:
27
+ self.projection = self.projection.to(device)
28
+ return self
29
+
30
  @spaces.GPU(duration=120)
31
  def initialize_embedder(embedding_dim=768):
32
+ # Initialize device inside the GPU worker
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ print(f"Initializing embedder on device: {device}")
35
+
36
+ # Create and move model to device
37
+ model = QwenEmbedder(embedding_dim=embedding_dim)
38
+ return model.to_device(device)
39
+
40
+ @spaces.GPU(duration=120)
41
+ def process_with_embedder(fn_name, *args):
42
+ """Generic handler for embedder operations"""
43
  global embedder
44
  if embedder is None:
45
+ embedder = initialize_embedder()
46
+
47
+ # Map function names to actual functions
48
+ fn_map = {
49
+ 'compute_similarity': compute_similarity,
50
+ 'rerank_documents': rerank_documents,
51
+ 'process_batch_embeddings': process_batch_embeddings,
52
+ 'process_retrieval': process_retrieval,
53
+ 'process_cross_lingual': process_cross_lingual,
54
+ 'classify_text': classify_text,
55
+ 'cluster_documents': cluster_documents,
56
+ 'analyze_sentiment': analyze_sentiment,
57
+ 'extract_concepts': extract_concepts
58
+ }
59
+
60
+ return fn_map[fn_name](embedder, *args)
61
 
62
  # Check for GPU support and configure appropriately
63
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
84
  batch_dict = tokenizer.pad(batch_dict, padding=True, return_tensors="pt")
85
  return batch_dict
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def compute_similarity(embedder: QwenEmbedder, text1: str, text2: str) -> float:
88
  embeddings = embedder.get_embeddings([text1, text2])
89
  similarity = torch.cosine_similarity(embeddings[0:1], embeddings[1:2]).item()
 
260
 
261
  return [(concept, round(score, 3)) for concept, score in results]
262
 
 
 
 
 
 
 
263
  # Update the CSS to improve feature visibility
264
  custom_css = """
265
  :root {
 
449
 
450
  # Create the Gradio interface
451
  def create_demo():
 
 
 
 
452
  demo = gr.Blocks(title="Advanced Text Processing with Qwen", css=custom_css, theme=gr.themes.Soft())
453
 
454
  with demo:
 
 
 
455
  with gr.Row():
456
  # Sidebar
457
  with gr.Column(scale=1, elem_classes="sidebar"):
 
589
  similarity_score = gr.Number(label="Similarity Score")
590
 
591
  similarity_btn.click(
592
+ fn=lambda t1, t2: process_with_embedder('compute_similarity', t1, t2),
593
+ inputs=[text1, text2],
594
  outputs=similarity_score
595
  )
596
 
 
632
  )
633
 
634
  rerank_btn.click(
635
+ fn=lambda q, d: process_with_embedder('rerank_documents', q, d),
636
+ inputs=[query_text, documents_text],
637
  outputs=rerank_results
638
  )
639
 
 
667
  )
668
 
669
  process_btn.click(
670
+ fn=lambda t: process_with_embedder('process_batch_embeddings', t),
671
+ inputs=[batch_texts],
672
  outputs=[similarity_matrix]
673
  )
674
 
 
739
  """)
740
 
741
  retrieve_btn.click(
742
+ fn=lambda p, q, d: process_with_embedder('process_retrieval', p, q, d),
743
+ inputs=[task_prompt, queries_text, documents_text],
744
  outputs=[retrieval_matrix]
745
  )
746
 
 
795
  """)
796
 
797
  match_btn.click(
798
+ fn=lambda a, e: process_with_embedder('process_cross_lingual', a, e),
799
+ inputs=[arabic_text, english_text],
800
  outputs=[cross_lingual_score]
801
  )
802
 
 
838
  )
839
 
840
  classify_btn.click(
841
+ fn=lambda t, c: process_with_embedder('classify_text', t, c),
842
+ inputs=[input_text, categories_text],
843
  outputs=classification_results
844
  )
845
 
 
887
  )
888
 
889
  cluster_btn.click(
890
+ fn=lambda d, n: process_with_embedder('cluster_documents', d, n),
891
+ inputs=[cluster_docs, num_clusters],
892
  outputs=clustering_results
893
  )
894
 
 
920
  sentiment_scores = gr.Json(label="Detailed Scores")
921
 
922
  analyze_btn.click(
923
+ fn=lambda t: process_with_embedder('analyze_sentiment', t),
924
+ inputs=[sentiment_text],
925
  outputs=[sentiment_label, sentiment_scores]
926
  )
927
 
 
960
  )
961
 
962
  extract_btn.click(
963
+ fn=lambda t, c: process_with_embedder('extract_concepts', t, c),
964
+ inputs=[concept_text, concept_type],
965
  outputs=concept_results
966
  )
967
 
968
+ # Update dimension handler
969
+ @spaces.GPU(duration=120)
970
+ def update_embedder_dim(dim):
971
+ global embedder
972
  try:
973
+ embedder = initialize_embedder(embedding_dim=dim)
974
+ return f"Successfully updated embedding dimension to {dim}"
 
975
  except Exception as e:
976
+ return f"Error updating dimension: {str(e)}"
977
 
978
  update_dim_btn.click(
979
  fn=update_embedder_dim,
980
+ inputs=[embedding_dim],
981
+ outputs=dim_status
982
  )
983
 
984
  return demo
985
 
986
  if __name__ == "__main__":
987
  demo = create_demo()
988
+ demo.queue()
989
  demo.launch()