24Sureshkumar commited on
Commit
831a7b4
Β·
verified Β·
1 Parent(s): e1224aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -62
app.py CHANGED
@@ -1,43 +1,45 @@
1
  import streamlit as st
2
  import torch
3
- import openai
4
- import os
5
- import time
6
- import requests
7
  from PIL import Image
8
  import tempfile
9
- import clip
 
 
10
  import torch.nn.functional as F
11
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
12
- from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
13
- from rouge_score import rouge_scorer
14
 
15
  # Set device
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
- # OpenAI Key
19
- openai.api_key = os.getenv("OPENAI_API_KEY")
20
-
21
- # ---- Load MBart (Translation) ----
22
  translator_model = MBartForConditionalGeneration.from_pretrained(
23
  "facebook/mbart-large-50-many-to-many-mmt"
24
- )
25
  translator_tokenizer = MBart50TokenizerFast.from_pretrained(
26
  "facebook/mbart-large-50-many-to-many-mmt"
27
  )
28
- translator_model.to(device)
29
  translator_tokenizer.src_lang = "ta_IN"
30
 
31
- # ---- GPT-2 ----
32
- gen_model = GPT2LMHeadModel.from_pretrained("gpt2")
33
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
34
- gen_model.to(device)
35
  gen_model.eval()
36
 
37
- # ---- CLIP ----
38
- clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
 
 
 
 
 
 
 
 
39
 
40
- # ---- Translation ----
41
  def translate_tamil_to_english(text, reference=None):
42
  start = time.time()
43
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
@@ -56,63 +58,52 @@ def translate_tamil_to_english(text, reference=None):
56
 
57
  return translated, duration, rouge_l
58
 
59
- # ---- Creative Text ----
60
  def generate_creative_text(prompt, max_length=100):
61
  start = time.time()
62
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
63
- output = gen_model.generate(
64
- input_ids,
65
- max_length=max_length,
66
- do_sample=True,
67
- top_k=50,
68
- temperature=0.9
69
- )
70
  text = gen_tokenizer.decode(output[0], skip_special_tokens=True)
71
  duration = round(time.time() - start, 2)
72
 
73
  tokens = text.split()
74
- rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
75
 
 
76
  with torch.no_grad():
77
  input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
78
  outputs = gen_model(input_ids, labels=input_ids)
79
  loss = outputs.loss
80
  perplexity = torch.exp(loss).item()
81
 
82
- return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4)
83
 
84
- # ---- Image Generation ----
85
  def generate_image(prompt):
86
  try:
87
  start = time.time()
88
- response = openai.images.generate(
89
- model="dall-e-3",
90
- prompt=prompt,
91
- size="512x512",
92
- quality="standard",
93
- n=1
94
- )
95
- image_url = response.data[0].url
96
- image_data = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))
97
-
98
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
99
- image_data.save(tmp_file.name)
100
  duration = round(time.time() - start, 2)
101
-
102
- image_input = clip_preprocess(image_data).unsqueeze(0).to(device)
103
- text_input = clip.tokenize([prompt]).to(device)
104
- with torch.no_grad():
105
- image_features = clip_model.encode_image(image_input)
106
- text_features = clip_model.encode_text(text_input)
107
- similarity = F.cosine_similarity(image_features, text_features).item()
108
-
109
- return tmp_file.name, duration, round(similarity, 4)
110
  except Exception as e:
111
- return None, None, f"Image generation failed: {str(e)}"
 
 
 
 
 
 
 
 
 
 
112
 
113
- # ---- UI ----
114
  st.set_page_config(page_title="Tamil β†’ English + AI Art", layout="centered")
115
- st.title("🧠 Tamil β†’ English + 🎨 Creative Text + πŸ–ΌοΈ AI Image")
116
 
117
  tamil_input = st.text_area("✍️ Enter Tamil text", height=150)
118
  reference_input = st.text_input("πŸ“˜ Optional: Reference English translation for ROUGE")
@@ -129,22 +120,25 @@ if st.button("πŸš€ Generate Output"):
129
  if rouge_l is not None:
130
  st.markdown(f"πŸ“Š ROUGE-L Score: `{rouge_l}`")
131
 
132
- with st.spinner("πŸ–ΌοΈ Generating image..."):
133
- image_path, img_time, clip_score = generate_image(english_text)
134
 
135
- if image_path:
136
- st.success(f"πŸ–ΌοΈ Image generated in {img_time}s using OpenAI DALLΒ·E 3")
137
  st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
138
- st.markdown(f"πŸ” **CLIP Text-Image Similarity:** `{clip_score}`")
 
 
 
139
  else:
140
- st.error(clip_score)
141
 
142
  with st.spinner("πŸ’‘ Generating creative text..."):
143
  creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text)
144
 
145
- st.success(f"✨ Creative text in {c_time}s")
146
  st.markdown(f"**🧠 Creative Output:** `{creative}`")
