|
import os |
|
import torch |
|
import gradio as gr |
|
from huggingface_hub import login |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSeq2SeqLM, |
|
GPT2LMHeadModel, |
|
GPT2Tokenizer |
|
) |
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
hf_token = os.getenv("HUGGINGFACE_TOKEN") |
|
if hf_token: |
|
login(token=hf_token) |
|
|
|
|
|
trans_checkpoint = "Hemanth-thunder/english-tamil-mt" |
|
trans_tokenizer = AutoTokenizer.from_pretrained(trans_checkpoint) |
|
trans_model = AutoModelForSeq2SeqLM.from_pretrained(trans_checkpoint) |
|
|
|
|
|
gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
gpt_model = GPT2LMHeadModel.from_pretrained("gpt2") |
|
gpt_model.eval() |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
sd_pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
use_auth_token=hf_token, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
).to(device) |
|
|
|
|
|
def tam_to_image_pipeline(tamil_text): |
|
|
|
inputs = trans_tokenizer(tamil_text, return_tensors="pt", truncation=True) |
|
translated_ids = trans_model.generate(**inputs, max_length=128) |
|
english_text = trans_tokenizer.decode(translated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
input_ids = gpt_tokenizer.encode(english_text, return_tensors="pt") |
|
with torch.no_grad(): |
|
gpt_output = gpt_model.generate( |
|
input_ids, |
|
max_length=60, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2, |
|
pad_token_id=gpt_tokenizer.eos_token_id |
|
) |
|
generated_text = gpt_tokenizer.decode(gpt_output[0], skip_special_tokens=True) |
|
|
|
|
|
image = sd_pipe(generated_text).images[0] |
|
|
|
return english_text, generated_text, image |
|
|
|
|
|
interface = gr.Interface( |
|
fn=tam_to_image_pipeline, |
|
inputs=gr.Textbox(label="Enter Tamil Text"), |
|
outputs=[ |
|
gr.Textbox(label="Translated English Text"), |
|
gr.Textbox(label="Generated Description"), |
|
gr.Image(label="Generated Image") |
|
], |
|
title="Tamil β Image Generator", |
|
description="π Tamil to English (M2M100) β GPTβ2 β Image via Stable Diffusion" |
|
) |
|
|
|
interface.launch() |
|
|