3dembed / app.py
Sergidev's picture
Update app.py
e856ebd verified
raw
history blame
2.41 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import plotly.graph_objects as go
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = None
# Set pad token to eos token if not defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def get_embedding(text):
global model
if model is None:
model = AutoModel.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu()
def reduce_to_3d(embedding):
return embedding[:3]
def compare_embeddings(text_input):
try:
texts = [t.strip() for t in text_input.split('\n') if t.strip()]
embeddings = [get_embedding(text) for text in texts]
embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
fig = go.Figure()
# Add origin point (black)
fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', name='Origin',
marker=dict(size=5, color='black')))
# Add lines and points for each text embedding
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
for i, emb in enumerate(embeddings_3d):
color = colors[i % len(colors)]
fig.add_trace(go.Scatter3d(x=[0, emb[0].item()], y=[0, emb[1].item()], z=[0, emb[2].item()],
mode='lines+markers', name=f'Text {i+1}',
line=dict(color=color), marker=dict(color=color)))
fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
return fig
except Exception as e:
return f"An error occurred: {str(e)}"
iface = gr.Interface(
fn=compare_embeddings,
inputs=[
gr.Textbox(label="Input Texts", lines=5, placeholder="Enter multiple texts, each on a new line")
],
outputs=gr.Plot(),
title="3D Embedding Comparison",
description="Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B.",
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()