Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import torch.nn.functional as F | |
# Mean Pooling - Take attention mask into account for correct averaging | |
def mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output[0] # First element of model_output contains all token embeddings | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
class Matcher: | |
def __init__(self): | |
self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
def _encoder(self, text: list[str]): | |
encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): | |
model_output = self.model(**encoded_input) | |
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) | |
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) | |
return sentence_embeddings | |
def __call__(self, textA: list[str], textB: list[str]): | |
embeddings_a = self._encoder(textA) | |
embeddings_b = self._encoder(textB) | |
sim = embeddings_a @ embeddings_b.T | |
match_inds = torch.argmax(sim, dim=1) | |
match_conf = torch.max(sim, dim=1).values | |
return match_inds.tolist(), match_conf.tolist() | |
def run_match(source_text, destination_text): | |
matcher = Matcher() | |
sources = source_text.split("\n") | |
destinations = destination_text.split("\n") | |
match_inds, match_conf = matcher(sources, destinations) | |
matches = [f"{sources[i]} -> {destinations[match_inds[i]]} ({match_conf[i]:.2f})" for i in | |
range(len(sources))] | |
return "\n".join(matches) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
source_text = gr.Textbox(lines=10, label="Query Text", name="source_text") | |
with gr.Column(): | |
dest_text = gr.Textbox(lines=10, label="Target Text", name="destination_text") | |
with gr.Column(): | |
matches = gr.Textbox(lines=10, label="Matches", name="matches") | |
with gr.Row(): | |
match_btn = gr.Button(label="Match", name="run") | |
match_btn.click(fn=run_match, inputs=[source_text, dest_text], outputs=matches) | |
demo.launch() |