24Sureshkumar's picture
Update app.py
2d8ff37 verified
raw
history blame
2.29 kB
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from diffusers import StableDiffusionPipeline
import torch
@st.cache_resource
def load_all_models():
# Load IndicTrans2 Tamil-to-English model
trans_model_id = "ai4bharat/indictrans2-indic-en-dist-200M"
trans_tokenizer = AutoTokenizer.from_pretrained(trans_model_id, trust_remote_code=True)
trans_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_id, trust_remote_code=True)
# Load English text generation model (you can use GPT2 or Falcon, etc.)
text_gen = pipeline("text-generation", model="gpt2")
# Load Stable Diffusion for image generation
img_pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
revision="fp16"
).to("cuda" if torch.cuda.is_available() else "cpu")
return trans_tokenizer, trans_model, text_gen, img_pipe
def translate_text(text, tokenizer, model):
input_text = f"translate Tamil to English: {text}"
inputs = tokenizer(input_text, return_tensors="pt", padding=True)
outputs = model.generate(**inputs, max_length=128)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def main():
st.title("Multimodal Tamil to Image Generator πŸš€")
st.markdown("Enter Tamil text, we translate it to English, continue the sentence, and generate an image!")
user_input = st.text_area("Enter Tamil text:", "")
if st.button("Generate"):
with st.spinner("Loading models..."):
tokenizer, model, text_gen, img_pipe = load_all_models()
with st.spinner("Translating to English..."):
english_text = translate_text(user_input, tokenizer, model)
st.subheader("Translated English:")
st.write(english_text)
with st.spinner("Generating continuation..."):
continuation = text_gen(english_text, max_length=50, do_sample=True)[0]['generated_text']
st.subheader("Generated Text:")
st.write(continuation)
with st.spinner("Generating Image..."):
image = img_pipe(continuation).images[0]
st.subheader("Generated Image:")
st.image(image)
if __name__ == "__main__":
main()