Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
-
import spaces
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer, AutoModel
|
5 |
import plotly.graph_objects as go
|
6 |
-
import numpy as np
|
7 |
|
8 |
model_name = "mistralai/Mistral-7B-v0.1"
|
9 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
@@ -13,44 +11,44 @@ model = None
|
|
13 |
if tokenizer.pad_token is None:
|
14 |
tokenizer.pad_token = tokenizer.eos_token
|
15 |
|
16 |
-
@spaces.GPU
|
17 |
def get_embedding(text):
|
18 |
global model
|
19 |
if model is None:
|
20 |
-
model = AutoModel.from_pretrained(model_name
|
21 |
-
model.resize_token_embeddings(len(tokenizer))
|
22 |
|
23 |
-
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(
|
24 |
with torch.no_grad():
|
25 |
outputs = model(**inputs)
|
26 |
-
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu()
|
27 |
|
28 |
def reduce_to_3d(embedding):
|
29 |
return embedding[:3]
|
30 |
|
31 |
-
@spaces.GPU
|
32 |
def compare_embeddings(text_input):
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
|
55 |
iface = gr.Interface(
|
56 |
fn=compare_embeddings,
|
@@ -63,4 +61,5 @@ iface = gr.Interface(
|
|
63 |
allow_flagging="never"
|
64 |
)
|
65 |
|
66 |
-
|
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModel
|
4 |
import plotly.graph_objects as go
|
|
|
5 |
|
6 |
model_name = "mistralai/Mistral-7B-v0.1"
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
11 |
if tokenizer.pad_token is None:
|
12 |
tokenizer.pad_token = tokenizer.eos_token
|
13 |
|
|
|
14 |
def get_embedding(text):
|
15 |
global model
|
16 |
if model is None:
|
17 |
+
model = AutoModel.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
|
|
|
18 |
|
19 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
|
20 |
with torch.no_grad():
|
21 |
outputs = model(**inputs)
|
22 |
+
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu()
|
23 |
|
24 |
def reduce_to_3d(embedding):
|
25 |
return embedding[:3]
|
26 |
|
|
|
27 |
def compare_embeddings(text_input):
|
28 |
+
try:
|
29 |
+
texts = [t.strip() for t in text_input.split('\n') if t.strip()]
|
30 |
+
embeddings = [get_embedding(text) for text in texts]
|
31 |
+
embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
|
32 |
+
|
33 |
+
fig = go.Figure()
|
34 |
|
35 |
+
# Add origin point (black)
|
36 |
+
fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', name='Origin',
|
37 |
+
marker=dict(size=5, color='black')))
|
38 |
|
39 |
+
# Add lines and points for each text embedding
|
40 |
+
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
|
41 |
+
for i, emb in enumerate(embeddings_3d):
|
42 |
+
color = colors[i % len(colors)]
|
43 |
+
fig.add_trace(go.Scatter3d(x=[0, emb[0].item()], y=[0, emb[1].item()], z=[0, emb[2].item()],
|
44 |
+
mode='lines+markers', name=f'Text {i+1}',
|
45 |
+
line=dict(color=color), marker=dict(color=color)))
|
46 |
|
47 |
+
fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
|
48 |
+
|
49 |
+
return fig
|
50 |
+
except Exception as e:
|
51 |
+
return f"An error occurred: {str(e)}"
|
52 |
|
53 |
iface = gr.Interface(
|
54 |
fn=compare_embeddings,
|
|
|
61 |
allow_flagging="never"
|
62 |
)
|
63 |
|
64 |
+
if __name__ == "__main__":
|
65 |
+
iface.launch()
|