24Sureshkumar commited on
Commit
ab52a13
Β·
verified Β·
1 Parent(s): 3fd89d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -20
app.py CHANGED
@@ -1,38 +1,61 @@
 
1
  import streamlit as st
2
  import torch
3
- import torch.nn.functional as F
4
  import os
5
  import time
6
  import tempfile
7
  from PIL import Image
8
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPProcessor, CLIPModel
 
 
 
 
 
 
 
 
10
  from diffusers import StableDiffusionPipeline
11
  from rouge_score import rouge_scorer
12
 
13
- # --- Device Setup ---
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- # --- Load Models ---
17
- translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
18
- translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
 
 
 
 
19
  translator_tokenizer.src_lang = "ta_IN"
20
 
 
21
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
22
  gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
23
  gen_model.eval()
24
 
25
- pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-1-5").to(device)
 
 
 
 
 
26
  pipe.safety_checker = None
27
 
 
28
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
29
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
30
 
31
- # --- Functions ---
 
32
  def translate_tamil_to_english(text, reference=None):
33
  start = time.time()
34
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
35
- outputs = translator_model.generate(**inputs, forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"])
 
 
 
36
  translated = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
37
  duration = round(time.time() - start, 2)
38
 
@@ -52,7 +75,7 @@ def generate_creative_text(prompt, max_length=100):
52
  duration = round(time.time() - start, 2)
53
 
54
  tokens = text.split()
55
- repetition_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
56
 
57
  with torch.no_grad():
58
  input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
@@ -79,15 +102,15 @@ def evaluate_clip_similarity(text, image):
79
  with torch.no_grad():
80
  outputs = clip_model(**inputs)
81
  logits_per_image = outputs.logits_per_image
82
- probs = F.softmax(logits_per_image, dim=1)
83
  similarity_score = logits_per_image[0][0].item()
84
  return round(similarity_score, 4)
85
 
86
- # --- Streamlit UI ---
 
87
  st.set_page_config(page_title="Tamil β†’ English + AI Art", layout="centered")
88
  st.title("🧠 Tamil β†’ English + 🎨 Creative Text + AI Image")
89
 
90
- tamil_input = st.text_area("✍️ Enter Tamil text", height=150)
91
  reference_input = st.text_input("πŸ“˜ Optional: Reference English translation for ROUGE")
92
 
93
  if st.button("πŸš€ Generate Output"):
@@ -97,16 +120,18 @@ if st.button("πŸš€ Generate Output"):
97
  with st.spinner("πŸ”„ Translating Tamil to English..."):
98
  english_text, t_time, rouge_l = translate_tamil_to_english(tamil_input, reference_input)
99
 
100
- st.success(f"βœ… Translated in {t_time}s")
101
  st.markdown(f"**πŸ“ English Translation:** `{english_text}`")
102
  if rouge_l is not None:
103
- st.markdown(f"πŸ“Š ROUGE-L Score: `{rouge_l}`")
 
 
104
 
105
- with st.spinner("πŸ–ΌοΈ Generating image from text..."):
106
  image_path, img_time, image_obj = generate_image(english_text)
107
 
108
  if isinstance(image_obj, Image.Image):
109
- st.success(f"πŸ–ΌοΈ Image generated in {img_time}s")
110
  st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
111
 
112
  with st.spinner("πŸ”Ž Evaluating CLIP similarity..."):
@@ -118,9 +143,9 @@ if st.button("πŸš€ Generate Output"):
118
  with st.spinner("πŸ’‘ Generating creative text..."):
119
  creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text)
120
 
121
- st.success(f"✨ Creative text in {c_time}s")
122
  st.markdown(f"**🧠 Creative Output:** `{creative}`")
