24Sureshkumar's picture
Update app.py
c4a1141 verified
import streamlit as st
import torch
import os
import time
import tempfile
from PIL import Image
import torch.nn.functional as F
from transformers import (
MBartForConditionalGeneration,
MBart50TokenizerFast,
AutoTokenizer,
AutoModelForCausalLM,
CLIPProcessor,
CLIPModel,
)
from diffusers import StableDiffusionPipeline
from rouge_score import rouge_scorer
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load MBart tokenizer and model
translator_model = MBartForConditionalGeneration.from_pretrained(
"facebook/mbart-large-50-many-to-many-mmt"
).to(device)
translator_tokenizer = MBart50TokenizerFast.from_pretrained(
"facebook/mbart-large-50-many-to-many-mmt"
)
translator_tokenizer.src_lang = "ta_IN"
# Load GPT-2
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
gen_model.eval()
# Load Stable Diffusion
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
token=os.getenv("HF_TOKEN"),
torch_dtype=torch.float32,
).to(device)
pipe.safety_checker = None
# Load CLIP
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# ---------------- Functions ---------------- #
def translate_tamil_to_english(text, reference=None):
start = time.time()
inputs = translator_tokenizer(text, return_tensors="pt").to(device)
outputs = translator_model.generate(
**inputs,
forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
)
translated = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
duration = round(time.time() - start, 2)
rouge_l = None
if reference:
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
score = scorer.score(reference.lower(), translated.lower())
rouge_l = round(score["rougeL"].fmeasure, 4)
return translated, duration, rouge_l
def generate_creative_text(prompt, max_length=100):
start = time.time()
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
output = gen_model.generate(input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9)
text = gen_tokenizer.decode(output[0], skip_special_tokens=True)
duration = round(time.time() - start, 2)
tokens = text.split()
repetition_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
with torch.no_grad():
input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
outputs = gen_model(input_ids, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
return text, duration, len(tokens), round(repetition_rate, 4), round(perplexity, 4)
def generate_image(prompt):
try:
start = time.time()
result = pipe(prompt)
image = result.images[0].resize((256, 256))
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
image.save(tmp_file.name)
duration = round(time.time() - start, 2)
return tmp_file.name, duration, image
except Exception as e:
return None, 0, f"Image generation failed: {str(e)}"
def evaluate_clip_similarity(text, image):
inputs = clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
similarity_score = logits_per_image[0][0].item()
return round(similarity_score, 4)
# ---------------- Streamlit UI ---------------- #
st.set_page_config(page_title="Tamil β†’ English + AI Art", layout="centered")
st.title("🧠 Tamil β†’ English + 🎨 Creative Text + AI Image")
tamil_input = st.text_area("✍️ Enter Tamil text here", height=150)
reference_input = st.text_input("πŸ“˜ Optional: Reference English translation for ROUGE")
if st.button("πŸš€ Generate Output"):
if not tamil_input.strip():
st.warning("Please enter Tamil text.")
else:
with st.spinner("πŸ”„ Translating Tamil to English..."):
english_text, t_time, rouge_l = translate_tamil_to_english(tamil_input, reference_input)
st.success(f"βœ… Translated in {t_time} seconds")
st.markdown(f"**πŸ“ English Translation:** `{english_text}`")
if rouge_l is not None:
st.markdown(f"πŸ“Š **ROUGE-L Score:** `{rouge_l}`")
else:
st.info("ℹ️ ROUGE-L not calculated. Reference not provided.")
with st.spinner("🎨 Generating image..."):
image_path, img_time, image_obj = generate_image(english_text)
if isinstance(image_obj, Image.Image):
st.success(f"πŸ–ΌοΈ Image generated in {img_time} seconds")
st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
with st.spinner("πŸ”Ž Evaluating CLIP similarity..."):
clip_score = evaluate_clip_similarity(english_text, image_obj)
st.markdown(f"πŸ” **CLIP Text-Image Similarity:** `{clip_score}`")
else:
st.error(image_obj)
with st.spinner("πŸ’‘ Generating creative text..."):
creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text)
st.success(f"✨ Creative text generated in {c_time} seconds")
st.markdown(f"**🧠 Creative Output:** `{creative}`")
st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`, πŸ“‰ Perplexity: `{ppl}`")
st.markdown("---")
st.caption("Built by Sureshkumar R using MBart, GPT-2, Stable Diffusion 2.1, and CLIP on Hugging Face πŸ€—")