24Sureshkumar commited on
Commit
9c3ea11
Β·
verified Β·
1 Parent(s): 831a7b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -38
app.py CHANGED
@@ -1,52 +1,38 @@
1
  import streamlit as st
2
  import torch
 
 
 
 
 
3
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from diffusers import StableDiffusionPipeline
6
  from rouge_score import rouge_scorer
7
- from PIL import Image
8
- import tempfile
9
- import os
10
- import time
11
- from transformers import CLIPProcessor, CLIPModel
12
- import torch.nn.functional as F
13
 
14
- # Set device
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- # Load translation model
18
- translator_model = MBartForConditionalGeneration.from_pretrained(
19
- "facebook/mbart-large-50-many-to-many-mmt"
20
- ).to(device)
21
- translator_tokenizer = MBart50TokenizerFast.from_pretrained(
22
- "facebook/mbart-large-50-many-to-many-mmt"
23
- )
24
  translator_tokenizer.src_lang = "ta_IN"
25
 
26
- # Load GPT-2 for creative text
27
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
28
  gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
29
  gen_model.eval()
30
 
31
- # Load Stable Diffusion 1.5
32
- pipe = StableDiffusionPipeline.from_pretrained(
33
- "stabilityai/stable-diffusion-1-5",
34
- torch_dtype=torch.float32,
35
- ).to(device)
36
- pipe.safety_checker = None # Optional: disable safety filter
37
 
38
- # Load CLIP model
39
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
40
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
41
 
42
- # --- Translation ---
43
  def translate_tamil_to_english(text, reference=None):
44
  start = time.time()
45
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
46
- outputs = translator_model.generate(
47
- **inputs,
48
- forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
49
- )
50
  translated = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
51
  duration = round(time.time() - start, 2)
52
 
@@ -58,7 +44,6 @@ def translate_tamil_to_english(text, reference=None):
58
 
59
  return translated, duration, rouge_l
60
 
61
- # --- GPT-2 Creative Generation ---
62
  def generate_creative_text(prompt, max_length=100):
63
  start = time.time()
64
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
@@ -67,9 +52,8 @@ def generate_creative_text(prompt, max_length=100):
67
  duration = round(time.time() - start, 2)
68
 
69
  tokens = text.split()
70
- repetition_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
71
 
72
- # Perplexity
73
  with torch.no_grad():
74
  input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
75
  outputs = gen_model(input_ids, labels=input_ids)
@@ -78,7 +62,6 @@ def generate_creative_text(prompt, max_length=100):
78
 
79
  return text, duration, len(tokens), round(repetition_rate, 4), round(perplexity, 4)
80
 
81
- # --- Stable Diffusion Image Generation ---
82
  def generate_image(prompt):
83
  try:
84
  start = time.time()
@@ -91,7 +74,6 @@ def generate_image(prompt):
91
  except Exception as e:
92
  return None, 0, f"Image generation failed: {str(e)}"
93
 
94
- # --- CLIP Similarity ---
95
  def evaluate_clip_similarity(text, image):
96
  inputs = clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to(device)
97
  with torch.no_grad():
@@ -112,7 +94,7 @@ if st.button("πŸš€ Generate Output"):
112
  if not tamil_input.strip():
113
  st.warning("Please enter Tamil text.")
114
  else:
115
- with st.spinner("πŸ”„ Translating..."):
116
  english_text, t_time, rouge_l = translate_tamil_to_english(tamil_input, reference_input)
117
 
118
  st.success(f"βœ… Translated in {t_time}s")
@@ -120,7 +102,7 @@ if st.button("πŸš€ Generate Output"):
120
  if rouge_l is not None:
121
  st.markdown(f"πŸ“Š ROUGE-L Score: `{rouge_l}`")
122
 
123
- with st.spinner("🎨 Generating image..."):
124
  image_path, img_time, image_obj = generate_image(english_text)
125
 
126
  if isinstance(image_obj, Image.Image):
@@ -129,16 +111,16 @@ if st.button("πŸš€ Generate Output"):
129
 
130
  with st.spinner("πŸ”Ž Evaluating CLIP similarity..."):
131
  clip_score = evaluate_clip_similarity(english_text, image_obj)
