Spaces:
Running
Running
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
# 1. Dense embedding model (HF bi-encoder) | |
# dense_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# dense_model = SentenceTransformer('distiluse-base-multilingual-cased-v2') | |
dense_model = SentenceTransformer('multi-qa-mpnet-base-cos-v1') | |
def embed_dense(text: str): | |
if not text.strip(): | |
return {"error": "Input text is empty."} | |
emb = dense_model.encode([text])[0] | |
return {"dense_embedding": emb.tolist()} | |
# 2. Sparse embedding model (BM25) | |
# Uses rank_bm25 to compute term weights | |
def embed_sparse(text: str): | |
if not text.strip(): | |
return {"error": "Input text is empty."} | |
tokens = text.split() | |
bm25 = BM25Okapi([tokens]) | |
unique_terms = sorted(set(tokens)) | |
scores = bm25.get_scores(unique_terms) | |
# Assign scores for all unique terms | |
term_weights = {term: float(score) for term, score in zip(unique_terms, scores)} | |
indices = list(range(len(unique_terms))) | |
values = [term_weights.get(term, 0.0) for term in unique_terms] | |
return {"indices": indices, "values": values, "terms": unique_terms} | |
# 3. Late-interaction embedding model (ColBERT) | |
colbert_tokenizer = AutoTokenizer.from_pretrained('colbert-ir/colbertv2.0', use_fast=True) | |
colbert_model = AutoModel.from_pretrained('colbert-ir/colbertv2.0') | |
# Freeze model parameters for inference speed | |
for param in colbert_model.parameters(): | |
param.requires_grad = False | |
def embed_colbert(text: str): | |
if not text.strip(): | |
return {"error": "Input text is empty."} | |
inputs = colbert_tokenizer(text, return_tensors='pt', truncation=True, max_length=64) | |
with torch.no_grad(): | |
outputs = colbert_model(**inputs) | |
# last_hidden_state: (1, seq_len, hidden_size) | |
embeddings = outputs.last_hidden_state.squeeze(0).tolist() | |
return {"colbert_embeddings": embeddings} | |
# Build Gradio interface with tabs for each model | |
with gr.Blocks(title="Text Embedding Playground") as demo: | |
gr.Markdown("# Text Embedding Playground\nChoose a model and input text to get embeddings.") | |
with gr.Tab("Dense (MiniLM-L6-v2)"): | |
txt1 = gr.Textbox(lines=3, label="Input Text") | |
out1 = gr.JSON(label="Embedding") | |
txt1.submit(embed_dense, txt1, out1) | |
gr.Button("Embed").click(embed_dense, txt1, out1) | |
with gr.Tab("Sparse (BM25)"): | |
txt2 = gr.Textbox(lines=3, label="Input Text") | |
out2 = gr.JSON(label="Term Weights") | |
txt2.submit(embed_sparse, txt2, out2) | |
gr.Button("Embed").click(embed_sparse, txt2, out2) | |
with gr.Tab("Late-Interaction (ColBERT)"): | |
txt3 = gr.Textbox(lines=3, label="Input Text") | |
out3 = gr.JSON(label="Embeddings per Token") | |
txt3.submit(embed_colbert, txt3, out3) | |
gr.Button("Embed").click(embed_colbert, txt3, out3) | |
if __name__ == "__main__": | |
demo.launch(mcp_server=True) | |