Spaces:
Build error
Build error
adymaharana
commited on
Commit
·
1cac669
1
Parent(s):
77e955b
restart
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import os, torch
|
| 2 |
import gradio as gr
|
| 3 |
import torchvision.utils as vutils
|
| 4 |
import torchvision.transforms as transforms
|
|
@@ -68,6 +68,7 @@ def save_story_results(images, video_len=4, n_candidates=1, mask=None):
|
|
| 68 |
def main(args):
|
| 69 |
#device = 'cuda:0'
|
| 70 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 71 |
|
| 72 |
model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
|
| 73 |
|
|
@@ -77,7 +78,7 @@ def main(args):
|
|
| 77 |
#if not os.path.exists("./ckpt/25.pth"):
|
| 78 |
# gdown.download(model_url, quiet=False, use_cookies=False, output="./ckpt/25.pth")
|
| 79 |
# print("Downloaded checkpoint")
|
| 80 |
-
assert os.path.exists("./ckpt/25.pth")
|
| 81 |
gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
|
| 82 |
|
| 83 |
if args.debug:
|
|
@@ -102,6 +103,9 @@ def main(args):
|
|
| 102 |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
| 103 |
)
|
| 104 |
|
|
|
|
|
|
|
|
|
|
| 105 |
def predict(caption_1, caption_2, caption_3, caption_4, source='Pororo', top_k=32, top_p=0.2, n_candidates=4,
|
| 106 |
supercondition=False):
|
| 107 |
|
|
|
|
| 1 |
+
import os, sys, torch
|
| 2 |
import gradio as gr
|
| 3 |
import torchvision.utils as vutils
|
| 4 |
import torchvision.transforms as transforms
|
|
|
|
| 68 |
def main(args):
|
| 69 |
#device = 'cuda:0'
|
| 70 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 71 |
+
#device = torch.device('cpu')
|
| 72 |
|
| 73 |
model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
|
| 74 |
|
|
|
|
| 78 |
#if not os.path.exists("./ckpt/25.pth"):
|
| 79 |
# gdown.download(model_url, quiet=False, use_cookies=False, output="./ckpt/25.pth")
|
| 80 |
# print("Downloaded checkpoint")
|
| 81 |
+
#assert os.path.exists("./ckpt/25.pth")
|
| 82 |
gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
|
| 83 |
|
| 84 |
if args.debug:
|
|
|
|
| 103 |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
| 104 |
)
|
| 105 |
|
| 106 |
+
#torch.save(model, './ckpt/checkpoint.pt')
|
| 107 |
+
#sys.exit()
|
| 108 |
+
|
| 109 |
def predict(caption_1, caption_2, caption_3, caption_4, source='Pororo', top_k=32, top_p=0.2, n_candidates=4,
|
| 110 |
supercondition=False):
|
| 111 |
|
dalle/__pycache__/__init__.cpython-38.pyc
CHANGED
|
Binary files a/dalle/__pycache__/__init__.cpython-38.pyc and b/dalle/__pycache__/__init__.cpython-38.pyc differ
|
|
|
dalle/models/__init__.py
CHANGED
|
@@ -23,6 +23,7 @@ from ..utils.utils import save_image
|
|
| 23 |
from .tokenizer import build_tokenizer
|
| 24 |
import numpy as np
|
| 25 |
from .stage2.layers import CrossAttentionLayer
|
|
|
|
| 26 |
|
| 27 |
_MODELS = {
|
| 28 |
'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
|
|
@@ -1191,7 +1192,9 @@ class StoryDalle(Dalle):
|
|
| 1191 |
print("Loaded tokenizer from finetuned checkpoint")
|
| 1192 |
print(model.cross_attention_idxs)
|
| 1193 |
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
|
|
|
|
| 1194 |
# model.from_ckpt(args.model_name_or_path)
|
|
|
|
| 1195 |
try:
|
| 1196 |
model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['state_dict'])
|
| 1197 |
except KeyError:
|
|
@@ -1248,9 +1251,9 @@ class StoryDalle(Dalle):
|
|
| 1248 |
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
|
| 1249 |
|
| 1250 |
with torch.no_grad():
|
| 1251 |
-
with autocast(enabled=False):
|
| 1252 |
-
|
| 1253 |
-
|
| 1254 |
|
| 1255 |
B, C, H, W = images.shape
|
| 1256 |
|
|
@@ -1310,8 +1313,8 @@ class StoryDalle(Dalle):
|
|
| 1310 |
# Check if the encoding works as intended
|
| 1311 |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
| 1312 |
|
| 1313 |
-
tokens = tokens.to(device)
|
| 1314 |
-
source = source.to(device)
|
| 1315 |
|
| 1316 |
# print(tokens.shape, sent_embeds.shape, prompt.shape)
|
| 1317 |
B, L, _ = sent_embeds.shape
|
|
@@ -1322,8 +1325,8 @@ class StoryDalle(Dalle):
|
|
| 1322 |
prompt = sent_embeds
|
| 1323 |
pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B*L, -1).to(self.device), mode='1d')
|
| 1324 |
|
| 1325 |
-
with autocast(enabled=False):
|
| 1326 |
-
|
| 1327 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
|
| 1328 |
print(tokens.shape, src_codes.shape, prompt.shape)
|
| 1329 |
if self.config.story.condition:
|
|
@@ -1378,8 +1381,8 @@ class StoryDalle(Dalle):
|
|
| 1378 |
# Check if the encoding works as intended
|
| 1379 |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
| 1380 |
|
| 1381 |
-
tokens = tokens.to(device)
|
| 1382 |
-
source = source.to(device)
|
| 1383 |
|
| 1384 |
# print(tokens.shape, sent_embeds.shape, prompt.shape)
|
| 1385 |
B, L, _ = sent_embeds.shape
|
|
@@ -1389,10 +1392,10 @@ class StoryDalle(Dalle):
|
|
| 1389 |
else:
|
| 1390 |
prompt = sent_embeds
|
| 1391 |
pos_enc_prompt = get_positional_encoding(
|
| 1392 |
-
torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B * L, -1).to(
|
| 1393 |
|
| 1394 |
-
with autocast(enabled=False):
|
| 1395 |
-
|
| 1396 |
|
| 1397 |
# repeat inputs to adjust to n_candidates and story length
|
| 1398 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
|
|
|
|
| 23 |
from .tokenizer import build_tokenizer
|
| 24 |
import numpy as np
|
| 25 |
from .stage2.layers import CrossAttentionLayer
|
| 26 |
+
from huggingface_hub import hf_hub_download
|
| 27 |
|
| 28 |
_MODELS = {
|
| 29 |
'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
|
|
|
|
| 1192 |
print("Loaded tokenizer from finetuned checkpoint")
|
| 1193 |
print(model.cross_attention_idxs)
|
| 1194 |
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
|
| 1195 |
+
|
| 1196 |
# model.from_ckpt(args.model_name_or_path)
|
| 1197 |
+
|
| 1198 |
try:
|
| 1199 |
model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['state_dict'])
|
| 1200 |
except KeyError:
|
|
|
|
| 1251 |
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
|
| 1252 |
|
| 1253 |
with torch.no_grad():
|
| 1254 |
+
#with autocast(enabled=False):
|
| 1255 |
+
codes = self.stage1.get_codes(images).detach()
|
| 1256 |
+
src_codes = self.stage1.get_codes(src_images).detach()
|
| 1257 |
|
| 1258 |
B, C, H, W = images.shape
|
| 1259 |
|
|
|
|
| 1313 |
# Check if the encoding works as intended
|
| 1314 |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
| 1315 |
|
| 1316 |
+
#tokens = tokens.to(device)
|
| 1317 |
+
#source = source.to(device)
|
| 1318 |
|
| 1319 |
# print(tokens.shape, sent_embeds.shape, prompt.shape)
|
| 1320 |
B, L, _ = sent_embeds.shape
|
|
|
|
| 1325 |
prompt = sent_embeds
|
| 1326 |
pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B*L, -1).to(self.device), mode='1d')
|
| 1327 |
|
| 1328 |
+
#with autocast(enabled=False):
|
| 1329 |
+
src_codes = self.stage1.get_codes(source).detach()
|
| 1330 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
|
| 1331 |
print(tokens.shape, src_codes.shape, prompt.shape)
|
| 1332 |
if self.config.story.condition:
|
|
|
|
| 1381 |
# Check if the encoding works as intended
|
| 1382 |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
| 1383 |
|
| 1384 |
+
#tokens = tokens.to(device)
|
| 1385 |
+
#source = source.to(device)
|
| 1386 |
|
| 1387 |
# print(tokens.shape, sent_embeds.shape, prompt.shape)
|
| 1388 |
B, L, _ = sent_embeds.shape
|
|
|
|
| 1392 |
else:
|
| 1393 |
prompt = sent_embeds
|
| 1394 |
pos_enc_prompt = get_positional_encoding(
|
| 1395 |
+
torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B * L, -1).to(tokens.device), mode='1d')
|
| 1396 |
|
| 1397 |
+
#with autocast(enabled=False):
|
| 1398 |
+
src_codes = self.stage1.get_codes(source).detach()
|
| 1399 |
|
| 1400 |
# repeat inputs to adjust to n_candidates and story length
|
| 1401 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
|
dalle/models/__pycache__/__init__.cpython-38.pyc
CHANGED
|
Binary files a/dalle/models/__pycache__/__init__.cpython-38.pyc and b/dalle/models/__pycache__/__init__.cpython-38.pyc differ
|
|
|