24Sureshkumar commited on
Commit
96ec789
Β·
verified Β·
1 Parent(s): c8f0e90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -45
app.py CHANGED
@@ -1,56 +1,73 @@
1
- import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
3
- from diffusers import StableDiffusionPipeline
 
4
  import torch
 
 
 
 
5
 
6
- # Cache models for faster loading
7
- @st.cache_resource
8
- def load_all_models():
9
- # Translation model
10
- translation_model = AutoModelForSeq2SeqLM.from_pretrained(
11
- "ai4bharat/indictrans2-indic-en-dist-200M", trust_remote_code=True
12
- )
13
- translation_tokenizer = AutoTokenizer.from_pretrained(
14
- "ai4bharat/indictrans2-indic-en-dist-200M", trust_remote_code=True
15
- )
16
- translation_pipeline = pipeline(
17
- "text2text-generation", model=translation_model, tokenizer=translation_tokenizer
18
- )
19
 
20
- # Image generation model (Stable Diffusion)
21
- img_pipe = StableDiffusionPipeline.from_pretrained(
22
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
23
- )
24
- img_pipe = img_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
25
-
26
- return translation_pipeline, img_pipe
27
 
28
- def main():
29
- st.title("πŸ“˜ Tamil to English Translator & Image Generator")
 
 
30
 
31
- tamil_text = st.text_area("πŸ“ Enter Tamil text (word or sentence)", height=100)
 
 
32
 
33
- if st.button("πŸ”„ Translate & Generate Image"):
34
- if not tamil_text.strip():
35
- st.warning("Please enter some Tamil text.")
36
- return
 
 
 
37
 
38
- try:
39
- translation_pipeline, img_pipe = load_all_models()
40
-
41
- # Prepare translation input
42
- formatted_input = "<2en><|ta|>" + tamil_text.strip()
43
- translated = translation_pipeline(formatted_input, max_length=256)[0]["generated_text"]
 
 
 
44
 
45
- st.success("βœ… English Translation:")
46
- st.write(translated)
 
 
 
47
 
48
- with st.spinner("πŸ–ΌοΈ Generating image..."):
49
- image = img_pipe(translated).images[0]
50
- st.image(image, caption="πŸ–ΌοΈ Generated from English text")
51
 
52
- except Exception as e:
53
- st.error(f"❌ Error: {str(e)}")
 
 
 
 
 
 
 
 
54
 
55
- if __name__ == "__main__":
56
- main()
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import time
4
+ import tempfile
5
  import torch
6
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoTokenizer, AutoModelForCausalLM
7
+ from diffusers import DiffusionPipeline
8
+ from PIL import Image
9
+ from rouge_score import rouge_scorer
10
 
11
+ # Device setup
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Hugging Face Token (required for image pipeline)
15
+ hf_token = os.getenv("HF_TOKEN", "your_token_here") # Replace with your token or set as environment variable
 
 
 
 
 
16
 
17
+ # Initialize translator (Tamil to English)
18
+ translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
19
+ translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
20
+ translator_tokenizer.src_lang = "ta_IN"
21
 
22
+ # Initialize text generator
23
+ gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
24
+ gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
25
 
26
+ # Initialize Stable Diffusion image pipeline
27
+ pipe = DiffusionPipeline.from_pretrained(
28
+ "stabilityai/stable-diffusion-2-1",
29
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
30
+ use_auth_token=hf_token
31
+ ).to(device)
32
+ pipe.safety_checker = None # Optional: disable safety checks
33
 
34
+ def translate_tamil_to_english(text, reference=None):
35
+ start = time.time()
36
+ inputs = translator_tokenizer(text, return_tensors="pt").to(device)
37
+ outputs = translator_model.generate(
38
+ **inputs,
39
+ forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
40
+ )
41
+ translation = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
42
+ duration = round(time.time() - start, 2)
43
 
44
+ rouge_l = None
45
+ if reference:
46
+ scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
47
+ scores = scorer.score(reference.lower(), translation.lower())
48
+ rouge_l = round(scores['rougeL'].fmeasure, 4)
49
 
50
+ return translation, duration, rouge_l
 
 
51
 
52
+ def generate_image(prompt):
53
+ try:
54
+ start = time.time()
55
+ out = pipe(prompt)
56
+ img = out.images[0].resize((256, 256))
57
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
58
+ img.save(tmp.name)
59
+ return tmp.name, round(time.time() - start, 2)
60
+ except Exception as e:
61
+ return None, f"Image generation failed: {e}"
62
 
63
+ def generate_creative_text(prompt, max_length=100):
64
+ start = time.time()
65
+ input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
66
+ out = gen_model.generate(
67
+ input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9
68
+ )
69
+ text = gen_tokenizer.decode(out[0], skip_special_tokens=True)
70
+ duration = round(time.time() - start, 2)
71
+ tokens = text.split()
72
+ repetition = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
73
+ return text, duration, len(tokens), round(repetition, 4)