147
  st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`, πŸ“‰ Perplexity: `{ppl}`")
148
 
149
  st.markdown("---")
150
- st.caption("Built by Sureshkumar R | MBart + GPT-2 + OpenAI DALLΒ·E 3")
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from diffusers import StableDiffusionPipeline
6
+ from rouge_score import rouge_scorer
7
  from PIL import Image
8
  import tempfile
9
+ import os
10
+ import time
11
+ from transformers import CLIPProcessor, CLIPModel
12
  import torch.nn.functional as F
 
 
 
13
 
14
  # Set device
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # Load translation model
 
 
 
18
  translator_model = MBartForConditionalGeneration.from_pretrained(
19
  "facebook/mbart-large-50-many-to-many-mmt"
20
+ ).to(device)
21
  translator_tokenizer = MBart50TokenizerFast.from_pretrained(
22
  "facebook/mbart-large-50-many-to-many-mmt"
23
  )
 
24
  translator_tokenizer.src_lang = "ta_IN"
25
 
26
+ # Load GPT-2 for creative text
 
27
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
28
+ gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
29
  gen_model.eval()
30
 
31
+ # Load Stable Diffusion 1.5
32
+ pipe = StableDiffusionPipeline.from_pretrained(
33
+ "stabilityai/stable-diffusion-1-5",
34
+ torch_dtype=torch.float32,
35
+ ).to(device)
36
+ pipe.safety_checker = None # Optional: disable safety filter
37
+
38
+ # Load CLIP model
39
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
40
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
41
 
42
+ # --- Translation ---
43
  def translate_tamil_to_english(text, reference=None):
44
  start = time.time()
45
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
 
58
 
59
  return translated, duration, rouge_l
60
 
61
+ # --- GPT-2 Creative Generation ---
62
  def generate_creative_text(prompt, max_length=100):
63
  start = time.time()
64
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
65
+ output = gen_model.generate(input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9)
 
 
 
 
 
 
66
  text = gen_tokenizer.decode(output[0], skip_special_tokens=True)
67
  duration = round(time.time() - start, 2)
68
 
69
  tokens = text.split()
70
+ repetition_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
71
 
72
+ # Perplexity
73
  with torch.no_grad():
74
  input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
75
  outputs = gen_model(input_ids, labels=input_ids)
76
  loss = outputs.loss
77
  perplexity = torch.exp(loss).item()
78
 
79
+ return text, duration, len(tokens), round(repetition_rate, 4), round(perplexity, 4)
80
 
81
+ # --- Stable Diffusion Image Generation ---
82
  def generate_image(prompt):
83
  try:
84
  start = time.time()
85
+ result = pipe(prompt)
86
+ image = result.images[0].resize((256, 256))
 
 
 
 
 
 
 
 
87
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
88
+ image.save(tmp_file.name)
89
  duration = round(time.time() - start, 2)
90
+ return tmp_file.name, duration, image
 
 
 
 
 
 
 
 
91
  except Exception as e:
92
+ return None, 0, f"Image generation failed: {str(e)}"
93
+
94
+ # --- CLIP Similarity ---
95
+ def evaluate_clip_similarity(text, image):
96
+ inputs = clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to(device)
97
+ with torch.no_grad():
98
+ outputs = clip_model(**inputs)
99
+ logits_per_image = outputs.logits_per_image
100
+ probs = F.softmax(logits_per_image, dim=1)
101
+ similarity_score = logits_per_image[0][0].item()
102
+ return round(similarity_score, 4)
103
 
104
+ # --- Streamlit UI ---
105
  st.set_page_config(page_title="Tamil β†’ English + AI Art", layout="centered")
106
+ st.title("🧠 Tamil β†’ English + 🎨 Creative Text + AI Image")
107
 
108
  tamil_input = st.text_area("✍️ Enter Tamil text", height=150)
109
  reference_input = st.text_input("πŸ“˜ Optional: Reference English translation for ROUGE")
 
120
  if rouge_l is not None:
121
  st.markdown(f"πŸ“Š ROUGE-L Score: `{rouge_l}`")
122
 
123
+ with st.spinner("🎨 Generating image..."):
124
+ image_path, img_time, image_obj = generate_image(english_text)
125
 
126
+ if isinstance(image_obj, Image.Image):
127
+ st.success(f"πŸ–ΌοΈ Image generated in {img_time}s")
128
  st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
129
+
130
+ with st.spinner("πŸ”Ž Evaluating CLIP similarity..."):
131
+ clip_score = evaluate_clip_similarity(english_text, image_obj)
132
+ st.markdown(f"πŸ” CLIP Text-Image Similarity: `{clip_score}`")
133
  else:
134
+ st.error(image_obj)
135
 
136
  with st.spinner("πŸ’‘ Generating creative text..."):
137
  creative, c_time, tokens, rep_rate, ppl = generate_creative_text(english_text)
138
 
139
+ st.success(f"✨ Creative text generated in {c_time}s")
140
  st.markdown(f"**🧠 Creative Output:** `{creative}`")
141
  st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`, πŸ“‰ Perplexity: `{ppl}`")
142
 
143
  st.markdown("---")
144
+ st.caption("Built by Sureshkumar R using MBart, GPT-2, Stable Diffusion 1.5, and CLIP (Open Source)")