Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,17 +12,52 @@ from transformers import AutoTokenizer, AutoModel
|
|
12 |
# Initialize the embedder at module level
|
13 |
embedder = None
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
@spaces.GPU(duration=120)
|
16 |
def initialize_embedder(embedding_dim=768):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
global embedder
|
18 |
if embedder is None:
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Check for GPU support and configure appropriately
|
28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -49,40 +84,6 @@ def tokenize(tokenizer, input_texts, eod_id, max_length):
|
|
49 |
batch_dict = tokenizer.pad(batch_dict, padding=True, return_tensors="pt")
|
50 |
return batch_dict
|
51 |
|
52 |
-
class QwenEmbedder:
|
53 |
-
def __init__(self, embedding_dim=768):
|
54 |
-
self.tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
|
55 |
-
self.model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')
|
56 |
-
# Uncomment below for better performance if GPU available
|
57 |
-
# self.model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B',
|
58 |
-
# attn_implementation="flash_attention_2",
|
59 |
-
# torch_dtype=torch.float16
|
60 |
-
# ).cuda()
|
61 |
-
self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
62 |
-
self.max_length = 8192
|
63 |
-
self.embedding_dim = embedding_dim
|
64 |
-
self.projection = torch.nn.Linear(768, embedding_dim) if embedding_dim != 768 else None
|
65 |
-
|
66 |
-
def get_embeddings(self, texts: List[str], with_instruction: bool = False) -> Tensor:
|
67 |
-
if with_instruction:
|
68 |
-
task = 'Process and understand the following text'
|
69 |
-
texts = [get_detailed_instruct(task, text) for text in texts]
|
70 |
-
|
71 |
-
batch_dict = tokenize(self.tokenizer, texts, self.eod_id, self.max_length)
|
72 |
-
batch_dict.to(self.model.device)
|
73 |
-
|
74 |
-
with torch.no_grad():
|
75 |
-
outputs = self.model(**batch_dict)
|
76 |
-
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
77 |
-
|
78 |
-
# Project to desired dimension if needed
|
79 |
-
if self.projection is not None:
|
80 |
-
embeddings = self.projection(embeddings)
|
81 |
-
|
82 |
-
embeddings = F.normalize(embeddings, p=2, dim=1)
|
83 |
-
|
84 |
-
return embeddings
|
85 |
-
|
86 |
def compute_similarity(embedder: QwenEmbedder, text1: str, text2: str) -> float:
|
87 |
embeddings = embedder.get_embeddings([text1, text2])
|
88 |
similarity = torch.cosine_similarity(embeddings[0:1], embeddings[1:2]).item()
|
@@ -259,12 +260,6 @@ def extract_concepts(embedder: QwenEmbedder, text: str, concept_type: str) -> Li
|
|
259 |
|
260 |
return [(concept, round(score, 3)) for concept, score in results]
|
261 |
|
262 |
-
# Add a function to reinitialize embedder with new dimension
|
263 |
-
def reinitialize_embedder(dim: int) -> QwenEmbedder:
|
264 |
-
global embedder
|
265 |
-
embedder = QwenEmbedder(embedding_dim=dim)
|
266 |
-
return "Embedder reinitialized with dimension: " + str(dim)
|
267 |
-
|
268 |
# Update the CSS to improve feature visibility
|
269 |
custom_css = """
|
270 |
:root {
|
@@ -454,16 +449,9 @@ button.secondary {
|
|
454 |
|
455 |
# Create the Gradio interface
|
456 |
def create_demo():
|
457 |
-
global embedder
|
458 |
-
# Initialize embedder with GPU support
|
459 |
-
embedder = initialize_embedder()
|
460 |
-
|
461 |
demo = gr.Blocks(title="Advanced Text Processing with Qwen", css=custom_css, theme=gr.themes.Soft())
|
462 |
|
463 |
with demo:
|
464 |
-
# Store embedder in state
|
465 |
-
state = gr.State(embedder)
|
466 |
-
|
467 |
with gr.Row():
|
468 |
# Sidebar
|
469 |
with gr.Column(scale=1, elem_classes="sidebar"):
|
@@ -601,8 +589,8 @@ def create_demo():
|
|
601 |
similarity_score = gr.Number(label="Similarity Score")
|
602 |
|
603 |
similarity_btn.click(
|
604 |
-
fn=lambda t1, t2
|
605 |
-
inputs=[text1, text2
|
606 |
outputs=similarity_score
|
607 |
)
|
608 |
|
@@ -644,8 +632,8 @@ def create_demo():
|
|
644 |
)
|
645 |
|
646 |
rerank_btn.click(
|
647 |
-
fn=lambda q, d
|
648 |
-
inputs=[query_text, documents_text
|
649 |
outputs=rerank_results
|
650 |
)
|
651 |
|
@@ -679,8 +667,8 @@ def create_demo():
|
|
679 |
)
|
680 |
|
681 |
process_btn.click(
|
682 |
-
fn=lambda t
|
683 |
-
inputs=[batch_texts
|
684 |
outputs=[similarity_matrix]
|
685 |
)
|
686 |
|
@@ -751,8 +739,8 @@ def create_demo():
|
|
751 |
""")
|
752 |
|
753 |
retrieve_btn.click(
|
754 |
-
fn=lambda p, q, d
|
755 |
-
inputs=[task_prompt, queries_text, documents_text
|
756 |
outputs=[retrieval_matrix]
|
757 |
)
|
758 |
|
@@ -807,8 +795,8 @@ def create_demo():
|
|
807 |
""")
|
808 |
|
809 |
match_btn.click(
|
810 |
-
fn=lambda a, e
|
811 |
-
inputs=[arabic_text, english_text
|
812 |
outputs=[cross_lingual_score]
|
813 |
)
|
814 |
|
@@ -850,8 +838,8 @@ def create_demo():
|
|
850 |
)
|
851 |
|
852 |
classify_btn.click(
|
853 |
-
fn=lambda t, c
|
854 |
-
inputs=[input_text, categories_text
|
855 |
outputs=classification_results
|
856 |
)
|
857 |
|
@@ -899,8 +887,8 @@ def create_demo():
|
|
899 |
)
|
900 |
|
901 |
cluster_btn.click(
|
902 |
-
fn=lambda d, n
|
903 |
-
inputs=[cluster_docs, num_clusters
|
904 |
outputs=clustering_results
|
905 |
)
|
906 |
|
@@ -932,8 +920,8 @@ def create_demo():
|
|
932 |
sentiment_scores = gr.Json(label="Detailed Scores")
|
933 |
|
934 |
analyze_btn.click(
|
935 |
-
fn=lambda t
|
936 |
-
inputs=[sentiment_text
|
937 |
outputs=[sentiment_label, sentiment_scores]
|
938 |
)
|
939 |
|
@@ -972,29 +960,30 @@ def create_demo():
|
|
972 |
)
|
973 |
|
974 |
extract_btn.click(
|
975 |
-
fn=lambda t, c
|
976 |
-
inputs=[concept_text, concept_type
|
977 |
outputs=concept_results
|
978 |
)
|
979 |
|
980 |
-
#
|
981 |
-
|
|
|
|
|
982 |
try:
|
983 |
-
|
984 |
-
|
985 |
-
return state, f"Successfully updated embedding dimension to {dim}"
|
986 |
except Exception as e:
|
987 |
-
return
|
988 |
|
989 |
update_dim_btn.click(
|
990 |
fn=update_embedder_dim,
|
991 |
-
inputs=[embedding_dim
|
992 |
-
outputs=
|
993 |
)
|
994 |
|
995 |
return demo
|
996 |
|
997 |
if __name__ == "__main__":
|
998 |
demo = create_demo()
|
999 |
-
demo.queue()
|
1000 |
demo.launch()
|
|
|
12 |
# Initialize the embedder at module level
|
13 |
embedder = None
|
14 |
|
15 |
+
class QwenEmbedder:
|
16 |
+
def __init__(self, embedding_dim=768):
|
17 |
+
self.tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
|
18 |
+
self.model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B')
|
19 |
+
self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
20 |
+
self.max_length = 8192
|
21 |
+
self.embedding_dim = embedding_dim
|
22 |
+
self.projection = torch.nn.Linear(768, embedding_dim) if embedding_dim != 768 else None
|
23 |
+
|
24 |
+
def to_device(self, device):
|
25 |
+
self.model = self.model.to(device)
|
26 |
+
if self.projection is not None:
|
27 |
+
self.projection = self.projection.to(device)
|
28 |
+
return self
|
29 |
+
|
30 |
@spaces.GPU(duration=120)
|
31 |
def initialize_embedder(embedding_dim=768):
|
32 |
+
# Initialize device inside the GPU worker
|
33 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
+
print(f"Initializing embedder on device: {device}")
|
35 |
+
|
36 |
+
# Create and move model to device
|
37 |
+
model = QwenEmbedder(embedding_dim=embedding_dim)
|
38 |
+
return model.to_device(device)
|
39 |
+
|
40 |
+
@spaces.GPU(duration=120)
|
41 |
+
def process_with_embedder(fn_name, *args):
|
42 |
+
"""Generic handler for embedder operations"""
|
43 |
global embedder
|
44 |
if embedder is None:
|
45 |
+
embedder = initialize_embedder()
|
46 |
+
|
47 |
+
# Map function names to actual functions
|
48 |
+
fn_map = {
|
49 |
+
'compute_similarity': compute_similarity,
|
50 |
+
'rerank_documents': rerank_documents,
|
51 |
+
'process_batch_embeddings': process_batch_embeddings,
|
52 |
+
'process_retrieval': process_retrieval,
|
53 |
+
'process_cross_lingual': process_cross_lingual,
|
54 |
+
'classify_text': classify_text,
|
55 |
+
'cluster_documents': cluster_documents,
|
56 |
+
'analyze_sentiment': analyze_sentiment,
|
57 |
+
'extract_concepts': extract_concepts
|
58 |
+
}
|
59 |
+
|
60 |
+
return fn_map[fn_name](embedder, *args)
|
61 |
|
62 |
# Check for GPU support and configure appropriately
|
63 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
84 |
batch_dict = tokenizer.pad(batch_dict, padding=True, return_tensors="pt")
|
85 |
return batch_dict
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
def compute_similarity(embedder: QwenEmbedder, text1: str, text2: str) -> float:
|
88 |
embeddings = embedder.get_embeddings([text1, text2])
|
89 |
similarity = torch.cosine_similarity(embeddings[0:1], embeddings[1:2]).item()
|
|
|
260 |
|
261 |
return [(concept, round(score, 3)) for concept, score in results]
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
# Update the CSS to improve feature visibility
|
264 |
custom_css = """
|
265 |
:root {
|
|
|
449 |
|
450 |
# Create the Gradio interface
|
451 |
def create_demo():
|
|
|
|
|
|
|
|
|
452 |
demo = gr.Blocks(title="Advanced Text Processing with Qwen", css=custom_css, theme=gr.themes.Soft())
|
453 |
|
454 |
with demo:
|
|
|
|
|
|
|
455 |
with gr.Row():
|
456 |
# Sidebar
|
457 |
with gr.Column(scale=1, elem_classes="sidebar"):
|
|
|
589 |
similarity_score = gr.Number(label="Similarity Score")
|
590 |
|
591 |
similarity_btn.click(
|
592 |
+
fn=lambda t1, t2: process_with_embedder('compute_similarity', t1, t2),
|
593 |
+
inputs=[text1, text2],
|
594 |
outputs=similarity_score
|
595 |
)
|
596 |
|
|
|
632 |
)
|
633 |
|
634 |
rerank_btn.click(
|
635 |
+
fn=lambda q, d: process_with_embedder('rerank_documents', q, d),
|
636 |
+
inputs=[query_text, documents_text],
|
637 |
outputs=rerank_results
|
638 |
)
|
639 |
|
|
|
667 |
)
|
668 |
|
669 |
process_btn.click(
|
670 |
+
fn=lambda t: process_with_embedder('process_batch_embeddings', t),
|
671 |
+
inputs=[batch_texts],
|
672 |
outputs=[similarity_matrix]
|
673 |
)
|
674 |
|
|
|
739 |
""")
|
740 |
|
741 |
retrieve_btn.click(
|
742 |
+
fn=lambda p, q, d: process_with_embedder('process_retrieval', p, q, d),
|
743 |
+
inputs=[task_prompt, queries_text, documents_text],
|
744 |
outputs=[retrieval_matrix]
|
745 |
)
|
746 |
|
|
|
795 |
""")
|
796 |
|
797 |
match_btn.click(
|
798 |
+
fn=lambda a, e: process_with_embedder('process_cross_lingual', a, e),
|
799 |
+
inputs=[arabic_text, english_text],
|
800 |
outputs=[cross_lingual_score]
|
801 |
)
|
802 |
|
|
|
838 |
)
|
839 |
|
840 |
classify_btn.click(
|
841 |
+
fn=lambda t, c: process_with_embedder('classify_text', t, c),
|
842 |
+
inputs=[input_text, categories_text],
|
843 |
outputs=classification_results
|
844 |
)
|
845 |
|
|
|
887 |
)
|
888 |
|
889 |
cluster_btn.click(
|
890 |
+
fn=lambda d, n: process_with_embedder('cluster_documents', d, n),
|
891 |
+
inputs=[cluster_docs, num_clusters],
|
892 |
outputs=clustering_results
|
893 |
)
|
894 |
|
|
|
920 |
sentiment_scores = gr.Json(label="Detailed Scores")
|
921 |
|
922 |
analyze_btn.click(
|
923 |
+
fn=lambda t: process_with_embedder('analyze_sentiment', t),
|
924 |
+
inputs=[sentiment_text],
|
925 |
outputs=[sentiment_label, sentiment_scores]
|
926 |
)
|
927 |
|
|
|
960 |
)
|
961 |
|
962 |
extract_btn.click(
|
963 |
+
fn=lambda t, c: process_with_embedder('extract_concepts', t, c),
|
964 |
+
inputs=[concept_text, concept_type],
|
965 |
outputs=concept_results
|
966 |
)
|
967 |
|
968 |
+
# Update dimension handler
|
969 |
+
@spaces.GPU(duration=120)
|
970 |
+
def update_embedder_dim(dim):
|
971 |
+
global embedder
|
972 |
try:
|
973 |
+
embedder = initialize_embedder(embedding_dim=dim)
|
974 |
+
return f"Successfully updated embedding dimension to {dim}"
|
|
|
975 |
except Exception as e:
|
976 |
+
return f"Error updating dimension: {str(e)}"
|
977 |
|
978 |
update_dim_btn.click(
|
979 |
fn=update_embedder_dim,
|
980 |
+
inputs=[embedding_dim],
|
981 |
+
outputs=dim_status
|
982 |
)
|
983 |
|
984 |
return demo
|
985 |
|
986 |
if __name__ == "__main__":
|
987 |
demo = create_demo()
|
988 |
+
demo.queue()
|
989 |
demo.launch()
|