123
  st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`, πŸ“‰ Perplexity: `{ppl}`")
124
 
125
  st.markdown("---")
126
- st.caption("Built by Sureshkumar R | MBart + GPT-2 + Stable Diffusion + CLIP")
 
1
+ %%writefile app.py
2
  import streamlit as st
3
  import torch
 
4
  import os
5
  import time
6
  import tempfile
7
  from PIL import Image
8
+ import torch.nn.functional as F
9
+
10
+ from transformers import (
11
+ MBartForConditionalGeneration,
12
+ MBart50TokenizerFast,
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ CLIPProcessor,
16
+ CLIPModel,
17
+ )
18
  from diffusers import StableDiffusionPipeline
19
  from rouge_score import rouge_scorer
20
 
21
+ # Set device
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
24
+ # Load MBart tokenizer and model
25
+ translator_model = MBartForConditionalGeneration.from_pretrained(
26
+ "facebook/mbart-large-50-many-to-many-mmt"
27
+ ).to(device)
28
+ translator_tokenizer = MBart50TokenizerFast.from_pretrained(
29
+ "facebook/mbart-large-50-many-to-many-mmt"
30
+ )
31
  translator_tokenizer.src_lang = "ta_IN"
32
 
33
+ # Load GPT-2
34
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
35
  gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
36
  gen_model.eval()
37
 
38
+ # Load Stable Diffusion
39
+ pipe = StableDiffusionPipeline.from_pretrained(
40
+ "stabilityai/stable-diffusion-2-1",
41
+ token=os.getenv("HF_TOKEN"),
42
+ torch_dtype=torch.float32,
43
+ ).to(device)
44
  pipe.safety_checker = None
45
 
46
+ # Load CLIP
47
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
48
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
49
 
50
+ # ---------------- Functions ---------------- #
51
+
52
  def translate_tamil_to_english(text, reference=None):
53
  start = time.time()
54
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
55
+ outputs = translator_model.generate(
56
+ **inputs,
57
+ forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
58
+ )
59
  translated = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
60
  duration = round(time.time() - start, 2)
61
 
 
75
  duration = round(time.time() - start, 2)
76
 
77
  tokens = text.split()
78
+ repetition_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
79
 
80
  with torch.no_grad():
81
  input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
 
102
  with torch.no_grad():
103
  outputs = clip_model(**inputs)
104
  logits_per_image = outputs.logits_per_image
 
105
  similarity_score = logits_per_image[0][0].item()
106
  return round(similarity_score, 4)
107
 
108
+ # ---------------- Streamlit UI ---------------- #
109
+
110
  st.set_page_config(page_title="Tamil β†’ English + AI Art", layout="centered")
111
  st.title("🧠 Tamil β†’ English + 🎨 Creative Text + AI Image")
112
 
113
+ tamil_input = st.text_area("✍️ Enter Tamil text here", height=150)
114
  reference_input = st.text_input("πŸ“˜ Optional: Reference English translation for ROUGE")
115
 
116
  if st.button("πŸš€ Generate Output"):
 
120
  with st.spinner("πŸ”„ Translating Tamil to English..."):
121
  english_text, t_time, rouge_l = translate_tamil_to_english(tamil_input, reference_input)
122
 
123
+ st.success(f"βœ… Translated in {t_time} seconds")
124
  st.markdown(f"**πŸ“ English Translation:** `{english_text}`")
125
  if rouge_l is not None:
126
+ st.markdown(f"πŸ“Š **ROUGE-L Score:** `{rouge_l}`")
127
+ else:
128
+ st.info("ℹ️ ROUGE-L not calculated. Reference not provided.")
129
 
130
+ with st.spinner("🎨 Generating image..."):
131
  image_path, img_time, image_obj = generate_image(english_text)
132
 
133
  if isinstance(image_obj, Image.Image):
134
+ st.success(f"πŸ–ΌοΈ Image generated in {img_time} seconds")
135
  st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
136
 
137
  with st.spinner("πŸ”Ž Evaluating CLIP similarity..."):
 
143
  with st.spinner("πŸ’‘ Generating creative text..."):
144
  creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text)
145
 
146
+ st.success(f"✨ Creative text generated in {c_time} seconds")
147
  st.markdown(f"**🧠 Creative Output:** `{creative}`")
148
  st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`, πŸ“‰ Perplexity: `{ppl}`")
149
 
150
  st.markdown("---")
151
+ st.caption("Built by Sureshkumar R using MBart, GPT-2, Stable Diffusion 2.1, and CLIP on Hugging Face πŸ€—")