|
import os |
|
import gradio as gr |
|
import torch |
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, pipeline |
|
from diffusers import StableDiffusionPipeline |
|
from PIL import Image |
|
|
|
|
|
try: |
|
translator = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M") |
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") |
|
tokenizer.src_lang = "ta" |
|
except Exception as e: |
|
print(f"Error loading M2M100 model: {e}") |
|
translator = tokenizer = None |
|
|
|
|
|
try: |
|
text_generator = pipeline("text-generation", model="gpt2") |
|
except Exception as e: |
|
print(f"Error loading GPT-2 model: {e}") |
|
text_generator = None |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
try: |
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch.float16 if device=="cuda" else torch.float32, |
|
use_auth_token=hf_token |
|
) |
|
pipe = pipe.to(device) |
|
|
|
if device == "cuda": |
|
pipe.enable_attention_slicing() |
|
except Exception as e: |
|
print(f"Error loading Stable Diffusion pipeline: {e}") |
|
pipe = None |
|
|
|
def tamil_to_image(tamil_text): |
|
""" |
|
Translate Tamil text to English, generate new text with GPT-2, |
|
and produce an image with Stable Diffusion. |
|
Returns (PIL.Image, info_text). |
|
""" |
|
if not tamil_text or not tamil_text.strip(): |
|
return None, "Error: Please enter Tamil text as input." |
|
|
|
|
|
try: |
|
tokenizer.src_lang = "ta" |
|
encoded = tokenizer(tamil_text, return_tensors="pt") |
|
generated_tokens = translator.generate( |
|
**encoded, forced_bos_token_id=tokenizer.get_lang_id("en") |
|
) |
|
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
except Exception as e: |
|
return None, f"Translation error: {e}" |
|
|
|
|
|
try: |
|
gen = text_generator(translation, max_length=50, num_return_sequences=1) |
|
gen_text = gen[0]['generated_text'] if isinstance(gen, list) else gen['generated_text'] |
|
except Exception as e: |
|
return None, f"Text generation error: {e}" |
|
|
|
|
|
try: |
|
|
|
prompt = gen_text |
|
if device == "cuda": |
|
image = pipe(prompt, num_inference_steps=50).images[0] |
|
else: |
|
|
|
image = pipe(prompt, num_inference_steps=25).images[0] |
|
except Exception as e: |
|
return None, f"Image generation error: {e}" |
|
|
|
info = f"Translated → English: {translation}\nGPT-2 Prompt: {prompt}" |
|
return image, info |
|
|
|
|
|
iface = gr.Interface( |
|
fn=tamil_to_image, |
|
inputs=gr.Textbox(label="Tamil Input", placeholder="Enter Tamil text here", type="text"), |
|
outputs=[ |
|
gr.Image(type="pil", label="Generated Image"), |
|
gr.Textbox(label="Output Info", type="text") |
|
], |
|
title="Tamil Text-to-Image Generator", |
|
description="Enter Tamil text; this demo translates it to English, generates a story prompt with GPT-2, then creates an image with Stable Diffusion." |
|
) |
|
|
|
|
|
iface.launch() |
|
|