24Sureshkumar's picture
Update app.py
b88f708 verified
raw
history blame
3.01 kB
import os
import torch
from transformers import MBartForConditionalGeneration, MBart50Tokenizer, AutoTokenizer, AutoModelForCausalLM
from diffusers import StableDiffusionPipeline
from PIL import Image
import tempfile
import time
import streamlit as st
# Use CPU (Hugging Face Spaces free tier)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load translation model
translator_tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
translator_tokenizer.src_lang = "ta_IN"
# Load text generation model
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
# Load image generation model
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
torch_dtype=torch.float32,
safety_checker=None
).to(device)
def translate_tamil_to_english(text):
inputs = translator_tokenizer(text, return_tensors="pt").to(device)
output = translator_model.generate(
**inputs,
forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
)
translated = translator_tokenizer.batch_decode(output, skip_special_tokens=True)[0]
return translated
def generate_creative_text(prompt, max_length=100):
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
output = gen_model.generate(
input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9
)
return gen_tokenizer.decode(output[0], skip_special_tokens=True)
def generate_image(prompt):
image = pipe(prompt).images[0]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
image.save(temp_file.name)
return temp_file.name
# Streamlit UI
st.set_page_config(page_title="Tamil β†’ English + AI", layout="centered")
st.title("🌐 Tamil to English + AI Image Generator")
tamil_input = st.text_area("✍️ Enter Tamil Text", height=150)
if st.button("πŸš€ Generate"):
if not tamil_input.strip():
st.warning("Please enter Tamil text.")
else:
with st.spinner("Translating..."):
translated = translate_tamil_to_english(tamil_input)
st.success("βœ… Translated!")
st.markdown(f"**English:** `{translated}`")
with st.spinner("Generating creative text..."):
creative_text = generate_creative_text(translated)
st.success("βœ… Creative text generated!")
st.markdown(f"**Creative Prompt:** `{creative_text}`")
with st.spinner("Generating image..."):
image_path = generate_image(translated)
st.success("βœ… Image generated!")
st.image(Image.open(image_path), caption="πŸ–ΌοΈ AI Generated Image", use_column_width=True)
st.markdown("---")
st.markdown("πŸ”§ Powered by MBart, GPT2 & Stable Diffusion - Deployed on Hugging Face πŸ€—")