24Sureshkumar commited on
Commit
f837ee9
Β·
verified Β·
1 Parent(s): ee418a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -47
app.py CHANGED
@@ -1,35 +1,50 @@
1
  import streamlit as st
2
  import torch
3
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from diffusers import StableDiffusionPipeline
6
  from rouge_score import rouge_scorer
7
  from PIL import Image
8
- import clip
9
  import tempfile
10
  import os
11
- import math
12
  import time
 
 
 
13
 
14
- # Device setup
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # Translation model
18
- translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
19
- translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
 
 
 
 
20
  translator_tokenizer.src_lang = "ta_IN"
21
 
22
- # GPT-2 for creative text
23
- gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
 
24
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
25
 
26
- # Stable Diffusion v1.4
27
- pipe = StableDiffusionPipeline.from_pretrained(
28
- "stabilityai/stable-diffusion-1-4",
29
- torch_dtype=torch.float32,
30
- use_auth_token=os.getenv("HF_TOKEN") # set this on Hugging Face Spaces
31
- ).to(device)
32
- pipe.safety_checker = None # Optional
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Load CLIP for image-text similarity
35
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
@@ -53,88 +68,89 @@ def translate_tamil_to_english(text, reference=None):
53
 
54
  return translated, duration, rouge_l
55
 
56
- # Text generation with repetition & perplexity
57
  def generate_creative_text(prompt, max_length=100):
58
  start = time.time()
59
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
60
  output = gen_model.generate(
61
- input_ids, max_length=max_length,
62
- do_sample=True, top_k=50, temperature=0.9
 
 
 
63
  )
64
  text = gen_tokenizer.decode(output[0], skip_special_tokens=True)
65
  duration = round(time.time() - start, 2)
66
 
67
  tokens = text.split()
68
- repetition_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
69
 
70
- # Perplexity
71
  with torch.no_grad():
 
72
  outputs = gen_model(input_ids, labels=input_ids)
73
  loss = outputs.loss
74
- perplexity = math.exp(loss.item())
75
 
76
- return text, duration, len(tokens), round(repetition_rate, 4), round(perplexity, 3)
77
 
78
- # Image generation + CLIP similarity
79
  def generate_image(prompt):
80
  try:
81
  start = time.time()
82
  result = pipe(prompt)
83
  image = result.images[0].resize((256, 256))
84
- duration = round(time.time() - start, 2)
85
-
86
- # Save image
87
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
88
  image.save(tmp_file.name)
 
89
 
90
- # CLIP similarity
91
- image_input = clip_preprocess(image).unsqueeze(0).to(device)
92
- text_input = clip.tokenize(prompt).to(device)
93
  with torch.no_grad():
94
  image_features = clip_model.encode_image(image_input)
95
  text_features = clip_model.encode_text(text_input)
96
- similarity = torch.cosine_similarity(image_features, text_features).item()
97
 
98
  return tmp_file.name, duration, round(similarity, 4)
99
  except Exception as e:
100
- return None, 0, f"Image generation failed: {str(e)}"
101
 
102
  # Streamlit UI
103
  st.set_page_config(page_title="Tamil β†’ English + AI Art", layout="centered")
104
- st.title("🧠 Tamil β†’ English + 🎨 Creative Text + AI Image")
105
 
106
- tamil_input = st.text_area("✍️ Enter Tamil text here", height=150)
107
  reference_input = st.text_input("πŸ“˜ Optional: Reference English translation for ROUGE")
108
 
109
  if st.button("πŸš€ Generate Output"):
110
  if not tamil_input.strip():
111
  st.warning("Please enter Tamil text.")
112
  else:
113
- with st.spinner("πŸ”„ Translating Tamil to English..."):
114
  english_text, t_time, rouge_l = translate_tamil_to_english(tamil_input, reference_input)
115
 
116
- st.success(f"βœ… Translated in {t_time} seconds")
117
  st.markdown(f"**πŸ“ English Translation:** `{english_text}`")
118
  if rouge_l is not None:
119
- st.markdown(f"πŸ“Š **ROUGE-L Score:** `{rouge_l}`")
120
 
121
  with st.spinner("πŸ–ΌοΈ Generating image..."):
122
- image_path, img_time, similarity = generate_image(english_text)
123
 
124
- if isinstance(similarity, float):
125
- st.success(f"πŸ–ΌοΈ Image generated in {img_time} seconds")
126
  st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
127
- st.markdown(f"🎯 **CLIP Text-Image Similarity:** `{similarity}`")
128
  else:
129
- st.error(similarity)
130
 
131
  with st.spinner("πŸ’‘ Generating creative text..."):
132
- creative, c_time, tokens, rep_rate, perplexity = generate_creative_text(english_text)
133
 
134
- st.success(f"✨ Creative text generated in {c_time} seconds")
135
  st.markdown(f"**🧠 Creative Output:** `{creative}`")
