Shad0ws ijktech-jk commited on
Commit
194d0c8
·
0 Parent(s):

Duplicate from ijktech/matcher

Browse files

Co-authored-by: James Kelly <[email protected]>

Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +15 -0
  3. app.py +61 -0
  4. requirements.txt +2 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Matcher
3
+ emoji: 🏢
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.10.8
12
+ duplicated_from: ijktech/matcher
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ # Mean Pooling - Take attention mask into account for correct averaging
8
+ def mean_pooling(model_output, attention_mask):
9
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
10
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
11
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
12
+
13
+
14
+ class Matcher:
15
+
16
+ def __init__(self):
17
+ self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
18
+ self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
19
+
20
+ def _encoder(self, text: list[str]):
21
+ encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')
22
+ with torch.no_grad():
23
+ model_output = self.model(**encoded_input)
24
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
25
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
26
+ return sentence_embeddings
27
+
28
+ def __call__(self, textA: list[str], textB: list[str]):
29
+ embeddings_a = self._encoder(textA)
30
+ embeddings_b = self._encoder(textB)
31
+ sim = embeddings_a @ embeddings_b.T
32
+ match_inds = torch.argmax(sim, dim=1)
33
+ match_conf = torch.max(sim, dim=1).values
34
+ return match_inds.tolist(), match_conf.tolist()
35
+
36
+
37
+ def run_match(source_text, destination_text):
38
+ matcher = Matcher()
39
+ sources = source_text.split("\n")
40
+ destinations = destination_text.split("\n")
41
+ match_inds, match_conf = matcher(sources, destinations)
42
+ matches = [f"{sources[i]} -> {destinations[match_inds[i]]} ({match_conf[i]:.2f})" for i in
43
+ range(len(sources))]
44
+ return "\n".join(matches)
45
+
46
+
47
+ with gr.Blocks() as demo:
48
+ with gr.Row():
49
+ with gr.Column():
50
+ source_text = gr.Textbox(lines=10, label="Query Text", name="source_text",
51
+ default="diavola with extra chillies\nseafood\nmargherita")
52
+ with gr.Column():
53
+ dest_text = gr.Textbox(lines=10, label="Target Text", name="destination_text",
54
+ default="cheese pizza\nhot and spicy pizza\ntuna, prawn and onion pizza")
55
+ with gr.Column():
56
+ matches = gr.Textbox(lines=10, label="Matches", name="matches")
57
+ with gr.Row():
58
+ match_btn = gr.Button(label="Match", name="run")
59
+ match_btn.click(fn=run_match, inputs=[source_text, dest_text], outputs=matches)
60
+
61
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==2.0.0
2
+ transformers==4.25.1