Omartificial-Intelligence-Space commited on
Commit
cb11c04
·
verified ·
1 Parent(s): 187ab5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -45
app.py CHANGED
@@ -5,46 +5,45 @@ import numpy as np
5
  import plotly.express as px
6
  import pandas as pd
7
  import spaces
8
- from typing import List, Tuple
9
  from torch import Tensor
10
  from transformers import AutoTokenizer, AutoModel
 
 
11
 
12
  # Initialize the embedder at module level
13
  embedder = None
14
 
 
 
 
 
 
15
  class QwenEmbedder:
16
- def __init__(self, embedding_dim=768):
17
- self.tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
18
- self.model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')
19
- self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
20
- self.max_length = 8192
21
  self.embedding_dim = embedding_dim
22
- self.projection = torch.nn.Linear(768, embedding_dim) if embedding_dim != 768 else None
23
-
24
- def to_device(self, device):
25
- self.model = self.model.to(device)
26
- if self.projection is not None:
27
- self.projection = self.projection.to(device)
28
- return self
29
 
30
- def get_embeddings(self, texts: List[str], with_instruction: bool = False) -> Tensor:
 
 
 
 
 
 
 
31
  if with_instruction:
32
- task = 'Process and understand the following text'
33
- texts = [get_detailed_instruct(task, text) for text in texts]
34
 
35
- batch_dict = tokenize(self.tokenizer, texts, self.eod_id, self.max_length)
36
- batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
37
 
38
- with torch.no_grad():
39
- outputs = self.model(**batch_dict)
40
- embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
41
-
42
- # Project to desired dimension if needed
43
- if self.projection is not None:
44
- embeddings = self.projection(embeddings)
45
-
46
- embeddings = F.normalize(embeddings, p=2, dim=1)
47
 
 
 
48
  return embeddings
49
 
50
  @spaces.GPU(duration=120)
@@ -280,6 +279,86 @@ def extract_concepts(embedder: QwenEmbedder, text: str, concept_type: str) -> Li
280
 
281
  return [(concept, round(score, 3)) for concept, score in results]
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  # Update the CSS to improve feature visibility
284
  custom_css = """
285
  :root {
@@ -490,6 +569,11 @@ def create_demo():
490
  Configure the embedding model parameters below.
491
  """)
492
 
 
 
 
 
 
493
  embedding_dim = gr.Slider(
494
  minimum=32,
495
  maximum=1024,
@@ -609,8 +693,8 @@ def create_demo():
609
  similarity_score = gr.Number(label="Similarity Score")
610
 
611
  similarity_btn.click(
612
- fn=lambda t1, t2: process_with_embedder('compute_similarity', t1, t2),
613
- inputs=[text1, text2],
614
  outputs=similarity_score
615
  )
616
 
@@ -652,8 +736,8 @@ def create_demo():
652
  )
653
 
654
  rerank_btn.click(
655
- fn=lambda q, d: process_with_embedder('rerank_documents', q, d),
656
- inputs=[query_text, documents_text],
657
  outputs=rerank_results
658
  )
659
 
@@ -687,8 +771,8 @@ def create_demo():
687
  )
688
 
689
  process_btn.click(
690
- fn=lambda t: process_with_embedder('process_batch_embeddings', t),
691
- inputs=[batch_texts],
692
  outputs=[similarity_matrix]
693
  )
694
 
@@ -759,8 +843,8 @@ def create_demo():
759
  """)
760
 
761
  retrieve_btn.click(
762
- fn=lambda p, q, d: process_with_embedder('process_retrieval', p, q, d),
763
- inputs=[task_prompt, queries_text, documents_text],
764
  outputs=[retrieval_matrix]
765
  )
766
 
@@ -815,8 +899,8 @@ def create_demo():
815
  """)
816
 
817
  match_btn.click(
818
- fn=lambda a, e: process_with_embedder('process_cross_lingual', a, e),
819
- inputs=[arabic_text, english_text],
820
  outputs=[cross_lingual_score]
821
  )
822
 
@@ -858,8 +942,8 @@ def create_demo():
858
  )
859
 
860
  classify_btn.click(
861
- fn=lambda t, c: process_with_embedder('classify_text', t, c),
862
- inputs=[input_text, categories_text],
863
  outputs=classification_results
864
  )
865
 
@@ -907,8 +991,8 @@ def create_demo():
907
  )
908
 
909
  cluster_btn.click(
910
- fn=lambda d, n: process_with_embedder('cluster_documents', d, n),
911
- inputs=[cluster_docs, num_clusters],
912
  outputs=clustering_results
913
  )
914
 
@@ -940,8 +1024,8 @@ def create_demo():
940
  sentiment_scores = gr.Json(label="Detailed Scores")
941
 
