import streamlit as st import torch import openai import os import time from PIL import Image import tempfile import clip # from OpenAI CLIP repo import torch.nn.functional as F from transformers import MBartForConditionalGeneration, MBart50TokenizerFast from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel from rouge_score import rouge_scorer from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize device = "cuda" if torch.cuda.is_available() else "cpu" openai.api_key = os.getenv("OPENAI_API_KEY") # Set this from env # Load MBart translator_model = MBartForConditionalGeneration.from_pretrained( "facebook/mbart-large-50-many-to-many-mmt" ).to(device) translator_tokenizer = MBart50TokenizerFast.from_pretrained( "facebook/mbart-large-50-many-to-many-mmt" ) translator_tokenizer.src_lang = "ta_IN" # GPT-2 gen_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device) gen_model.eval() gen_tokenizer = AutoTokenizer.from_pretrained("gpt2") # CLIP clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) # ---- Translation ---- def translate_tamil_to_english(text, reference=None): start = time.time() inputs = translator_tokenizer(text, return_tensors="pt").to(device) outputs = translator_model.generate( **inputs, forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"] ) translated = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] duration = round(time.time() - start, 2) rouge_l = None if reference: scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) score = scorer.score(reference.lower(), translated.lower()) rouge_l = round(score["rougeL"].fmeasure, 4) return translated, duration, rouge_l # ---- Creative Text ---- def generate_creative_text(prompt, max_length=100): start = time.time() input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device) output = gen_model.generate( input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9 ) text = gen_tokenizer.decode(output[0], skip_special_tokens=True) duration = round(time.time() - start, 2) tokens = text.split() rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0 with torch.no_grad(): input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device) outputs = gen_model(input_ids, labels=input_ids) loss = outputs.loss perplexity = torch.exp(loss).item() return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4) # ---- Image Generation using DALLΒ·E 3 ---- def generate_image(prompt): try: start = time.time() response = openai.images.generate( model="dall-e-3", prompt=prompt, size="512x512", quality="standard", n=1 ) image_url = response.data[0].url image_data = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix=".png")) image_data = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256)) # Save locally tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") image_data.save(tmp_file.name) duration = round(time.time() - start, 2) # CLIP similarity image_input = clip_preprocess(image_data).unsqueeze(0).to(device) text_input = clip.tokenize([prompt]).to(device) with torch.no_grad(): image_features = clip_model.encode_image(image_input) text_features = clip_model.encode_text(text_input) similarity = F.cosine_similarity(image_features, text_features).item() return tmp_file.name, duration, round(similarity, 4) except Exception as e: return None, None, f"Image generation failed: {str(e)}" # ---- UI ---- st.set_page_config(page_title="Tamil β†’ English + AI Art", layout="centered") st.title("🧠 Tamil β†’ English + 🎨 Creative Text + πŸ–ΌοΈ AI Image") tamil_input = st.text_area("✍️ Enter Tamil text", height=150) reference_input = st.text_input("πŸ“˜ Optional: Reference English translation for ROUGE") if st.button("πŸš€ Generate Output"): if not tamil_input.strip(): st.warning("Please enter Tamil text.") else: with st.spinner("πŸ”„ Translating..."): english_text, t_time, rouge_l = translate_tamil_to_english(tamil_input, reference_input) st.success(f"βœ… Translated in {t_time}s") st.markdown(f"**πŸ“ English Translation:** `{english_text}`") if rouge_l is not None: st.markdown(f"πŸ“Š ROUGE-L Score: `{rouge_l}`") with st.spinner("πŸ–ΌοΈ Generating image..."): image_path, img_time, clip_score = generate_image(english_text) if image_path: st.success(f"πŸ–ΌοΈ Image generated in {img_time}s using OpenAI DALLΒ·E 3") st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True) st.markdown(f"πŸ” **CLIP Text-Image Similarity:** `{clip_score}`") else: st.error(clip_score) with st.spinner("πŸ’‘ Generating creative text..."): creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text) st.success(f"✨ Creative text in {c_time}s") st.markdown(f"**🧠 Creative Output:** `{creative}`") st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`, πŸ“‰ Perplexity: `{ppl}`") st.markdown("---") st.caption("Built by Sureshkumar R | MBart + GPT-2 + OpenAI DALLΒ·E 3")