|
import streamlit as st |
|
import torch |
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel |
|
from diffusers import StableDiffusionPipeline |
|
from rouge_score import rouge_scorer |
|
from PIL import Image |
|
import tempfile |
|
import os |
|
import time |
|
import torch.nn.functional as F |
|
import clip |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
translator_model = MBartForConditionalGeneration.from_pretrained( |
|
"facebook/mbart-large-50-many-to-many-mmt" |
|
).to(device) |
|
translator_tokenizer = MBart50TokenizerFast.from_pretrained( |
|
"facebook/mbart-large-50-many-to-many-mmt" |
|
) |
|
translator_tokenizer.src_lang = "ta_IN" |
|
|
|
|
|
gen_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device) |
|
gen_model.eval() |
|
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
|
|
|
try: |
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1", |
|
torch_dtype=torch.float32, |
|
use_auth_token=os.getenv("HF_TOKEN") |
|
).to(device) |
|
pipe.safety_checker = None |
|
model_loaded = "stabilityai/stable-diffusion-2-1" |
|
except Exception as e: |
|
st.warning("β οΈ SD-2.1 failed. Using lightweight fallback model.") |
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"OFA-Sys/small-stable-diffusion-v0", |
|
torch_dtype=torch.float32 |
|
).to(device) |
|
pipe.safety_checker = None |
|
model_loaded = "OFA-Sys/small-stable-diffusion-v0" |
|
|
|
|
|
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
|
|
def translate_tamil_to_english(text, reference=None): |
|
start = time.time() |
|
inputs = translator_tokenizer(text, return_tensors="pt").to(device) |
|
outputs = translator_model.generate( |
|
**inputs, |
|
forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"] |
|
) |
|
translated = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
duration = round(time.time() - start, 2) |
|
|
|
rouge_l = None |
|
if reference: |
|
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) |
|
score = scorer.score(reference.lower(), translated.lower()) |
|
rouge_l = round(score["rougeL"].fmeasure, 4) |
|
|
|
return translated, duration, rouge_l |
|
|
|
|
|
def generate_creative_text(prompt, max_length=100): |
|
start = time.time() |
|
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
output = gen_model.generate( |
|
input_ids, |
|
max_length=max_length, |
|
do_sample=True, |
|
top_k=50, |
|
temperature=0.9 |
|
) |
|
text = gen_tokenizer.decode(output[0], skip_special_tokens=True) |
|
duration = round(time.time() - start, 2) |
|
|
|
tokens = text.split() |
|
rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0 |
|
|
|
|
|
with torch.no_grad(): |
|
input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device) |
|
outputs = gen_model(input_ids, labels=input_ids) |
|
loss = outputs.loss |
|
perplexity = torch.exp(loss).item() |
|
|
|
return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4) |
|
|
|
|
|
def generate_image(prompt): |
|
try: |
|
start = time.time() |
|
result = pipe(prompt) |
|
image = result.images[0].resize((256, 256)) |
|
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") |
|
image.save(tmp_file.name) |
|
duration = round(time.time() - start, 2) |
|
|
|
|
|
image_input = clip_preprocess(Image.open(tmp_file.name)).unsqueeze(0).to(device) |
|
text_input = clip.tokenize([prompt]).to(device) |
|
with torch.no_grad(): |
|
image_features = clip_model.encode_image(image_input) |
|
text_features = clip_model.encode_text(text_input) |
|
similarity = F.cosine_similarity(image_features, text_features).item() |
|
|
|
return tmp_file.name, duration, round(similarity, 4) |
|
except Exception as e: |
|
return None, None, f"Image generation failed: {str(e)}" |
|
|
|
|
|
st.set_page_config(page_title="Tamil β English + AI Art", layout="centered") |
|
st.title("π§ Tamil β English + π¨ Creative Text + πΌοΈ AI Image") |
|
|
|
tamil_input = st.text_area("βοΈ Enter Tamil text", height=150) |
|
reference_input = st.text_input("π Optional: Reference English translation for ROUGE") |
|
|
|
if st.button("π Generate Output"): |
|
if not tamil_input.strip(): |
|
st.warning("Please enter Tamil text.") |
|
else: |
|
with st.spinner("π Translating..."): |
|
english_text, t_time, rouge_l = translate_tamil_to_english(tamil_input, reference_input) |
|
|
|
st.success(f"β
Translated in {t_time}s") |
|
st.markdown(f"**π English Translation:** `{english_text}`") |
|
if rouge_l is not None: |
|
st.markdown(f"π ROUGE-L Score: `{rouge_l}`") |
|
|
|
with st.spinner("πΌοΈ Generating image..."): |
|
image_path, img_time, clip_score = generate_image(english_text) |
|
|
|
if image_path: |
|
st.success(f"πΌοΈ Image generated in {img_time}s using `{model_loaded}`") |
|
st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True) |
|
st.markdown(f"π **CLIP Text-Image Similarity:** `{clip_score}`") |
|
else: |
|
st.error(clip_score) |
|
|
|
with st.spinner("π‘ Generating creative text..."): |
|
creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text) |
|
|
|
st.success(f"β¨ Creative text in {c_time}s") |
|
st.markdown(f"**π§ Creative Output:** `{creative}`") |
|
st.markdown(f"π Tokens: `{tokens}`, π Repetition Rate: `{rep_rate}`, π Perplexity: `{ppl}`") |
|
|
|
st.markdown("---") |
|
st.caption("Built by Sureshkumar R | MBart + GPT-2 + Stable Diffusion + CLIP") |
|
|