942
  analyze_btn.click(
943
- fn=lambda t: process_with_embedder('analyze_sentiment', t),
944
- inputs=[sentiment_text],
945
  outputs=[sentiment_label, sentiment_scores]
946
  )
947
 
@@ -980,8 +1064,8 @@ def create_demo():
980
  )
981
 
982
  extract_btn.click(
983
- fn=lambda t, c: process_with_embedder('extract_concepts', t, c),
984
- inputs=[concept_text, concept_type],
985
  outputs=concept_results
986
  )
987
 
 
5
  import plotly.express as px
6
  import pandas as pd
7
  import spaces
8
+ from typing import List, Tuple, Dict
9
  from torch import Tensor
10
  from transformers import AutoTokenizer, AutoModel
11
+ from sentence_transformers import SentenceTransformer
12
+ import json
13
 
14
  # Initialize the embedder at module level
15
  embedder = None
16
 
17
+ AVAILABLE_MODELS = {
18
+ "Qwen Original": "Qwen/Qwen3-Embedding-0.6B",
19
+ "Arabic Fine-tuned": "Omartificial-Intelligence-Space/Semantic-Ar-Qwen-Embed-0.6B"
20
+ }
21
+
22
  class QwenEmbedder:
23
+ def __init__(self, model_name: str = "Qwen/Qwen3-Embedding-0.6B", embedding_dim: int = 768):
24
+ self.model = SentenceTransformer(model_name)
 
 
 
25
  self.embedding_dim = embedding_dim
26
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ self.model.to(self.device)
 
 
 
 
 
28
 
29
+ if embedding_dim != 768:
30
+ # Add projection layer if needed
31
+ self.projection = torch.nn.Linear(768, embedding_dim)
32
+ self.projection.to(self.device)
33
+ else:
34
+ self.projection = None
35
+
36
+ def get_embeddings(self, texts: List[str], with_instruction: bool = False) -> torch.Tensor:
37
  if with_instruction:
38
+ texts = [f"Represent this Arabic text for retrieval: {text}" for text in texts]
 
39
 
40
+ embeddings = self.model.encode(texts, convert_to_tensor=True)
 
41
 
42
+ if self.projection is not None:
43
+ embeddings = self.projection(embeddings)
 
 
 
 
 
 
 
44
 
45
+ # Normalize embeddings
46
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
47
  return embeddings
48
 
49
  @spaces.GPU(duration=120)
 
279
 
280
  return [(concept, round(score, 3)) for concept, score in results]
281
 
282
+ def create_embedder(model_choice: str, embedding_dim: int = 768) -> QwenEmbedder:
283
+ model_name = AVAILABLE_MODELS[model_choice]
284
+ return QwenEmbedder(model_name=model_name, embedding_dim=embedding_dim)
285
+
286
+ def process_similarity(text1: str, text2: str, model_choice: str, embedding_dim: int) -> float:
287
+ embedder = create_embedder(model_choice, embedding_dim)
288
+ embeddings = embedder.get_embeddings([text1, text2])
289
+ similarity = torch.nn.functional.cosine_similarity(embeddings[0].unsqueeze(0), embeddings[1].unsqueeze(0))
290
+ return float(similarity)
291
+
292
+ def process_reranking(query: str, documents: str, model_choice: str, embedding_dim: int) -> Dict:
293
+ embedder = create_embedder(model_choice, embedding_dim)
294
+ documents = [doc.strip() for doc in documents.split('\n') if doc.strip()]
295
+
296
+ query_embedding = embedder.get_embeddings([query], with_instruction=True)
297
+ doc_embeddings = embedder.get_embeddings(documents)
298
+
299
+ similarities = torch.nn.functional.cosine_similarity(query_embedding, doc_embeddings)
300
+
301
+ # Sort documents by similarity
302
+ sorted_indices = torch.argsort(similarities, descending=True)
303
+ results = []
304
+ for idx in sorted_indices:
305
+ results.append({
306
+ 'document': documents[idx],
307
+ 'score': float(similarities[idx])
308
+ })
309
+
310
+ return {'results': results}
311
+
312
+ def process_batch(texts: str, model_choice: str, embedding_dim: int) -> Dict:
313
+ embedder = create_embedder(model_choice, embedding_dim)
314
+ texts = [text.strip() for text in texts.split('\n') if text.strip()]
315
+
316
+ embeddings = embedder.get_embeddings(texts)
317
+ similarity_matrix = torch.nn.functional.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
318
+
319
+ df = pd.DataFrame(similarity_matrix.cpu().numpy(), index=texts, columns=texts)
320
+ return {'similarity_matrix': df.to_dict()}
321
+
322
+ def process_retrieval(prompt: str, queries: str, documents: str, model_choice: str, embedding_dim: int) -> Dict:
323
+ embedder = create_embedder(model_choice, embedding_dim)
324
+
325
+ # Process input strings
326
+ queries = [q.strip() for q in queries.split('\n') if q.strip()]
327
+ documents = [doc.strip() for doc in documents.split('\n') if doc.strip()]
328
+
329
+ # Get embeddings
330
+ prompt_embedding = embedder.get_embeddings([prompt], with_instruction=True)
331
+ query_embeddings = embedder.get_embeddings(queries, with_instruction=True)
332
+ doc_embeddings = embedder.get_embeddings(documents)
333
+
334
+ # Calculate similarities
335
+ query_similarities = torch.nn.functional.cosine_similarity(prompt_embedding, query_embeddings)
336
+ doc_similarities = torch.nn.functional.cosine_similarity(prompt_embedding.repeat(len(documents), 1), doc_embeddings)
337
+
338
+ # Process results
339
+ results = {
340
+ 'relevant_queries': [],
341
+ 'relevant_documents': []
342
+ }
343
+
344
+ # Sort queries
345
+ query_indices = torch.argsort(query_similarities, descending=True)
346
+ for idx in query_indices:
347
+ results['relevant_queries'].append({
348
+ 'query': queries[idx],
349
+ 'similarity': float(query_similarities[idx])
350
+ })
351
+
352
+ # Sort documents
353
+ doc_indices = torch.argsort(doc_similarities, descending=True)
354
+ for idx in doc_indices:
355
+ results['relevant_documents'].append({
356
+ 'document': documents[idx],
357
+ 'similarity': float(doc_similarities[idx])
358
+ })
359
+
360
+ return results
361
+
362
  # Update the CSS to improve feature visibility
