24Sureshkumar's picture
Update app.py
f20a187 verified
raw
history blame
5.73 kB
import streamlit as st
import torch
import openai
import os
import time
from PIL import Image
import tempfile
import clip # from OpenAI CLIP repo
import torch.nn.functional as F
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
from rouge_score import rouge_scorer
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
device = "cuda" if torch.cuda.is_available() else "cpu"
openai.api_key = os.getenv("OPENAI_API_KEY") # Set this from env
# Load MBart
translator_model = MBartForConditionalGeneration.from_pretrained(
"facebook/mbart-large-50-many-to-many-mmt"
).to(device)
translator_tokenizer = MBart50TokenizerFast.from_pretrained(
"facebook/mbart-large-50-many-to-many-mmt"
)
translator_tokenizer.src_lang = "ta_IN"
# GPT-2
gen_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
gen_model.eval()
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
# CLIP
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
# ---- Translation ----
def translate_tamil_to_english(text, reference=None):
start = time.time()
inputs = translator_tokenizer(text, return_tensors="pt").to(device)
outputs = translator_model.generate(
**inputs,
forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
)
translated = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
duration = round(time.time() - start, 2)
rouge_l = None
if reference:
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
score = scorer.score(reference.lower(), translated.lower())
rouge_l = round(score["rougeL"].fmeasure, 4)
return translated, duration, rouge_l
# ---- Creative Text ----
def generate_creative_text(prompt, max_length=100):
start = time.time()
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
output = gen_model.generate(
input_ids,
max_length=max_length,
do_sample=True,
top_k=50,
temperature=0.9
)
text = gen_tokenizer.decode(output[0], skip_special_tokens=True)
duration = round(time.time() - start, 2)
tokens = text.split()
rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
with torch.no_grad():
input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
outputs = gen_model(input_ids, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4)
# ---- Image Generation using DALL·E 3 ----
def generate_image(prompt):
try:
start = time.time()
response = openai.images.generate(
model="dall-e-3",
prompt=prompt,
size="512x512",
quality="standard",
n=1
)
image_url = response.data[0].url
image_data = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix=".png"))
image_data = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))
# Save locally
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
image_data.save(tmp_file.name)
duration = round(time.time() - start, 2)
# CLIP similarity
image_input = clip_preprocess(image_data).unsqueeze(0).to(device)
text_input = clip.tokenize([prompt]).to(device)
with torch.no_grad():
image_features = clip_model.encode_image(image_input)
text_features = clip_model.encode_text(text_input)
similarity = F.cosine_similarity(image_features, text_features).item()
return tmp_file.name, duration, round(similarity, 4)
except Exception as e:
return None, None, f"Image generation failed: {str(e)}"
# ---- UI ----
st.set_page_config(page_title="Tamil → English + AI Art", layout="centered")
st.title("🧠 Tamil → English + 🎨 Creative Text + 🖼️ AI Image")
tamil_input = st.text_area("✍️ Enter Tamil text", height=150)
reference_input = st.text_input("📘 Optional: Reference English translation for ROUGE")
if st.button("🚀 Generate Output"):
if not tamil_input.strip():
st.warning("Please enter Tamil text.")
else:
with st.spinner("🔄 Translating..."):
english_text, t_time, rouge_l = translate_tamil_to_english(tamil_input, reference_input)
st.success(f"✅ Translated in {t_time}s")
st.markdown(f"**📝 English Translation:** `{english_text}`")
if rouge_l is not None:
st.markdown(f"📊 ROUGE-L Score: `{rouge_l}`")
with st.spinner("🖼️ Generating image..."):
image_path, img_time, clip_score = generate_image(english_text)
if image_path:
st.success(f"🖼️ Image generated in {img_time}s using OpenAI DALL·E 3")
st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
st.markdown(f"🔍 **CLIP Text-Image Similarity:** `{clip_score}`")
else:
st.error(clip_score)
with st.spinner("💡 Generating creative text..."):
creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text)
st.success(f"✨ Creative text in {c_time}s")
st.markdown(f"**🧠 Creative Output:** `{creative}`")
st.markdown(f"📌 Tokens: `{tokens}`, 🔁 Repetition Rate: `{rep_rate}`, 📉 Perplexity: `{ppl}`")
st.markdown("---")
st.caption("Built by Sureshkumar R | MBart + GPT-2 + OpenAI DALL·E 3")