24Sureshkumar commited on
Commit
f20a187
ยท
verified ยท
1 Parent(s): 3d2a9f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -40
app.py CHANGED
@@ -1,20 +1,21 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
5
- from diffusers import StableDiffusionPipeline
6
- from rouge_score import rouge_scorer
7
- from PIL import Image
8
- import tempfile
9
  import os
10
  import time
11
- import torch.nn.functional as F
 
12
  import clip # from OpenAI CLIP repo
 
 
 
 
13
  from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
16
 
17
- # Load MBart model
18
  translator_model = MBartForConditionalGeneration.from_pretrained(
19
  "facebook/mbart-large-50-many-to-many-mmt"
20
  ).to(device)
@@ -23,33 +24,15 @@ translator_tokenizer = MBart50TokenizerFast.from_pretrained(
23
  )
24
  translator_tokenizer.src_lang = "ta_IN"
25
 
26
- # Load GPT-2
27
  gen_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
28
  gen_model.eval()
29
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
30
 
31
- # Try loading SD-2.1, fallback to lightweight
32
- try:
33
- pipe = StableDiffusionPipeline.from_pretrained(
34
- "stabilityai/stable-diffusion-2-1",
35
- torch_dtype=torch.float32,
36
- use_auth_token=os.getenv("HF_TOKEN")
37
- ).to(device)
38
- pipe.safety_checker = None
39
- model_loaded = "stabilityai/stable-diffusion-2-1"
40
- except Exception as e:
41
- st.warning("โš ๏ธ SD-2.1 failed. Using lightweight fallback model.")
42
- pipe = StableDiffusionPipeline.from_pretrained(
43
- "OFA-Sys/small-stable-diffusion-v0",
44
- torch_dtype=torch.float32
45
- ).to(device)
46
- pipe.safety_checker = None
47
- model_loaded = "OFA-Sys/small-stable-diffusion-v0"
48
-
49
- # Load CLIP for image-text similarity
50
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
51
 
52
- # Translation function
53
  def translate_tamil_to_english(text, reference=None):
54
  start = time.time()
55
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
@@ -68,7 +51,7 @@ def translate_tamil_to_english(text, reference=None):
68
 
69
  return translated, duration, rouge_l
70
 
71
- # Creative text generator with evaluation
72
  def generate_creative_text(prompt, max_length=100):
73
  start = time.time()
74
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
@@ -85,7 +68,6 @@ def generate_creative_text(prompt, max_length=100):
85
  tokens = text.split()
86
  rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
87
 
88
- # Calculate perplexity
89
  with torch.no_grad():
90
  input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
91
  outputs = gen_model(input_ids, labels=input_ids)
@@ -94,18 +76,28 @@ def generate_creative_text(prompt, max_length=100):
94
 
95
  return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4)
96
 
97
- # Generate image and CLIP similarity
98
  def generate_image(prompt):
99
  try:
100
  start = time.time()
101
- result = pipe(prompt)
102
- image = result.images[0].resize((256, 256))
 
 
 
 
 
 
 
 
 
 
103
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
104
- image.save(tmp_file.name)
105
  duration = round(time.time() - start, 2)
106
 
107
- # Compute CLIP similarity
108
- image_input = clip_preprocess(Image.open(tmp_file.name)).unsqueeze(0).to(device)
109
  text_input = clip.tokenize([prompt]).to(device)
110
  with torch.no_grad():
111
  image_features = clip_model.encode_image(image_input)
@@ -116,7 +108,7 @@ def generate_image(prompt):
116
  except Exception as e:
117
  return None, None, f"Image generation failed: {str(e)}"
118
 
119
- # Streamlit UI
120
  st.set_page_config(page_title="Tamil โ†’ English + AI Art", layout="centered")
121
  st.title("๐Ÿง  Tamil โ†’ English + ๐ŸŽจ Creative Text + ๐Ÿ–ผ๏ธ AI Image")
122
 
@@ -139,7 +131,7 @@ if st.button("๐Ÿš€ Generate Output"):
139
  image_path, img_time, clip_score = generate_image(english_text)
140
 
141
  if image_path:
