Update app.py
Browse files
app.py
CHANGED
@@ -1,20 +1,21 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
|
4 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
|
5 |
-
from diffusers import StableDiffusionPipeline
|
6 |
-
from rouge_score import rouge_scorer
|
7 |
-
from PIL import Image
|
8 |
-
import tempfile
|
9 |
import os
|
10 |
import time
|
11 |
-
|
|
|
12 |
import clip # from OpenAI CLIP repo
|
|
|
|
|
|
|
|
|
13 |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
14 |
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
16 |
|
17 |
-
# Load MBart
|
18 |
translator_model = MBartForConditionalGeneration.from_pretrained(
|
19 |
"facebook/mbart-large-50-many-to-many-mmt"
|
20 |
).to(device)
|
@@ -23,33 +24,15 @@ translator_tokenizer = MBart50TokenizerFast.from_pretrained(
|
|
23 |
)
|
24 |
translator_tokenizer.src_lang = "ta_IN"
|
25 |
|
26 |
-
#
|
27 |
gen_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
|
28 |
gen_model.eval()
|
29 |
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
30 |
|
31 |
-
#
|
32 |
-
try:
|
33 |
-
pipe = StableDiffusionPipeline.from_pretrained(
|
34 |
-
"stabilityai/stable-diffusion-2-1",
|
35 |
-
torch_dtype=torch.float32,
|
36 |
-
use_auth_token=os.getenv("HF_TOKEN")
|
37 |
-
).to(device)
|
38 |
-
pipe.safety_checker = None
|
39 |
-
model_loaded = "stabilityai/stable-diffusion-2-1"
|
40 |
-
except Exception as e:
|
41 |
-
st.warning("โ ๏ธ SD-2.1 failed. Using lightweight fallback model.")
|
42 |
-
pipe = StableDiffusionPipeline.from_pretrained(
|
43 |
-
"OFA-Sys/small-stable-diffusion-v0",
|
44 |
-
torch_dtype=torch.float32
|
45 |
-
).to(device)
|
46 |
-
pipe.safety_checker = None
|
47 |
-
model_loaded = "OFA-Sys/small-stable-diffusion-v0"
|
48 |
-
|
49 |
-
# Load CLIP for image-text similarity
|
50 |
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
|
51 |
|
52 |
-
# Translation
|
53 |
def translate_tamil_to_english(text, reference=None):
|
54 |
start = time.time()
|
55 |
inputs = translator_tokenizer(text, return_tensors="pt").to(device)
|
@@ -68,7 +51,7 @@ def translate_tamil_to_english(text, reference=None):
|
|
68 |
|
69 |
return translated, duration, rouge_l
|
70 |
|
71 |
-
# Creative
|
72 |
def generate_creative_text(prompt, max_length=100):
|
73 |
start = time.time()
|
74 |
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
|
@@ -85,7 +68,6 @@ def generate_creative_text(prompt, max_length=100):
|
|
85 |
tokens = text.split()
|
86 |
rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
|
87 |
|
88 |
-
# Calculate perplexity
|
89 |
with torch.no_grad():
|
90 |
input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
|
91 |
outputs = gen_model(input_ids, labels=input_ids)
|
@@ -94,18 +76,28 @@ def generate_creative_text(prompt, max_length=100):
|
|
94 |
|
95 |
return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4)
|
96 |
|
97 |
-
#
|
98 |
def generate_image(prompt):
|
99 |
try:
|
100 |
start = time.time()
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
104 |
-
|
105 |
duration = round(time.time() - start, 2)
|
106 |
|
107 |
-
#
|
108 |
-
image_input = clip_preprocess(
|
109 |
text_input = clip.tokenize([prompt]).to(device)
|
110 |
with torch.no_grad():
|
111 |
image_features = clip_model.encode_image(image_input)
|
@@ -116,7 +108,7 @@ def generate_image(prompt):
|
|
116 |
except Exception as e:
|
117 |
return None, None, f"Image generation failed: {str(e)}"
|
118 |
|
119 |
-
#
|
120 |
st.set_page_config(page_title="Tamil โ English + AI Art", layout="centered")
|
121 |
st.title("๐ง Tamil โ English + ๐จ Creative Text + ๐ผ๏ธ AI Image")
|
122 |
|
@@ -139,7 +131,7 @@ if st.button("๐ Generate Output"):
|
|
139 |
image_path, img_time, clip_score = generate_image(english_text)
|
140 |
|
141 |
if image_path:
|
142 |
-
st.success(f"๐ผ๏ธ Image generated in {img_time}s using
|
143 |
st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
|
144 |
st.markdown(f"๐ **CLIP Text-Image Similarity:** `{clip_score}`")
|
145 |
else:
|
@@ -153,4 +145,4 @@ if st.button("๐ Generate Output"):
|
|
153 |
st.markdown(f"๐ Tokens: `{tokens}`, ๐ Repetition Rate: `{rep_rate}`, ๐ Perplexity: `{ppl}`")
|
154 |
|
155 |
st.markdown("---")
|
156 |
-
st.caption("Built by Sureshkumar R | MBart + GPT-2 +
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
import openai
|
|
|
|
|
|
|
|
|
|
|
4 |
import os
|
5 |
import time
|
6 |
+
from PIL import Image
|
7 |
+
import tempfile
|
8 |
import clip # from OpenAI CLIP repo
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
11 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
|
12 |
+
from rouge_score import rouge_scorer
|
13 |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
14 |
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
openai.api_key = os.getenv("OPENAI_API_KEY") # Set this from env
|
17 |
|
18 |
+
# Load MBart
|
19 |
translator_model = MBartForConditionalGeneration.from_pretrained(
|
20 |
"facebook/mbart-large-50-many-to-many-mmt"
|
21 |
).to(device)
|
|
|
24 |
)
|
25 |
translator_tokenizer.src_lang = "ta_IN"
|
26 |
|
27 |
+
# GPT-2
|
28 |
gen_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
|
29 |
gen_model.eval()
|
30 |
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
31 |
|
32 |
+
# CLIP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
|
34 |
|
35 |
+
# ---- Translation ----
|
36 |
def translate_tamil_to_english(text, reference=None):
|
37 |
start = time.time()
|
38 |
inputs = translator_tokenizer(text, return_tensors="pt").to(device)
|
|
|
51 |
|
52 |
return translated, duration, rouge_l
|
53 |
|
54 |
+
# ---- Creative Text ----
|
55 |
def generate_creative_text(prompt, max_length=100):
|
56 |
start = time.time()
|
57 |
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
|
|
|
68 |
tokens = text.split()
|
69 |
rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
|
70 |
|
|
|
71 |
with torch.no_grad():
|
72 |
input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
|
73 |
outputs = gen_model(input_ids, labels=input_ids)
|
|
|
76 |
|
77 |
return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4)
|
78 |
|
79 |
+
# ---- Image Generation using DALLยทE 3 ----
|
80 |
def generate_image(prompt):
|
81 |
try:
|
82 |
start = time.time()
|
83 |
+
response = openai.images.generate(
|
84 |
+
model="dall-e-3",
|
85 |
+
prompt=prompt,
|
86 |
+
size="512x512",
|
87 |
+
quality="standard",
|
88 |
+
n=1
|
89 |
+
)
|
90 |
+
image_url = response.data[0].url
|
91 |
+
image_data = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix=".png"))
|
92 |
+
image_data = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))
|
93 |
+
|
94 |
+
# Save locally
|
95 |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
96 |
+
image_data.save(tmp_file.name)
|
97 |
duration = round(time.time() - start, 2)
|
98 |
|
99 |
+
# CLIP similarity
|
100 |
+
image_input = clip_preprocess(image_data).unsqueeze(0).to(device)
|
101 |
text_input = clip.tokenize([prompt]).to(device)
|
102 |
with torch.no_grad():
|
103 |
image_features = clip_model.encode_image(image_input)
|
|
|
108 |
except Exception as e:
|
109 |
return None, None, f"Image generation failed: {str(e)}"
|
110 |
|
111 |
+
# ---- UI ----
|
112 |
st.set_page_config(page_title="Tamil โ English + AI Art", layout="centered")
|
113 |
st.title("๐ง Tamil โ English + ๐จ Creative Text + ๐ผ๏ธ AI Image")
|
114 |
|
|
|
131 |
image_path, img_time, clip_score = generate_image(english_text)
|
132 |
|
133 |
if image_path:
|
134 |
+
st.success(f"๐ผ๏ธ Image generated in {img_time}s using OpenAI DALLยทE 3")
|
135 |
st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
|
136 |
st.markdown(f"๐ **CLIP Text-Image Similarity:** `{clip_score}`")
|
137 |
else:
|
|
|
145 |
st.markdown(f"๐ Tokens: `{tokens}`, ๐ Repetition Rate: `{rep_rate}`, ๐ Perplexity: `{ppl}`")
|
146 |
|
147 |
st.markdown("---")
|
148 |
+
st.caption("Built by Sureshkumar R | MBart + GPT-2 + OpenAI DALLยทE 3")
|