24Sureshkumar's picture
Update app.py
207023f verified
raw
history blame
2.92 kB
# app.py
import app_ui
import os
import time
import tempfile
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoTokenizer, AutoModelForCausalLM
from diffusers import DiffusionPipeline
from PIL import Image
from rouge_score import rouge_scorer
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# Hugging Face Token (required for image pipeline)
hf_token = os.getenv("HF_TOKEN", "your_token_here") # Replace with your token or set as environment variable
# Initialize translator (Tamil to English)
translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
translator_tokenizer.src_lang = "ta_IN"
# Initialize text generator
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
# Initialize Stable Diffusion image pipeline
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=hf_token
).to(device)
pipe.safety_checker = None # Optional: disable safety checks
def translate_tamil_to_english(text, reference=None):
start = time.time()
inputs = translator_tokenizer(text, return_tensors="pt").to(device)
outputs = translator_model.generate(
**inputs,
forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
)
translation = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
duration = round(time.time() - start, 2)
rouge_l = None
if reference:
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
scores = scorer.score(reference.lower(), translation.lower())
rouge_l = round(scores['rougeL'].fmeasure, 4)
return translation, duration, rouge_l
def generate_image(prompt):
try:
start = time.time()
out = pipe(prompt)
img = out.images[0].resize((256, 256))
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
img.save(tmp.name)
return tmp.name, round(time.time() - start, 2)
except Exception as e:
return None, f"Image generation failed: {e}"
def generate_creative_text(prompt, max_length=100):
start = time.time()
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
out = gen_model.generate(
input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9
)
text = gen_tokenizer.decode(out[0], skip_special_tokens=True)
duration = round(time.time() - start, 2)
tokens = text.split()
repetition = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
return text, duration, len(tokens), round(repetition, 4)