|
|
|
import app_ui |
|
import os |
|
import time |
|
import tempfile |
|
import torch |
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoTokenizer, AutoModelForCausalLM |
|
from diffusers import DiffusionPipeline |
|
from PIL import Image |
|
from rouge_score import rouge_scorer |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN", "your_token_here") |
|
|
|
|
|
translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device) |
|
translator_tokenizer.src_lang = "ta_IN" |
|
|
|
|
|
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device) |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1", |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
use_auth_token=hf_token |
|
).to(device) |
|
pipe.safety_checker = None |
|
|
|
def translate_tamil_to_english(text, reference=None): |
|
start = time.time() |
|
inputs = translator_tokenizer(text, return_tensors="pt").to(device) |
|
outputs = translator_model.generate( |
|
**inputs, |
|
forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"] |
|
) |
|
translation = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
duration = round(time.time() - start, 2) |
|
|
|
rouge_l = None |
|
if reference: |
|
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) |
|
scores = scorer.score(reference.lower(), translation.lower()) |
|
rouge_l = round(scores['rougeL'].fmeasure, 4) |
|
|
|
return translation, duration, rouge_l |
|
|
|
def generate_image(prompt): |
|
try: |
|
start = time.time() |
|
out = pipe(prompt) |
|
img = out.images[0].resize((256, 256)) |
|
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") |
|
img.save(tmp.name) |
|
return tmp.name, round(time.time() - start, 2) |
|
except Exception as e: |
|
return None, f"Image generation failed: {e}" |
|
|
|
def generate_creative_text(prompt, max_length=100): |
|
start = time.time() |
|
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
out = gen_model.generate( |
|
input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9 |
|
) |
|
text = gen_tokenizer.decode(out[0], skip_special_tokens=True) |
|
duration = round(time.time() - start, 2) |
|
tokens = text.split() |
|
repetition = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) |
|
return text, duration, len(tokens), round(repetition, 4) |
|
|