132
- st.markdown(f"οΏ½οΏ½ CLIP Text-Image Similarity: `{clip_score}`")
133
  else:
134
  st.error(image_obj)
135
 
136
  with st.spinner("πŸ’‘ Generating creative text..."):
137
  creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text)
138
 
139
- st.success(f"✨ Creative text generated in {c_time}s")
140
  st.markdown(f"**🧠 Creative Output:** `{creative}`")
141
  st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`, πŸ“‰ Perplexity: `{ppl}`")
142
 
143
  st.markdown("---")
144
- st.caption("Built by Sureshkumar R using MBart, GPT-2, Stable Diffusion 1.5, and CLIP (Open Source)")
 
1
  import streamlit as st
2
  import torch
3
+ import torch.nn.functional as F
4
+ import os
5
+ import time
6
+ import tempfile
7
+ from PIL import Image
8
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPProcessor, CLIPModel
10
  from diffusers import StableDiffusionPipeline
11
  from rouge_score import rouge_scorer
 
 
 
 
 
 
12
 
13
+ # --- Device Setup ---
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ # --- Load Models ---
17
+ translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
18
+ translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
 
 
 
 
19
  translator_tokenizer.src_lang = "ta_IN"
20
 
 
21
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
22
  gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
23
  gen_model.eval()
24
 
25
+ pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-1-5").to(device)
26
+ pipe.safety_checker = None
 
 
 
 
27
 
 
28
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
29
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
30
 
31
+ # --- Functions ---
32
  def translate_tamil_to_english(text, reference=None):
33
  start = time.time()
34
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
35
+ outputs = translator_model.generate(**inputs, forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"])
 
 
 
36
  translated = translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
37
  duration = round(time.time() - start, 2)
38
 
 
44
 
45
  return translated, duration, rouge_l
46
 
 
47
  def generate_creative_text(prompt, max_length=100):
48
  start = time.time()
49
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
 
52
  duration = round(time.time() - start, 2)
53
 
54
  tokens = text.split()
55
+ repetition_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
56
 
 
57
  with torch.no_grad():
58
  input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
59
  outputs = gen_model(input_ids, labels=input_ids)
 
62
 
63
  return text, duration, len(tokens), round(repetition_rate, 4), round(perplexity, 4)
64
 
 
65
  def generate_image(prompt):
66
  try:
67
  start = time.time()
 
74
  except Exception as e:
75
  return None, 0, f"Image generation failed: {str(e)}"
76
 
 
77
  def evaluate_clip_similarity(text, image):
78
  inputs = clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to(device)
79
  with torch.no_grad():
 
94
  if not tamil_input.strip():
95
  st.warning("Please enter Tamil text.")
96
  else:
97
+ with st.spinner("πŸ”„ Translating Tamil to English..."):
98
  english_text, t_time, rouge_l = translate_tamil_to_english(tamil_input, reference_input)
99
 
100
  st.success(f"βœ… Translated in {t_time}s")
 
102
  if rouge_l is not None:
103
  st.markdown(f"πŸ“Š ROUGE-L Score: `{rouge_l}`")
104
 
105
+ with st.spinner("πŸ–ΌοΈ Generating image from text..."):
106
  image_path, img_time, image_obj = generate_image(english_text)
107
 
108
  if isinstance(image_obj, Image.Image):
 
111
 
112
  with st.spinner("πŸ”Ž Evaluating CLIP similarity..."):
113
  clip_score = evaluate_clip_similarity(english_text, image_obj)
114
+ st.markdown(f"πŸ” **CLIP Text-Image Similarity:** `{clip_score}`")
115
  else:
116
  st.error(image_obj)
117
 
118
  with st.spinner("πŸ’‘ Generating creative text..."):
119
  creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text)
120
 
121
+ st.success(f"✨ Creative text in {c_time}s")
122
  st.markdown(f"**🧠 Creative Output:** `{creative}`")
123
  st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`, πŸ“‰ Perplexity: `{ppl}`")
124
 
125
  st.markdown("---")
126
+ st.caption("Built by Sureshkumar R | MBart + GPT-2 + Stable Diffusion + CLIP")