translateEn2FR / app.py
amiguel's picture
Update app.py
3c9f4cd verified
raw
history blame
7.76 kB
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TextStreamer
from huggingface_hub import login
import PyPDF2
import pandas as pd
import torch
import time
# Device setup
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Set page configuration
st.set_page_config(
page_title="Translator Agent",
page_icon="🚀",
layout="centered"
)
# Model name
MODEL_NAME = "Helsinki-NLP/opus-mt-en-fr"
# Title with rocket emojis
st.title("🚀 English to French Translator 🚀")
# Configure Avatars
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
# Sidebar configuration
with st.sidebar:
st.header("Authentication 🔒")
hf_token = st.text_input("Hugging Face Token", type="password",
help="Get your token from https://huggingface.co/settings/tokens")
st.header("Upload Documents 📂")
uploaded_file = st.file_uploader(
"Choose a PDF or XLSX file to translate",
type=["pdf", "xlsx"],
label_visibility="collapsed"
)
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# File processing function
@st.cache_data
def process_file(uploaded_file):
if uploaded_file is None:
return ""
try:
if uploaded_file.type == "application/pdf":
pdf_reader = PyPDF2.PdfReader(uploaded_file)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
df = pd.read_excel(uploaded_file)
return df.to_markdown()
except Exception as e:
st.error(f"📄 Error processing file: {str(e)}")
return ""
# Model loading function
@st.cache_resource
def load_model(hf_token):
try:
if not hf_token:
st.error("🔐 Authentication required! Please provide a Hugging Face token.")
return None
login(token=hf_token)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
token=hf_token
)
# Load the model with appropriate dtype for CPU/GPU compatibility
dtype = torch.float16 if DEVICE == "cuda" else torch.float32
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_NAME,
token=hf_token,
torch_dtype=dtype,
device_map="auto" # Automatically maps to CPU or GPU
)
return model, tokenizer
except Exception as e:
st.error(f"🤖 Model loading failed: {str(e)}")
return None
# Generation function for translation with streaming
def generate_translation(input_text, model, tokenizer):
try:
# Tokenize the input (no prompt needed for seq2seq translation models)
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = inputs.to(DEVICE)
# Set up the streamer for real-time output
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
# Generate translation with streaming
model.eval()
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=512,
num_beams=5,
length_penalty=1.0,
early_stopping=True,
streamer=streamer,
return_dict_in_generate=True,
output_scores=True
)
# Decode the full output for storage and metrics
translation = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
return translation, streamer
except Exception as e:
raise Exception(f"Generation error: {str(e)}")
# Display chat messages
for message in st.session_state.messages:
try:
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
except:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input handling
if prompt := st.chat_input("Enter text to translate into French..."):
if not hf_token:
st.error("🔑 Authentication required!")
st.stop()
# Load model if not already loaded
if "model" not in st.session_state:
model_data = load_model(hf_token)
if model_data is None:
st.error("Failed to load model. Please check your token and try again.")
st.stop()
st.session_state.model, st.session_state.tokenizer = model_data
model = st.session_state.model
tokenizer = st.session_state.tokenizer
# Add user message
with st.chat_message("user", avatar=USER_AVATAR):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
# Process file or use prompt directly
file_context = process_file(uploaded_file)
input_text = file_context if file_context else prompt
# Generate translation with streaming
if model and tokenizer:
try:
with st.chat_message("assistant", avatar=BOT_AVATAR):
start_time = time.time()
# Create a placeholder for streaming output
response_container = st.empty()
full_response = ""
# Generate translation and stream output
translation, streamer = generate_translation(input_text, model, tokenizer)
# Streamlit will automatically display the streamed output via the TextStreamer
# Collect the full response for metrics and storage
full_response = translation
# Update the placeholder with the final response
response_container.markdown(full_response)
# Calculate performance metrics
end_time = time.time()
input_tokens = len(tokenizer(input_text)["input_ids"])
output_tokens = len(tokenizer(full_response)["input_ids"])
speed = output_tokens / (end_time - start_time)
# Calculate costs (hypothetical pricing model)
input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens
output_cost = (output_tokens / 1000000) * 15 # $15 per million output tokens
total_cost_usd = input_cost + output_cost
total_cost_aoa = total_cost_usd * 1160 # Convert to AOA (Angolan Kwanza)
# Display metrics
st.caption(
f"🔑 Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
f"🕒 Speed: {speed:.1f}t/s | 💰 Cost (USD): ${total_cost_usd:.4f} | "
f"💵 Cost (AOA): {total_cost_aoa:.4f}"
)
# Store the full response in chat history
st.session_state.messages.append({"role": "assistant", "content": full_response})
except Exception as e:
st.error(f"⚡ Translation error: {str(e)}")
else:
st.error("🤖 Model not loaded!")