3dembed / app.py
Sergidev's picture
Update app.py
1c026a2 verified
raw
history blame
2.22 kB
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModel
import plotly.graph_objects as go
import numpy as np
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = None
# Set pad token to eos token if not defined
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
@spaces.GPU
def get_embedding(text):
global model
if model is None:
model = AutoModel.from_pretrained(model_name).cuda()
model.resize_token_embeddings(len(tokenizer))
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
def reduce_to_3d(embedding):
return embedding[:3]
@spaces.GPU
def compare_embeddings(text_input):
texts = text_input.split('\n')
embeddings = [get_embedding(text) for text in texts]
embeddings_3d = [reduce_to_3d(emb) for emb in embeddings]
fig = go.Figure()
# Add origin point (black)
fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', name='Origin',
marker=dict(size=5, color='black')))
# Add lines and points for each text embedding
colors = ['red', 'blue', 'green', 'purple', 'orange', 'cyan', 'magenta', 'yellow']
for i, emb in enumerate(embeddings_3d):
color = colors[i % len(colors)]
fig.add_trace(go.Scatter3d(x=[0, emb[0]], y=[0, emb[1]], z=[0, emb[2]],
mode='lines+markers', name=f'Text {i+1}',
line=dict(color=color), marker=dict(color=color)))
fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
return fig
iface = gr.Interface(
fn=compare_embeddings,
inputs=[
gr.Textbox(label="Input Texts", lines=5, placeholder="Enter multiple texts, each on a new line")
],
outputs=gr.Plot(),
title="3D Embedding Comparison",
description="Compare the embeddings of multiple strings visualized in 3D space using Mistral 7B."
)
iface.launch()