142
- st.success(f"๐Ÿ–ผ๏ธ Image generated in {img_time}s using `{model_loaded}`")
143
  st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
144
  st.markdown(f"๐Ÿ” **CLIP Text-Image Similarity:** `{clip_score}`")
145
  else:
@@ -153,4 +145,4 @@ if st.button("๐Ÿš€ Generate Output"):
153
  st.markdown(f"๐Ÿ“Œ Tokens: `{tokens}`, ๐Ÿ” Repetition Rate: `{rep_rate}`, ๐Ÿ“‰ Perplexity: `{ppl}`")
154
 
155
  st.markdown("---")
156
- st.caption("Built by Sureshkumar R | MBart + GPT-2 + Stable Diffusion")
 
1
  import streamlit as st
2
  import torch
3
+ import openai
 
 
 
 
 
4
  import os
5
  import time
6
+ from PIL import Image
7
+ import tempfile
8
  import clip # from OpenAI CLIP repo
9
+ import torch.nn.functional as F
10
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
12
+ from rouge_score import rouge_scorer
13
  from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ openai.api_key = os.getenv("OPENAI_API_KEY") # Set this from env
17
 
18
+ # Load MBart
19
  translator_model = MBartForConditionalGeneration.from_pretrained(
20
  "facebook/mbart-large-50-many-to-many-mmt"
21
  ).to(device)
 
24
  )
25
  translator_tokenizer.src_lang = "ta_IN"
26
 
27
+ # GPT-2
28
  gen_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
29
  gen_model.eval()
30
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
31
 
32
+ # CLIP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
34
 
35
+ # ---- Translation ----
36
  def translate_tamil_to_english(text, reference=None):
37
  start = time.time()
38
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
 
51
 
52
  return translated, duration, rouge_l
53
 
54
+ # ---- Creative Text ----
55
  def generate_creative_text(prompt, max_length=100):
56
  start = time.time()
57
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
 
68
  tokens = text.split()
69
  rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
70
 
 
71
  with torch.no_grad():
72
  input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
73
  outputs = gen_model(input_ids, labels=input_ids)
 
76
 
77
  return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4)
78
 
79
+ # ---- Image Generation using DALLยทE 3 ----
80
  def generate_image(prompt):
81
  try:
82
  start = time.time()
83
+ response = openai.images.generate(
84
+ model="dall-e-3",
85
+ prompt=prompt,
86
+ size="512x512",
87
+ quality="standard",
88
+ n=1
89
+ )
90
+ image_url = response.data[0].url
91
+ image_data = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix=".png"))
92
+ image_data = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))
93
+
94
+ # Save locally
95
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
96
+ image_data.save(tmp_file.name)
97
  duration = round(time.time() - start, 2)
98
 
99
+ # CLIP similarity
100
+ image_input = clip_preprocess(image_data).unsqueeze(0).to(device)
101
  text_input = clip.tokenize([prompt]).to(device)
102
  with torch.no_grad():
103
  image_features = clip_model.encode_image(image_input)
 
108
  except Exception as e:
109
  return None, None, f"Image generation failed: {str(e)}"
110
 
111
+ # ---- UI ----
112
  st.set_page_config(page_title="Tamil โ†’ English + AI Art", layout="centered")
113
  st.title("๐Ÿง  Tamil โ†’ English + ๐ŸŽจ Creative Text + ๐Ÿ–ผ๏ธ AI Image")
114
 
 
131
  image_path, img_time, clip_score = generate_image(english_text)
132
 
133
  if image_path:
134
+ st.success(f"๐Ÿ–ผ๏ธ Image generated in {img_time}s using OpenAI DALLยทE 3")
135
  st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
136
  st.markdown(f"๐Ÿ” **CLIP Text-Image Similarity:** `{clip_score}`")
137
  else:
 
145
  st.markdown(f"๐Ÿ“Œ Tokens: `{tokens}`, ๐Ÿ” Repetition Rate: `{rep_rate}`, ๐Ÿ“‰ Perplexity: `{ppl}`")
146
 
147
  st.markdown("---")
148
+ st.caption("Built by Sureshkumar R | MBart + GPT-2 + OpenAI DALLยทE 3")