IotaCluster commited on
Commit
18c9386
·
verified ·
1 Parent(s): dead1a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -17
app.py CHANGED
@@ -1,23 +1,68 @@
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer
 
 
 
3
 
4
- # Load the multilingual embedding model
5
- model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
6
 
7
- # Define a function to embed text
8
- def embed(text: str):
9
  if not text.strip():
10
  return {"error": "Input text is empty."}
11
- embedding = model.encode([text])[0] # Get the embedding vector
12
- return {"embedding": embedding.tolist()}
13
-
14
- # Launch Gradio interface
15
- demo = gr.Interface(
16
- fn=embed,
17
- inputs=gr.Textbox(lines=3, label="Input Text"),
18
- outputs="json",
19
- title="Multilingual Text Embedder",
20
- description="Uses paraphrase-multilingual-MiniLM-L12-v2 to convert text into embeddings"
21
- )
22
-
23
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer
3
+ from rank_bm25 import BM25Okapi
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import torch
6
 
7
+ # 1. Dense embedding model (HF bi-encoder)
8
+ dense_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
9
 
10
+ def embed_dense(text: str):
 
11
  if not text.strip():
12
  return {"error": "Input text is empty."}
13
+ emb = dense_model.encode([text])[0]
14
+ return {"dense_embedding": emb.tolist()}
15
+
16
+ # 2. Sparse embedding model (BM25)
17
+ # Uses rank_bm25 to compute term weights
18
+
19
+ def embed_sparse(text: str):
20
+ if not text.strip():
21
+ return {"error": "Input text is empty."}
22
+ tokens = text.split()
23
+ bm25 = BM25Okapi([tokens])
24
+ scores = bm25.get_scores(tokens)
25
+ # Map each term to its BM25 weight
26
+ term_weights = {tok: float(score) for tok, score in zip(tokens, scores)}
27
+ return {"sparse_embedding": term_weights}
28
+
29
+ # 3. Late-interaction embedding model (ColBERT)
30
+ colbert_tokenizer = AutoTokenizer.from_pretrained('colbert-ir/colbertv2.0', use_fast=True)
31
+ colbert_model = AutoModel.from_pretrained('colbert-ir/colbertv2.0')
32
+
33
+ # Freeze model parameters for inference speed
34
+ for param in colbert_model.parameters():
35
+ param.requires_grad = False
36
+
37
+
38
+ def embed_colbert(text: str):
39
+ if not text.strip():
40
+ return {"error": "Input text is empty."}
41
+ inputs = colbert_tokenizer(text, return_tensors='pt', truncation=True, max_length=64)
42
+ with torch.no_grad():
43
+ outputs = colbert_model(**inputs)
44
+ # last_hidden_state: (1, seq_len, hidden_size)
45
+ embeddings = outputs.last_hidden_state.squeeze(0).tolist()
46
+ return {"colbert_embeddings": embeddings}
47
+
48
+ # Build Gradio interface with tabs for each model
49
+ with gr.Blocks(title="Text Embedding Playground") as demo:
50
+ gr.Markdown("# Text Embedding Playground\nChoose a model and input text to get embeddings.")
51
+ with gr.Tab("Dense (MiniLM-L6-v2)"):
52
+ txt1 = gr.Textbox(lines=3, label="Input Text")
53
+ out1 = gr.JSON(label="Embedding")
54
+ txt1.submit(embed_dense, txt1, out1)
55
+ gr.Button("Embed").click(embed_dense, txt1, out1)
56
+ with gr.Tab("Sparse (BM25)"):
57
+ txt2 = gr.Textbox(lines=3, label="Input Text")
58
+ out2 = gr.JSON(label="Term Weights")
59
+ txt2.submit(embed_sparse, txt2, out2)
60
+ gr.Button("Embed").click(embed_sparse, txt2, out2)
61
+ with gr.Tab("Late-Interaction (ColBERT)"):
62
+ txt3 = gr.Textbox(lines=3, label="Input Text")
63
+ out3 = gr.JSON(label="Embeddings per Token")
64
+ txt3.submit(embed_colbert, txt3, out3)
65
+ gr.Button("Embed").click(embed_colbert, txt3, out3)
66
+
67
+ if __name__ == "__main__":
68
+ demo.launch()