363
  custom_css = """
364
  :root {
 
569
  Configure the embedding model parameters below.
570
  """)
571
 
572
+ model_choice = gr.Dropdown(
573
+ choices=list(AVAILABLE_MODELS.keys()),
574
+ value=list(AVAILABLE_MODELS.keys())[0],
575
+ label="Select Model"
576
+ )
577
  embedding_dim = gr.Slider(
578
  minimum=32,
579
  maximum=1024,
 
693
  similarity_score = gr.Number(label="Similarity Score")
694
 
695
  similarity_btn.click(
696
+ fn=lambda t1, t2, m, d: process_with_embedder('compute_similarity', t1, t2, m, d),
697
+ inputs=[text1, text2, model_choice, embedding_dim],
698
  outputs=similarity_score
699
  )
700
 
 
736
  )
737
 
738
  rerank_btn.click(
739
+ fn=lambda q, d, m, e: process_with_embedder('rerank_documents', q, d, m, e),
740
+ inputs=[query_text, documents_text, model_choice, embedding_dim],
741
  outputs=rerank_results
742
  )
743
 
 
771
  )
772
 
773
  process_btn.click(
774
+ fn=lambda t, m, e: process_with_embedder('process_batch_embeddings', t, m, e),
775
+ inputs=[batch_texts, model_choice, embedding_dim],
776
  outputs=[similarity_matrix]
777
  )
778
 
 
843
  """)
844
 
845
  retrieve_btn.click(
846
+ fn=lambda p, q, d, m, e: process_with_embedder('process_retrieval', p, q, d, m, e),
847
+ inputs=[task_prompt, queries_text, documents_text, model_choice, embedding_dim],
848
  outputs=[retrieval_matrix]
849
  )
850
 
 
899
  """)
900
 
901
  match_btn.click(
902
+ fn=lambda a, e, m, e: process_with_embedder('process_cross_lingual', a, e, m, e),
903
+ inputs=[arabic_text, english_text, model_choice, embedding_dim],
904
  outputs=[cross_lingual_score]
905
  )
906
 
 
942
  )
943
 
944
  classify_btn.click(
945
+ fn=lambda t, c, m, e: process_with_embedder('classify_text', t, c, m, e),
946
+ inputs=[input_text, categories_text, model_choice, embedding_dim],
947
  outputs=classification_results
948
  )
949
 
 
991
  )
992
 
993
  cluster_btn.click(
994
+ fn=lambda d, n, m, e: process_with_embedder('cluster_documents', d, n, m, e),
995
+ inputs=[cluster_docs, num_clusters, model_choice, embedding_dim],
996
  outputs=clustering_results
997
  )
998
 
 
1024
  sentiment_scores = gr.Json(label="Detailed Scores")
1025
 
1026
  analyze_btn.click(
1027
+ fn=lambda t, m, e: process_with_embedder('analyze_sentiment', t, m, e),
1028
+ inputs=[sentiment_text, model_choice, embedding_dim],
1029
  outputs=[sentiment_label, sentiment_scores]
1030
  )
1031
 
 
1064
  )
1065
 
1066
  extract_btn.click(
1067
+ fn=lambda t, c, m, e: process_with_embedder('extract_concepts', t, c, m, e),
1068
+ inputs=[concept_text, concept_type, model_choice, embedding_dim],
1069
  outputs=concept_results
1070
  )
1071