|
import streamlit as st |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline |
|
from diffusers import StableDiffusionPipeline |
|
import torch |
|
|
|
@st.cache_resource |
|
def load_all_models(): |
|
|
|
trans_model_id = "ai4bharat/indictrans2-indic-en-dist-200M" |
|
trans_tokenizer = AutoTokenizer.from_pretrained(trans_model_id, trust_remote_code=True) |
|
trans_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_id, trust_remote_code=True) |
|
|
|
|
|
text_gen = pipeline("text-generation", model="gpt2") |
|
|
|
|
|
img_pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch.float16, |
|
revision="fp16" |
|
).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
return trans_tokenizer, trans_model, text_gen, img_pipe |
|
|
|
def translate_text(text, tokenizer, model): |
|
input_text = f"translate Tamil to English: {text}" |
|
inputs = tokenizer(input_text, return_tensors="pt", padding=True) |
|
outputs = model.generate(**inputs, max_length=128) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def main(): |
|
st.title("Multimodal Tamil to Image Generator π") |
|
st.markdown("Enter Tamil text, we translate it to English, continue the sentence, and generate an image!") |
|
|
|
user_input = st.text_area("Enter Tamil text:", "") |
|
|
|
if st.button("Generate"): |
|
with st.spinner("Loading models..."): |
|
tokenizer, model, text_gen, img_pipe = load_all_models() |
|
|
|
with st.spinner("Translating to English..."): |
|
english_text = translate_text(user_input, tokenizer, model) |
|
st.subheader("Translated English:") |
|
st.write(english_text) |
|
|
|
with st.spinner("Generating continuation..."): |
|
continuation = text_gen(english_text, max_length=50, do_sample=True)[0]['generated_text'] |
|
st.subheader("Generated Text:") |
|
st.write(continuation) |
|
|
|
with st.spinner("Generating Image..."): |
|
image = img_pipe(continuation).images[0] |
|
st.subheader("Generated Image:") |
|
st.image(image) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|