136
- st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`")
137
- st.markdown(f"πŸ“‰ Perplexity: `{perplexity}`")
138
 
139
  st.markdown("---")
140
- st.caption("Built by Sureshkumar R using MBart, GPT-2 & Stable Diffusion on Hugging Face")
 
1
  import streamlit as st
2
  import torch
3
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
5
  from diffusers import StableDiffusionPipeline
6
  from rouge_score import rouge_scorer
7
  from PIL import Image
 
8
  import tempfile
9
  import os
 
10
  import time
11
+ import torch.nn.functional as F
12
+ import clip # from OpenAI CLIP repo
13
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
14
 
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
+ # Load MBart model
18
+ translator_model = MBartForConditionalGeneration.from_pretrained(
19
+ "facebook/mbart-large-50-many-to-many-mmt"
20
+ ).to(device)
21
+ translator_tokenizer = MBart50TokenizerFast.from_pretrained(
22
+ "facebook/mbart-large-50-many-to-many-mmt"
23
+ )
24
  translator_tokenizer.src_lang = "ta_IN"
25
 
26
+ # Load GPT-2
27
+ gen_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
28
+ gen_model.eval()
29
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
30
 
31
+ # Try loading SD-2.1, fallback to lightweight
32
+ try:
33
+ pipe = StableDiffusionPipeline.from_pretrained(
34
+ "stabilityai/stable-diffusion-2-1",
35
+ torch_dtype=torch.float32,
36
+ use_auth_token=os.getenv("HF_TOKEN")
37
+ ).to(device)
38
+ pipe.safety_checker = None
39
+ model_loaded = "stabilityai/stable-diffusion-2-1"
40
+ except Exception as e:
41
+ st.warning("⚠️ SD-2.1 failed. Using lightweight fallback model.")
42
+ pipe = StableDiffusionPipeline.from_pretrained(
43
+ "OFA-Sys/small-stable-diffusion-v0",
44
+ torch_dtype=torch.float32
45
+ ).to(device)
46
+ pipe.safety_checker = None
47
+ model_loaded = "OFA-Sys/small-stable-diffusion-v0"
48
 
49
  # Load CLIP for image-text similarity
50
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
 
68
 
69
  return translated, duration, rouge_l
70
 
71
+ # Creative text generator with evaluation
72
  def generate_creative_text(prompt, max_length=100):
73
  start = time.time()
74
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
75
  output = gen_model.generate(
76
+ input_ids,
77
+ max_length=max_length,
78
+ do_sample=True,
79
+ top_k=50,
80
+ temperature=0.9
81
  )
82
  text = gen_tokenizer.decode(output[0], skip_special_tokens=True)
83
  duration = round(time.time() - start, 2)
84
 
85
  tokens = text.split()
86
+ rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
87
 
88
+ # Calculate perplexity
89
  with torch.no_grad():
90
+ input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
91
  outputs = gen_model(input_ids, labels=input_ids)
92
  loss = outputs.loss
93
+ perplexity = torch.exp(loss).item()
94
 
95
+ return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4)
96
 
97
+ # Generate image and CLIP similarity
98
  def generate_image(prompt):
99
  try:
100
  start = time.time()
101
  result = pipe(prompt)
102
  image = result.images[0].resize((256, 256))
 
 
 
103
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
104
  image.save(tmp_file.name)
105
+ duration = round(time.time() - start, 2)
106
 
107
+ # Compute CLIP similarity
108
+ image_input = clip_preprocess(Image.open(tmp_file.name)).unsqueeze(0).to(device)
109
+ text_input = clip.tokenize([prompt]).to(device)
110
  with torch.no_grad():
111
  image_features = clip_model.encode_image(image_input)
112
  text_features = clip_model.encode_text(text_input)
113
+ similarity = F.cosine_similarity(image_features, text_features).item()
114
 
115
  return tmp_file.name, duration, round(similarity, 4)
116
  except Exception as e:
117
+ return None, None, f"Image generation failed: {str(e)}"
118
 
119
  # Streamlit UI
120
  st.set_page_config(page_title="Tamil β†’ English + AI Art", layout="centered")
121
+ st.title("🧠 Tamil β†’ English + 🎨 Creative Text + πŸ–ΌοΈ AI Image")
122
 
123
+ tamil_input = st.text_area("✍️ Enter Tamil text", height=150)
124
  reference_input = st.text_input("πŸ“˜ Optional: Reference English translation for ROUGE")
125
 
126
  if st.button("πŸš€ Generate Output"):
127
  if not tamil_input.strip():
128
  st.warning("Please enter Tamil text.")
129
  else:
130
+ with st.spinner("πŸ”„ Translating..."):
131
  english_text, t_time, rouge_l = translate_tamil_to_english(tamil_input, reference_input)
132
 
133
+ st.success(f"βœ… Translated in {t_time}s")
134
  st.markdown(f"**πŸ“ English Translation:** `{english_text}`")
135
  if rouge_l is not None:
136
+ st.markdown(f"πŸ“Š ROUGE-L Score: `{rouge_l}`")
137
 
138
  with st.spinner("πŸ–ΌοΈ Generating image..."):
139
+ image_path, img_time, clip_score = generate_image(english_text)
140
 
141
+ if image_path:
142
+ st.success(f"πŸ–ΌοΈ Image generated in {img_time}s using `{model_loaded}`")
143
  st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
144
+ st.markdown(f"πŸ” **CLIP Text-Image Similarity:** `{clip_score}`")
145
  else:
146
+ st.error(clip_score)
147
 
148
  with st.spinner("πŸ’‘ Generating creative text..."):
149
+ creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text)
150
 
151
+ st.success(f"✨ Creative text in {c_time}s")
152
  st.markdown(f"**🧠 Creative Output:** `{creative}`")
153
+ st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`, πŸ“‰ Perplexity: `{ppl}`")
 
154
 
155
  st.markdown("---")
156
+ st.caption("Built by Sureshkumar R | MBart + GPT-2 + Stable Diffusion + CLIP")