Sergidev commited on
Commit
e856ebd
·
verified ·
1 Parent(s): 6a4abaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -27
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
  from transformers import AutoTokenizer, AutoModel
5
  import plotly.graph_objects as go
6
- import numpy as np
7
 
8
  model_name = "mistralai/Mistral-7B-v0.1"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -13,44 +11,44 @@ model = None
13
  if tokenizer.pad_token is None:
14
  tokenizer.pad_token = tokenizer.eos_token
15
 
16
- @spaces.GPU
17
  def get_embedding(text):
18
  global model
19
  if model is None:
20
- model = AutoModel.from_pretrained(model_name).cuda()
21
- model.resize_token_embeddings(len(tokenizer))
22
 
23
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
- return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
27
 
28
  def reduce_to_3d(embedding):
29
  return embedding[:3]
30
 
31
- @spaces.GPU
32
  def compare_embeddings(text_input):
33
- texts = [t.strip() for t in text_input.split('\n') if t.strip()]
34
- embeddings = [get_embedding(text) for text in texts]
35
- embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
36
-
37
- fig = go.Figure()
 
38
 
39
- # Add origin point (black)
40
- fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', name='Origin',
41
- marker=dict(size=5, color='black')))
42
 
43
- # Add lines and points for each text embedding
44
- colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
45
- for i, emb in enumerate(embeddings_3d):
46
- color = colors[i % len(colors)]
47
- fig.add_trace(go.Scatter3d(x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]],
48
- mode='lines+markers', name=f'Text {i+1}',
49
- line=dict(color=color), marker=dict(color=color)))
50
 
51
- fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
52
-
53
- return fig
 
 
54
 
55
  iface = gr.Interface(
56
  fn=compare_embeddings,
@@ -63,4 +61,5 @@ iface = gr.Interface(
63
  allow_flagging="never"
64
  )
65
 
66
- iface.launch()
 
 
1
  import gradio as gr
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModel
4
  import plotly.graph_objects as go
 
5
 
6
  model_name = "mistralai/Mistral-7B-v0.1"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
11
  if tokenizer.pad_token is None:
12
  tokenizer.pad_token = tokenizer.eos_token
13
 
 
14
  def get_embedding(text):
15
  global model
16
  if model is None:
17
+ model = AutoModel.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
 
18
 
19
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
20
  with torch.no_grad():
21
  outputs = model(**inputs)
22
+ return outputs.last_hidden_state.mean(dim=1).squeeze().cpu()
23
 
24
  def reduce_to_3d(embedding):
25
  return embedding[:3]
26
 
 
27
  def compare_embeddings(text_input):
28
+ try:
29
+ texts = [t.strip() for t in text_input.split('\n') if t.strip()]
30
+ embeddings = [get_embedding(text) for text in texts]
31
+ embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
32
+
33
+ fig = go.Figure()
34
 
35
+ # Add origin point (black)
36
+ fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', name='Origin',
37
+ marker=dict(size=5, color='black')))
38
 
39
+ # Add lines and points for each text embedding
40
+ colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
41
+ for i, emb in enumerate(embeddings_3d):
42
+ color = colors[i % len(colors)]
43
+ fig.add_trace(go.Scatter3d(x=[0, emb[0].item()], y=[0, emb[1].item()], z=[0, emb[2].item()],
44
+ mode='lines+markers', name=f'Text {i+1}',
45
+ line=dict(color=color), marker=dict(color=color)))
46
 
47
+ fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
48
+
49
+ return fig
50
+ except Exception as e:
51
+ return f"An error occurred: {str(e)}"
52
 
53
  iface = gr.Interface(
54
  fn=compare_embeddings,
 
61
  allow_flagging="never"
62
  )
63
 
64
+ if __name__ == "__main__":
65
+ iface.launch()