24Sureshkumar commited on
Commit
ab46a3f
·
verified ·
1 Parent(s): b7ac1a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -74
app.py CHANGED
@@ -1,74 +1,44 @@
1
- # app.py
2
- import app_ui
3
- import os
4
- import time
5
- import tempfile
6
- import torch
7
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoTokenizer, AutoModelForCausalLM
8
- from diffusers import DiffusionPipeline
9
- from PIL import Image
10
- from rouge_score import rouge_scorer
11
-
12
- # Device setup
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
-
15
- # Hugging Face Token (required for image pipeline)
16
- hf_token = os.getenv("HF_TOKEN", "your_token_here") # Replace with your token or set as environment variable
17
-
18
- # Initialize translator (Tamil to English)
19
- translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
20
- translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
21
- translator_tokenizer.src_lang = "ta_IN"
22
-
23
- # Initialize text generator
24
- gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
25
- gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
26
-
27
- # Initialize Stable Diffusion image pipeline
28
- pipe = DiffusionPipeline.from_pretrained(
29
- "stabilityai/stable-diffusion-2-1",
30
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
31
- use_auth_token=hf_token
32
- ).to(device)
33
- pipe.safety_checker = None # Optional: disable safety checks
34
-
35
- def translate_tamil_to_english(text, reference=None):
36
- start = time.time()
37
- inputs = translator_tokenizer(text, return_tensors="pt").to(device)
38
- outputs = translator_model.generate(
39
- **inputs,
40
- forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
41
- )
42
- translation = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
43
- duration = round(time.time() - start, 2)
44
-
45
- rouge_l = None
46
- if reference:
47
- scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
48
- scores = scorer.score(reference.lower(), translation.lower())
49
- rouge_l = round(scores['rougeL'].fmeasure, 4)
50
-
51
- return translation, duration, rouge_l
52
-
53
- def generate_image(prompt):
54
- try:
55
- start = time.time()
56
- out = pipe(prompt)
57
- img = out.images[0].resize((256, 256))
58
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
59
- img.save(tmp.name)
60
- return tmp.name, round(time.time() - start, 2)
61
- except Exception as e:
62
- return None, f"Image generation failed: {e}"
63
-
64
- def generate_creative_text(prompt, max_length=100):
65
- start = time.time()
66
- input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
67
- out = gen_model.generate(
68
- input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9
69
- )
70
- text = gen_tokenizer.decode(out[0], skip_special_tokens=True)
71
- duration = round(time.time() - start, 2)
72
- tokens = text.split()
73
- repetition = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
74
- return text, duration, len(tokens), round(repetition, 4)
 
1
+ import gradio as gr
2
+ from core import translate_tamil_to_english, generate_image, generate_creative_text
3
+
4
+ def full_pipeline(tamil_text, reference_text=None):
5
+ if not tamil_text.strip():
6
+ return "Please enter Tamil text.", None, None, None, None, None, None, None, None
7
+
8
+ # Step 1: Translate Tamil to English
9
+ en_text, t_time, rouge_score = translate_tamil_to_english(tamil_text, reference_text)
10
+
11
+ # Step 2: Generate Image from English translation
12
+ image_path, img_time = generate_image(en_text)
13
+
14
+ # Step 3: Generate creative text from English translation
15
+ gen_text, gen_time, tokens, rep = generate_creative_text(en_text)
16
+
17
+ return en_text, t_time, rouge_score, image_path, img_time, gen_text, gen_time, tokens, rep
18
+
19
+ # Gradio Interface
20
+ demo = gr.Interface(
21
+ fn=full_pipeline,
22
+ inputs=[
23
+ gr.Textbox(label="✍️ Enter Tamil Text", lines=5, placeholder="எனது கனவு வீட்டை வர்ணிக்க..."),
24
+ gr.Textbox(label="📘 (Optional) Reference English Translation", lines=2)
25
+ ],
26
+ outputs=[
27
+ gr.Textbox(label="📝 English Translation"),
28
+ gr.Number(label="⏱️ Translation Time (s)"),
29
+ gr.Number(label="📊 ROUGE-L Score"),
30
+ gr.Image(label="🎨 Generated Image (256x256)"),
31
+ gr.Number(label="🖼️ Image Generation Time (s)"),
32
+ gr.Textbox(label="💡 Creative English Text"),
33
+ gr.Number(label="🕒 Text Generation Time (s)"),
34
+ gr.Number(label="🔢 Number of Tokens"),
35
+ gr.Number(label="♻️ Repetition Rate")
36
+ ],
37
+ title="🌐 Tamil to English Translator + Image & Text Generator",
38
+ description="Translate Tamil to English using MBart50 → Generate AI Image using StabilityAI → Generate Creative Text using GPT-2",
39
+ theme="soft",
40
+ allow_flagging="never"
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ demo.launch()