Spaces:
Configuration error
Configuration error
Kunpeng Song
commited on
Commit
·
73c6f92
1
Parent(s):
7d79282
zero
Browse files- .DS_Store +0 -0
- README.md +8 -6
- app.py +51 -0
- backupApp_version1.py +52 -0
- checkpoints/.DS_Store +0 -0
- checkpoints/attn_adapters_projectors.th +1 -0
- checkpoints/ckpt_saving_path.txt +0 -0
- dataset_lib/__pycache__/dataset_eval_MoMA.cpython-310.pyc +0 -0
- dataset_lib/dataset_eval_MoMA.py +44 -0
- example_images/newImages/.DS_Store +0 -0
- example_images/newImages/02.jpg +0 -0
- example_images/newImages/03.jpg +0 -0
- example_images/newImages/1.jpeg +0 -0
- example_images/newImages/17.jpg +0 -0
- example_images/newImages/2.jpg +0 -0
- example_images/newImages/3.jpg +0 -0
- flagged/log.csv +2 -0
- model_lib/__init__.py +0 -0
- model_lib/__pycache__/__init__.cpython-310.pyc +0 -0
- model_lib/__pycache__/__init__.cpython-39.pyc +0 -0
- model_lib/__pycache__/attention_processor.cpython-310.pyc +0 -0
- model_lib/__pycache__/moMA_generator.cpython-310.pyc +0 -0
- model_lib/__pycache__/moMA_generator.cpython-39.pyc +0 -0
- model_lib/__pycache__/modules.cpython-310.pyc +0 -0
- model_lib/__pycache__/modules.cpython-39.pyc +0 -0
- model_lib/__pycache__/utils.cpython-310.pyc +0 -0
- model_lib/attention_processor.py +245 -0
- model_lib/moMA_generator.py +285 -0
- model_lib/modules.py +151 -0
- model_lib/utils.py +27 -0
- output/car_A car in autumn with falling leaves..jpg +0 -0
- output/car_A wooden sculpture of a car on the table..jpg +0 -0
- requirements.txt +32 -0
.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
README.md
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
---
|
| 2 |
-
title: MoMA
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 4.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: MoMA
|
| 3 |
+
emoji: 🌍
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.31.4
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
short_description: Multi-modal LLM for image personalization
|
| 12 |
---
|
| 13 |
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
import torch
|
| 7 |
+
from pytorch_lightning import seed_everything
|
| 8 |
+
from torchvision.utils import save_image
|
| 9 |
+
from model_lib.modules import MoMA_main_modal
|
| 10 |
+
from model_lib.utils import parse_args
|
| 11 |
+
import os
|
| 12 |
+
os.environ["CUDA_VISIBLE_DEVICES"]="0"
|
| 13 |
+
|
| 14 |
+
title = "MoMA"
|
| 15 |
+
description = "This model has to run on GPU. By default, we load the model with 4-bit quantization to make it fit in smaller hardwares."
|
| 16 |
+
|
| 17 |
+
def MoMA_demo(rgb, subject, prompt, strength, seed):
|
| 18 |
+
seed = int(seed) if seed else 0
|
| 19 |
+
try:
|
| 20 |
+
seed = int(seed)
|
| 21 |
+
except ValueError:
|
| 22 |
+
seed = 0
|
| 23 |
+
seed = seed if not seed == 0 else np.random.randint(0,1000)
|
| 24 |
+
print(f"Seed: {seed}")
|
| 25 |
+
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
generated_image = model.generate_images(rgb, subject, prompt, strength=strength, seed=seed)
|
| 28 |
+
return generated_image
|
| 29 |
+
|
| 30 |
+
def inference(rgb, subject, prompt, strength, seed):
|
| 31 |
+
result = MoMA_demo(rgb, subject, prompt, strength, seed)
|
| 32 |
+
return result
|
| 33 |
+
|
| 34 |
+
seed_everything(0)
|
| 35 |
+
args = parse_args()
|
| 36 |
+
#load MoMA from HuggingFace. Auto download
|
| 37 |
+
model = MoMA_main_modal(args).to(args.device, dtype=torch.float16)
|
| 38 |
+
|
| 39 |
+
gr.Interface(
|
| 40 |
+
inference,
|
| 41 |
+
[gr.Image(type="pil", label="Input RGB"),
|
| 42 |
+
gr.Textbox(lines=1, label="subject"),
|
| 43 |
+
gr.Textbox(lines=1, label="Prompt"),
|
| 44 |
+
gr.Slider(minimum=0.2, maximum=1.2, step=0.1,label="Strength. Recommend: 1.0 for context editing; 0.4 for texture editing",value=1.0),
|
| 45 |
+
gr.Textbox(lines=1, label="Seed. Use 0 for a random seed")],
|
| 46 |
+
gr.Image(type="pil", label="Output"),
|
| 47 |
+
title=title,
|
| 48 |
+
description=description,
|
| 49 |
+
examples=[["example_images/newImages/3.jpg",'car','A car in autumn with falling leaves.',1.0,"6"],["example_images/newImages/3.jpg",'car','A wooden sculpture of a car on a table.',0.4,"4"],["example_images/newImages/2.jpg",'car','A car on a city road with green trees and buildings.',1.0,"4"],["example_images/newImages/03.jpg",'cat','A cat at the Grand Canyon.',1.0,"2"],["example_images/newImages/02.jpg",'dog','A dog in a spring garden with flowers.',1.0,"6"],["example_images/newImages/1.jpeg",'bird','A bird in spring with flowers.',1.0,"1"],["example_images/newImages/17.jpg",'robot','A robot in autumn mountain and lake.',1,"5"]],
|
| 50 |
+
allow_flagging='never'
|
| 51 |
+
).launch(debug=False)
|
backupApp_version1.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
import torch
|
| 7 |
+
from pytorch_lightning import seed_everything
|
| 8 |
+
from torchvision.utils import save_image
|
| 9 |
+
from model_lib.modules import MoMA_main_modal
|
| 10 |
+
from model_lib.utils import parse_args
|
| 11 |
+
import os
|
| 12 |
+
os.environ["CUDA_VISIBLE_DEVICES"]="0"
|
| 13 |
+
|
| 14 |
+
title = "MoMA"
|
| 15 |
+
description = "This model has to run on GPU"
|
| 16 |
+
article = "<p style='text-align: center'><a href='https://news.machinelearning.sg/posts/beautiful_profile_pics_remove_background_image_with_deeplabv3/'>Blog</a> | <a href='https://github.com/eugenesiow/practical-ml'>Github Repo</a></p>"
|
| 17 |
+
|
| 18 |
+
def MoMA_demo(rgb, mask, subject, prompt):
|
| 19 |
+
# move the input and model to GPU for speed if available
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
generated_image = model.generate_images(rgb, mask, subject, prompt, strength=1.0, seed=2)
|
| 22 |
+
return generated_image
|
| 23 |
+
|
| 24 |
+
def inference(rgb, mask, subject, prompt):
|
| 25 |
+
result = MoMA_demo(rgb, mask, subject, prompt)
|
| 26 |
+
return result
|
| 27 |
+
|
| 28 |
+
seed_everything(0)
|
| 29 |
+
args = parse_args()
|
| 30 |
+
#load MoMA from HuggingFace. Auto download
|
| 31 |
+
model = MoMA_main_modal(args).to(args.device, dtype=torch.float16)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
################ change texture ##################
|
| 35 |
+
# prompt = "A wooden sculpture of a car on the table."
|
| 36 |
+
# generated_image = model.generate_images(rgb_path, mask_path, subject, prompt, strength=0.4, seed=4, return_mask=True) # set strength to 0.4 for better prompt fidelity
|
| 37 |
+
# save_image(generated_image,f"{args.output_path}/{subject}_{prompt}.jpg")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
gr.Interface(
|
| 41 |
+
inference,
|
| 42 |
+
[gr.Image(type="pil", label="Input RGB"),
|
| 43 |
+
gr.Image(type="pil", label="Input Mask"),
|
| 44 |
+
gr.Textbox(lines=1, label="subject"),
|
| 45 |
+
gr.Textbox(lines=5, label="Prompt")],
|
| 46 |
+
gr.Image(type="pil", label="Output"),
|
| 47 |
+
title=title,
|
| 48 |
+
description=description,
|
| 49 |
+
article=article,
|
| 50 |
+
examples=[["example_images/newImages/3.jpg",'example_images/newImages/3_mask.jpg','car','A car in autumn with falling leaves.']],
|
| 51 |
+
# enable_queue=True
|
| 52 |
+
).launch(debug=False)
|
checkpoints/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
checkpoints/attn_adapters_projectors.th
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
../../../../../../../home/ks1418/.cache/huggingface/hub/models--KunpengSong--MoMA_llava_7b/blobs/0b432a39e46f01cd9cdb4794b8ef13b9bb0aff2ad6da6800d67fd2ca4af21fa6
|
checkpoints/ckpt_saving_path.txt
ADDED
|
File without changes
|
dataset_lib/__pycache__/dataset_eval_MoMA.cpython-310.pyc
ADDED
|
Binary file (1.76 kB). View file
|
|
|
dataset_lib/dataset_eval_MoMA.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
| 6 |
+
from rembg import remove
|
| 7 |
+
|
| 8 |
+
def create_binary_mask(image):
|
| 9 |
+
grayscale = image.convert("L")
|
| 10 |
+
mask = grayscale.point(lambda x: 255 if x > 1 else 0, '1')
|
| 11 |
+
return mask
|
| 12 |
+
|
| 13 |
+
def Dataset_evaluate_MoMA(image_pil, prompt,subject, moMA_main_modal):
|
| 14 |
+
|
| 15 |
+
LLaVa_processor = moMA_main_modal.image_processor_llava
|
| 16 |
+
llava_config = moMA_main_modal.model_llava.config
|
| 17 |
+
|
| 18 |
+
transform = transforms.Compose([
|
| 19 |
+
transforms.Resize((512, 512)),
|
| 20 |
+
])
|
| 21 |
+
|
| 22 |
+
mask_pil = create_binary_mask(remove(image_pil)) # Image.open(mask_path)
|
| 23 |
+
blip2_opt = prompt
|
| 24 |
+
|
| 25 |
+
if transform is not None:
|
| 26 |
+
image_pil = transform(image_pil)
|
| 27 |
+
mask_pil = transform(mask_pil)
|
| 28 |
+
|
| 29 |
+
mask_pil = np.array(mask_pil)
|
| 30 |
+
mask_pil = mask_pil[:,:,0] if len(mask_pil.shape)==3 else mask_pil
|
| 31 |
+
image = torch.from_numpy(np.array(image_pil)).permute(2,0,1)
|
| 32 |
+
mask = (torch.clamp((torch.from_numpy(mask_pil).unsqueeze(0)).float(),min=0.0,max=1.0)>0).float()
|
| 33 |
+
|
| 34 |
+
res = {'image': (image/127.5-1).unsqueeze(0),\
|
| 35 |
+
'mask': mask.unsqueeze(0), \
|
| 36 |
+
'text': [blip2_opt]}
|
| 37 |
+
|
| 38 |
+
image_wb = image * mask + torch.ones_like(image)* (1-mask)*255
|
| 39 |
+
image_pil = Image.fromarray(image_wb.permute(1,2,0).numpy().astype(np.uint8))
|
| 40 |
+
|
| 41 |
+
res['llava_processed'] = process_images([image_pil], LLaVa_processor, llava_config)
|
| 42 |
+
res['label'] = [subject]
|
| 43 |
+
return res
|
| 44 |
+
|
example_images/newImages/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
example_images/newImages/02.jpg
ADDED
|
example_images/newImages/03.jpg
ADDED
|
example_images/newImages/1.jpeg
ADDED
|
example_images/newImages/17.jpg
ADDED
|
example_images/newImages/2.jpg
ADDED
|
example_images/newImages/3.jpg
ADDED
|
flagged/log.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Input RGB,subject,Prompt,Strength. Recommend: 1.0 for context editing; 0.4 for texture editing,Seed. Use 0 for a random seed,Output,flag,username,timestamp
|
| 2 |
+
,,,1,,,,,2024-05-21 19:36:27.802622
|
model_lib/__init__.py
ADDED
|
File without changes
|
model_lib/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (191 Bytes). View file
|
|
|
model_lib/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (189 Bytes). View file
|
|
|
model_lib/__pycache__/attention_processor.cpython-310.pyc
ADDED
|
Binary file (7.06 kB). View file
|
|
|
model_lib/__pycache__/moMA_generator.cpython-310.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
model_lib/__pycache__/moMA_generator.cpython-39.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
model_lib/__pycache__/modules.cpython-310.pyc
ADDED
|
Binary file (6.96 kB). View file
|
|
|
model_lib/__pycache__/modules.cpython-39.pyc
ADDED
|
Binary file (6.94 kB). View file
|
|
|
model_lib/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.45 kB). View file
|
|
|
model_lib/attention_processor.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import math
|
| 7 |
+
from torchvision.utils import save_image
|
| 8 |
+
import torchvision.transforms as T
|
| 9 |
+
|
| 10 |
+
def get_mask_from_cross(attn_processors):
|
| 11 |
+
reference_masks = []
|
| 12 |
+
for attn_processor in attn_processors.values():
|
| 13 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
| 14 |
+
reference_masks.append(attn_processor.mask_i)
|
| 15 |
+
mask = torch.cat(reference_masks,dim=1).mean(dim=1)
|
| 16 |
+
mask = (mask-mask.min())/(mask.max()-mask.min())
|
| 17 |
+
mask = (mask>0.2).to(torch.float32)*mask
|
| 18 |
+
mask = (mask-mask.min())/(mask.max()-mask.min())
|
| 19 |
+
return mask.unsqueeze(1)
|
| 20 |
+
|
| 21 |
+
class IPAttnProcessor(nn.Module):
|
| 22 |
+
r"""
|
| 23 |
+
Attention processor for IP-Adapater.
|
| 24 |
+
Args:
|
| 25 |
+
hidden_size (`int`):
|
| 26 |
+
The hidden size of the attention layer.
|
| 27 |
+
cross_attention_dim (`int`):
|
| 28 |
+
The number of channels in the `encoder_hidden_states`.
|
| 29 |
+
scale (`float`, defaults to 1.0):
|
| 30 |
+
the weight scale of image prompt.
|
| 31 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
| 32 |
+
The context length of the image features.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.hidden_size = hidden_size
|
| 39 |
+
self.cross_attention_dim = cross_attention_dim
|
| 40 |
+
self.scale = scale
|
| 41 |
+
self.num_tokens = num_tokens
|
| 42 |
+
|
| 43 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 44 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 45 |
+
|
| 46 |
+
self.store_attn = None
|
| 47 |
+
self.enabled = True
|
| 48 |
+
self.mode = 'inject'
|
| 49 |
+
|
| 50 |
+
self.subject_idxs = None
|
| 51 |
+
self.mask_i = None
|
| 52 |
+
self.mask_ig_prev = None
|
| 53 |
+
|
| 54 |
+
def __call__(
|
| 55 |
+
self,
|
| 56 |
+
attn,
|
| 57 |
+
hidden_states,
|
| 58 |
+
encoder_hidden_states=None,
|
| 59 |
+
attention_mask=None,
|
| 60 |
+
temb=None,
|
| 61 |
+
):
|
| 62 |
+
residual = hidden_states
|
| 63 |
+
|
| 64 |
+
input_ndim = hidden_states.ndim
|
| 65 |
+
|
| 66 |
+
if input_ndim == 4:
|
| 67 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 68 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 69 |
+
|
| 70 |
+
batch_size, sequence_length, _ = (
|
| 71 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 72 |
+
)
|
| 73 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 74 |
+
|
| 75 |
+
if attn.group_norm is not None:
|
| 76 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 77 |
+
|
| 78 |
+
query = attn.to_q(hidden_states)
|
| 79 |
+
|
| 80 |
+
if encoder_hidden_states is None:
|
| 81 |
+
encoder_hidden_states = hidden_states
|
| 82 |
+
else:
|
| 83 |
+
# get encoder_hidden_states, ip_hidden_states
|
| 84 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
| 85 |
+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
|
| 86 |
+
if attn.norm_cross:
|
| 87 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 88 |
+
|
| 89 |
+
key = attn.to_k(encoder_hidden_states)
|
| 90 |
+
value = attn.to_v(encoder_hidden_states)
|
| 91 |
+
|
| 92 |
+
query = attn.head_to_batch_dim(query)
|
| 93 |
+
key = attn.head_to_batch_dim(key)
|
| 94 |
+
value = attn.head_to_batch_dim(value)
|
| 95 |
+
|
| 96 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 97 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 98 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 99 |
+
|
| 100 |
+
# for ip-adapter
|
| 101 |
+
if self.enabled:
|
| 102 |
+
if self.mode == 'inject' or self.mode == 'masked_generation':
|
| 103 |
+
ip_key = self.to_k_ip(ip_hidden_states.to(torch.float16))
|
| 104 |
+
ip_value = self.to_v_ip(ip_hidden_states.to(torch.float16))
|
| 105 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
| 106 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
| 107 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key.to(torch.float32), None)
|
| 108 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value.to(torch.float32))
|
| 109 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
| 110 |
+
if (self.mask_ig_prev is not None) and self.mode == 'masked_generation':
|
| 111 |
+
mask_ig_prev = rearrange(F.interpolate(self.mask_ig_prev,size=int(math.sqrt(query.shape[1]))),"b c h w -> b (h w) c")
|
| 112 |
+
if not mask_ig_prev.shape[0]==ip_hidden_states.shape[0]: mask_ig_prev = mask_ig_prev.repeat(2,1,1)
|
| 113 |
+
ip_hidden_states = ip_hidden_states * mask_ig_prev
|
| 114 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
| 115 |
+
if self.mode == 'extract' or self.mode == 'masked_generation':
|
| 116 |
+
subject_idxs = self.subject_idxs*2 if not (hidden_states.shape[0] == len(self.subject_idxs)) else self.subject_idxs
|
| 117 |
+
assert (hidden_states.shape[0] == len(subject_idxs))
|
| 118 |
+
attentions = rearrange(attention_probs, '(b h) n d -> b h n d', h=8).mean(1)
|
| 119 |
+
attn_extracted = [attentions[i, :, subject_idxs[i]].sum(-1) for i in range(hidden_states.shape[0])]
|
| 120 |
+
attn_extracted = [(atn-atn.min())/(atn.max()-atn.min()) for atn in attn_extracted]
|
| 121 |
+
attn_extracted = torch.stack(attn_extracted, dim=0)
|
| 122 |
+
attn_extracted = rearrange(attn_extracted, 'b (h w) -> b h w', h=int(math.sqrt(attention_probs.shape[1])))
|
| 123 |
+
attn_extracted = torch.clamp(F.interpolate(attn_extracted.unsqueeze(1),size=512),min=0,max=1)
|
| 124 |
+
self.mask_i = attn_extracted
|
| 125 |
+
|
| 126 |
+
# linear proj
|
| 127 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 128 |
+
# dropout
|
| 129 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 130 |
+
|
| 131 |
+
if input_ndim == 4:
|
| 132 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 133 |
+
|
| 134 |
+
return hidden_states
|
| 135 |
+
|
| 136 |
+
### added for self attention
|
| 137 |
+
class IPAttnProcessor_Self(nn.Module):
|
| 138 |
+
r"""
|
| 139 |
+
Attention processor for IP-Adapater. (But for self attention)
|
| 140 |
+
Args:
|
| 141 |
+
hidden_size (`int`):
|
| 142 |
+
The hidden size of the attention layer.
|
| 143 |
+
cross_attention_dim (`int`):
|
| 144 |
+
The number of channels in the `encoder_hidden_states`.
|
| 145 |
+
scale (`float`, defaults to 1.0):
|
| 146 |
+
the weight scale of image prompt.
|
| 147 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
| 148 |
+
The context length of the image features.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
| 152 |
+
super().__init__()
|
| 153 |
+
|
| 154 |
+
self.hidden_size = hidden_size
|
| 155 |
+
self.cross_attention_dim = cross_attention_dim
|
| 156 |
+
self.scale = scale
|
| 157 |
+
self.num_tokens = num_tokens
|
| 158 |
+
|
| 159 |
+
self.to_k_ip = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 160 |
+
self.to_v_ip = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 161 |
+
|
| 162 |
+
self.scale_learnable = torch.nn.Parameter(torch.zeros(1),requires_grad=True)
|
| 163 |
+
|
| 164 |
+
self.enabled = True
|
| 165 |
+
self.mode = 'extract'
|
| 166 |
+
|
| 167 |
+
self.store_ks, self.store_vs = [], []
|
| 168 |
+
self.mask_id, self.mask_ig = None, None
|
| 169 |
+
|
| 170 |
+
def __call__(
|
| 171 |
+
self,
|
| 172 |
+
attn,
|
| 173 |
+
hidden_states,
|
| 174 |
+
encoder_hidden_states=None,
|
| 175 |
+
attention_mask=None,
|
| 176 |
+
temb=None,
|
| 177 |
+
):
|
| 178 |
+
input_ndim = hidden_states.ndim
|
| 179 |
+
|
| 180 |
+
if input_ndim == 4:
|
| 181 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 182 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 183 |
+
|
| 184 |
+
batch_size, sequence_length, _ = (
|
| 185 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 186 |
+
)
|
| 187 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 188 |
+
|
| 189 |
+
if attn.group_norm is not None:
|
| 190 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 191 |
+
|
| 192 |
+
query = attn.to_q(hidden_states)
|
| 193 |
+
|
| 194 |
+
if encoder_hidden_states is None:
|
| 195 |
+
encoder_hidden_states = hidden_states
|
| 196 |
+
else:
|
| 197 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
| 198 |
+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
|
| 199 |
+
if attn.norm_cross:
|
| 200 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 201 |
+
|
| 202 |
+
key_0 = attn.to_k(encoder_hidden_states)
|
| 203 |
+
value_0 = attn.to_v(encoder_hidden_states)
|
| 204 |
+
|
| 205 |
+
query = attn.head_to_batch_dim(query)
|
| 206 |
+
key = attn.head_to_batch_dim(key_0)
|
| 207 |
+
value = attn.head_to_batch_dim(value_0)
|
| 208 |
+
|
| 209 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 210 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 211 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 212 |
+
|
| 213 |
+
if self.enabled:
|
| 214 |
+
if self.mode == 'extract':
|
| 215 |
+
ks, vs = attn.head_to_batch_dim(self.to_k_ip(key_0)), attn.head_to_batch_dim(self.to_v_ip(value_0))
|
| 216 |
+
self.store_ks, self.store_vs = self.store_ks+[ks], self.store_vs+[vs]
|
| 217 |
+
self.store_ks, self.store_vs = torch.cat(self.store_ks,dim=0), torch.cat(self.store_vs,dim=0)
|
| 218 |
+
|
| 219 |
+
if self.mode == 'masked_generation':
|
| 220 |
+
if not self.store_ks.shape[0]==query.shape[0]: self.store_ks,self.store_vs = self.store_ks.repeat(2,1,1), self.store_vs.repeat(2,1,1)
|
| 221 |
+
mask_id = self.mask_id.clone()
|
| 222 |
+
mask_id.masked_fill_(self.mask_id==False, -torch.finfo(mask_id.dtype).max)
|
| 223 |
+
mask_id = rearrange(F.interpolate(mask_id,size=int(math.sqrt(query.shape[1]))),"b c h w -> b c (h w)").repeat(1,query.shape[1],1)
|
| 224 |
+
mask_id = mask_id.repeat(8,1,1) # 8 is head dim
|
| 225 |
+
if not mask_id.shape[0]==int(query.shape[0]): mask_id = mask_id.repeat(2,1,1)
|
| 226 |
+
attention_probs_ref = attn.get_attention_scores(query, self.store_ks, mask_id.to(query.dtype))
|
| 227 |
+
hidden_states_ref = torch.bmm(attention_probs_ref, self.store_vs)
|
| 228 |
+
hidden_states_ref = attn.batch_to_head_dim(hidden_states_ref)
|
| 229 |
+
scale = self.scale.repeat(int(batch_size/self.scale.shape[0])).unsqueeze(-1).unsqueeze(-1) if type(self.scale)==torch.Tensor else self.scale
|
| 230 |
+
if self.mask_ig == None:
|
| 231 |
+
hidden_states = hidden_states + scale * hidden_states_ref * self.scale_learnable
|
| 232 |
+
else:
|
| 233 |
+
mask_ig = rearrange(F.interpolate(self.mask_ig,size=int(math.sqrt(query.shape[1]))),"b c h w -> b (h w) c")
|
| 234 |
+
if not mask_ig.shape[0]==hidden_states_ref.shape[0]: mask_ig = mask_ig.repeat(2,1,1)
|
| 235 |
+
hidden_states = hidden_states + scale * hidden_states_ref * mask_ig * self.scale_learnable
|
| 236 |
+
|
| 237 |
+
# linear proj
|
| 238 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 239 |
+
# dropout
|
| 240 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 241 |
+
|
| 242 |
+
if input_ndim == 4:
|
| 243 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 244 |
+
|
| 245 |
+
return hidden_states
|
model_lib/moMA_generator.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from model_lib.attention_processor import IPAttnProcessor, IPAttnProcessor_Self, get_mask_from_cross
|
| 6 |
+
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
|
| 7 |
+
import tqdm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_subject_idx(model,prompt,src_subject,device):
|
| 11 |
+
tokenized_prompt = model.tokenizer(prompt,padding="max_length",max_length=model.tokenizer.model_max_length,truncation=True,return_tensors="pt",).to(device)
|
| 12 |
+
input_ids = tokenized_prompt['input_ids']
|
| 13 |
+
src_subject_idxs = []
|
| 14 |
+
for subject,input_id in zip(src_subject,input_ids):
|
| 15 |
+
src_subject_token_id = [model.tokenizer.encode(i, add_special_tokens=False)[0] for i in subject.split(' ')]
|
| 16 |
+
src_subject_idxs = [i for i, x in enumerate(input_id.tolist()) if x in src_subject_token_id]
|
| 17 |
+
return [src_subject_idxs]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def add_function(model):
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def generate_with_adapters(
|
| 23 |
+
model,
|
| 24 |
+
prompt_embeds,
|
| 25 |
+
num_inference_steps,
|
| 26 |
+
generator,
|
| 27 |
+
t_range=list(range(0,950)),
|
| 28 |
+
):
|
| 29 |
+
|
| 30 |
+
latents = model.prepare_latents(prompt_embeds.shape[0]//2,4,512,512,prompt_embeds.dtype,prompt_embeds.device,generator)
|
| 31 |
+
|
| 32 |
+
model.scheduler.set_timesteps(num_inference_steps)
|
| 33 |
+
|
| 34 |
+
iterator = tqdm.tqdm(model.scheduler.timesteps)
|
| 35 |
+
mask_ig_prev = None
|
| 36 |
+
for i, t in enumerate(iterator):
|
| 37 |
+
if not t in t_range:
|
| 38 |
+
model.moMA_generator.toggle_enable_flag('cross')
|
| 39 |
+
else:
|
| 40 |
+
model.moMA_generator.toggle_enable_flag('all')
|
| 41 |
+
|
| 42 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 43 |
+
noise_pred = model.unet(
|
| 44 |
+
latent_model_input,
|
| 45 |
+
t,
|
| 46 |
+
encoder_hidden_states=prompt_embeds,
|
| 47 |
+
return_dict=False,
|
| 48 |
+
)[0]
|
| 49 |
+
|
| 50 |
+
# perform guidance
|
| 51 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 52 |
+
noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)
|
| 53 |
+
|
| 54 |
+
latents = model.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 55 |
+
|
| 56 |
+
mask_ig_prev = (get_mask_from_cross(model.unet.attn_processors))[latents.shape[0]:]
|
| 57 |
+
|
| 58 |
+
model.moMA_generator.set_self_mask('self','ig',mask_ig_prev)
|
| 59 |
+
model.moMA_generator.set_self_mask('cross',mask=mask_ig_prev.clone().detach())
|
| 60 |
+
|
| 61 |
+
image = model.vae.decode(latents / model.vae.config.scaling_factor, return_dict=False)[0]
|
| 62 |
+
return image ,mask_ig_prev.repeat(1,3,1,1) if (not mask_ig_prev==None) else None
|
| 63 |
+
model.generate_with_adapters = generate_with_adapters
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ImageProjModel(torch.nn.Module):
|
| 67 |
+
"""Projection Model"""
|
| 68 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.cross_attention_dim = cross_attention_dim
|
| 72 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
| 73 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
| 74 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
| 75 |
+
|
| 76 |
+
def forward(self, image_embeds):
|
| 77 |
+
embeds = image_embeds
|
| 78 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
| 79 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
| 80 |
+
return clip_extra_context_tokens
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class MoMA_generator:
|
| 84 |
+
def __init__(self, device,args):
|
| 85 |
+
self.args = args
|
| 86 |
+
self.device = device
|
| 87 |
+
|
| 88 |
+
noise_scheduler = DDIMScheduler(num_train_timesteps=1000,beta_start=0.00085,beta_end=0.012,beta_schedule="scaled_linear",clip_sample=False,set_alpha_to_one=False,steps_offset=1,)
|
| 89 |
+
|
| 90 |
+
print('Loading VAE: stabilityai--sd-vae-ft-mse...')
|
| 91 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
|
| 92 |
+
|
| 93 |
+
print('Loading StableDiffusion: Realistic_Vision...')
|
| 94 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(
|
| 95 |
+
"SG161222/Realistic_Vision_V4.0_noVAE",
|
| 96 |
+
torch_dtype=torch.float16,
|
| 97 |
+
scheduler=noise_scheduler,
|
| 98 |
+
vae=vae,
|
| 99 |
+
feature_extractor=None,
|
| 100 |
+
safety_checker=None,
|
| 101 |
+
).to(self.device)
|
| 102 |
+
|
| 103 |
+
self.unet = self.pipe.unet
|
| 104 |
+
add_function(self.pipe)
|
| 105 |
+
self.pipe.moMA_generator = self
|
| 106 |
+
|
| 107 |
+
self.set_ip_adapter()
|
| 108 |
+
self.image_proj_model = self.init_proj()
|
| 109 |
+
|
| 110 |
+
def init_proj(self):
|
| 111 |
+
image_proj_model = ImageProjModel(
|
| 112 |
+
cross_attention_dim=768,
|
| 113 |
+
clip_embeddings_dim=1024,
|
| 114 |
+
clip_extra_context_tokens=4,
|
| 115 |
+
).to(self.device, dtype=torch.float16)
|
| 116 |
+
return image_proj_model
|
| 117 |
+
|
| 118 |
+
def set_ip_adapter(self):
|
| 119 |
+
unet = self.unet
|
| 120 |
+
attn_procs = {}
|
| 121 |
+
for name in unet.attn_processors.keys():
|
| 122 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 123 |
+
if name.startswith("mid_block"):
|
| 124 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 125 |
+
elif name.startswith("up_blocks"):
|
| 126 |
+
block_id = int(name[len("up_blocks.")])
|
| 127 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 128 |
+
elif name.startswith("down_blocks"):
|
| 129 |
+
block_id = int(name[len("down_blocks.")])
|
| 130 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 131 |
+
if cross_attention_dim is None:
|
| 132 |
+
attn_procs[name] = IPAttnProcessor_Self(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
|
| 133 |
+
else:
|
| 134 |
+
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
|
| 135 |
+
unet.set_attn_processor(attn_procs)
|
| 136 |
+
|
| 137 |
+
@torch.inference_mode()
|
| 138 |
+
def get_image_embeds_CFG(self, llava_emb):
|
| 139 |
+
clip_image_embeds = llava_emb
|
| 140 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
| 141 |
+
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
| 142 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
| 143 |
+
|
| 144 |
+
def get_image_crossAttn_feature(
|
| 145 |
+
self,
|
| 146 |
+
llava_emb,
|
| 147 |
+
num_samples=1,
|
| 148 |
+
):
|
| 149 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_CFG(llava_emb)
|
| 150 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
| 151 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
| 152 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
| 153 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
| 154 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
| 155 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
| 156 |
+
|
| 157 |
+
# feature are from self-attention layers of Unet: feed reference image to Unet with t=0
|
| 158 |
+
def get_image_selfAttn_feature(
|
| 159 |
+
self,
|
| 160 |
+
pil_image,
|
| 161 |
+
prompt,
|
| 162 |
+
):
|
| 163 |
+
self.toggle_enable_flag('self')
|
| 164 |
+
self.toggle_extract_inject_flag('self', 'extract')
|
| 165 |
+
tokenized_prompt = self.pipe.tokenizer(prompt,padding="max_length",truncation=True,return_tensors="pt",).to(self.device)
|
| 166 |
+
text_embeddings = self.pipe.text_encoder(input_ids=tokenized_prompt.input_ids)[0]
|
| 167 |
+
|
| 168 |
+
ref_image = pil_image
|
| 169 |
+
ref_image.to(self.device)
|
| 170 |
+
|
| 171 |
+
with torch.no_grad(): latents = self.pipe.vae.encode(ref_image).latent_dist.sample()
|
| 172 |
+
latents = latents * self.pipe.vae.config.scaling_factor
|
| 173 |
+
|
| 174 |
+
noise = torch.randn_like(latents)
|
| 175 |
+
timesteps = torch.tensor([0],device=latents.device).long() # fixed to 0
|
| 176 |
+
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timesteps)
|
| 177 |
+
|
| 178 |
+
_ = self.unet(noisy_latents,timestep=timesteps,encoder_hidden_states=text_embeddings)["sample"]
|
| 179 |
+
# features are stored in attn_processors
|
| 180 |
+
|
| 181 |
+
return None
|
| 182 |
+
|
| 183 |
+
@torch.no_grad()
|
| 184 |
+
def generate_with_MoMA(
|
| 185 |
+
self,
|
| 186 |
+
batch,
|
| 187 |
+
llava_emb=None,
|
| 188 |
+
seed=None,
|
| 189 |
+
device='cuda',
|
| 190 |
+
):
|
| 191 |
+
self.reset_all()
|
| 192 |
+
img_ig,mask_id,subject,prompt = batch['image'].half().to(device),batch['mask'].half().to(device),batch['label'][0],batch['text'][0]
|
| 193 |
+
|
| 194 |
+
prompt = [f"photo of a {subject}. "+ prompt]
|
| 195 |
+
subject_idx = get_subject_idx(self.pipe,prompt,[subject],self.device)
|
| 196 |
+
negative_prompt = None
|
| 197 |
+
|
| 198 |
+
# get context-cross-attention feature (from MLLM decoder)
|
| 199 |
+
cond_llava_embeds, uncond_llava_embeds = self.get_image_crossAttn_feature(llava_emb,num_samples=1)
|
| 200 |
+
# get subject-cross-attention feature (from Unet)
|
| 201 |
+
self.get_image_selfAttn_feature(img_ig,subject) # features are stored in attn_processors
|
| 202 |
+
|
| 203 |
+
with torch.inference_mode():
|
| 204 |
+
prompt_embeds = self.pipe._encode_prompt(
|
| 205 |
+
prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
| 206 |
+
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
|
| 207 |
+
prompt_embeds = torch.cat([prompt_embeds_, cond_llava_embeds], dim=1)
|
| 208 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_llava_embeds], dim=1)
|
| 209 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 210 |
+
|
| 211 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
| 212 |
+
|
| 213 |
+
self.set_self_mask('eraseAll')
|
| 214 |
+
self.toggle_enable_flag('all')
|
| 215 |
+
self.toggle_extract_inject_flag('all','masked_generation')
|
| 216 |
+
self.set_self_mask('self','id',mask_id)
|
| 217 |
+
self.set_cross_subject_idxs(subject_idx)
|
| 218 |
+
|
| 219 |
+
images, mask = self.pipe.generate_with_adapters(
|
| 220 |
+
self.pipe,
|
| 221 |
+
prompt_embeds,
|
| 222 |
+
50,
|
| 223 |
+
generator,
|
| 224 |
+
)
|
| 225 |
+
images = torch.clip((images+1)/2.0,min=0.0,max=1.0)
|
| 226 |
+
|
| 227 |
+
return images.cpu(), mask.cpu()
|
| 228 |
+
|
| 229 |
+
def set_selfAttn_strength(self, strength):
|
| 230 |
+
for attn_processor in self.unet.attn_processors.values():
|
| 231 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
| 232 |
+
attn_processor.scale = 1.0
|
| 233 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):
|
| 234 |
+
attn_processor.scale = strength
|
| 235 |
+
|
| 236 |
+
def set_cross_subject_idxs(self, subject_idxs):
|
| 237 |
+
for attn_processor in self.unet.attn_processors.values():
|
| 238 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
| 239 |
+
attn_processor.subject_idxs = subject_idxs
|
| 240 |
+
|
| 241 |
+
def set_self_mask(self,mode,id_ig='', mask=None): #only have effect on self attn of the generation process
|
| 242 |
+
for attn_processor in self.unet.attn_processors.values():
|
| 243 |
+
if mode == 'eraseAll':
|
| 244 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):
|
| 245 |
+
attn_processor.mask_id,attn_processor.mask_ig = None,None
|
| 246 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
| 247 |
+
attn_processor.mask_i, attn_processor.mask_ig_prev = None, None
|
| 248 |
+
if mode == 'self':
|
| 249 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):
|
| 250 |
+
if id_ig == 'id':attn_processor.mask_id = mask
|
| 251 |
+
if id_ig == 'ig':attn_processor.mask_ig = mask
|
| 252 |
+
if mode == 'cross':
|
| 253 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
| 254 |
+
attn_processor.mask_ig_prev = mask
|
| 255 |
+
|
| 256 |
+
def toggle_enable_flag(self, processor_enable_mode):
|
| 257 |
+
for attn_processor in self.unet.attn_processors.values():
|
| 258 |
+
if processor_enable_mode == 'cross':
|
| 259 |
+
if isinstance(attn_processor, IPAttnProcessor):attn_processor.enabled = True
|
| 260 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.enabled = False
|
| 261 |
+
if processor_enable_mode == 'self':
|
| 262 |
+
if isinstance(attn_processor, IPAttnProcessor):attn_processor.enabled = False
|
| 263 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.enabled = True
|
| 264 |
+
if processor_enable_mode == 'all':
|
| 265 |
+
attn_processor.enabled = True
|
| 266 |
+
if processor_enable_mode == 'none':
|
| 267 |
+
attn_processor.enabled = False
|
| 268 |
+
|
| 269 |
+
def toggle_extract_inject_flag(self, processor_name, mode): # mode: str, 'extract' or 'inject' or 'both'(cross only)
|
| 270 |
+
for attn_processor in self.unet.attn_processors.values():
|
| 271 |
+
if processor_name == 'cross':
|
| 272 |
+
if isinstance(attn_processor, IPAttnProcessor):attn_processor.mode = mode
|
| 273 |
+
if processor_name == 'self':
|
| 274 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.mode = mode
|
| 275 |
+
if processor_name == 'all':
|
| 276 |
+
attn_processor.mode = mode
|
| 277 |
+
|
| 278 |
+
def reset_all(self,keep_self=False):
|
| 279 |
+
for attn_processor in self.unet.attn_processors.values():
|
| 280 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
| 281 |
+
attn_processor.store_attn, attn_processor.subject_idxs, attn_processor.mask_i, attn_processor.mask_ig_prev, self.subject_idxs = None, None, None, None, None
|
| 282 |
+
|
| 283 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):
|
| 284 |
+
attn_processor.mask_id, attn_processor.mask_ig = None, None
|
| 285 |
+
if not keep_self: attn_processor.store_ks, attn_processor.store_vs = [], []
|
model_lib/modules.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
import torch.utils.checkpoint
|
| 7 |
+
from torchvision.transforms import ToPILImage
|
| 8 |
+
from model_lib.moMA_generator import MoMA_generator
|
| 9 |
+
from transformers.activations import ACT2FN
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
|
| 12 |
+
from dataset_lib.dataset_eval_MoMA import Dataset_evaluate_MoMA
|
| 13 |
+
|
| 14 |
+
from llava.model.builder import load_pretrained_model
|
| 15 |
+
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
|
| 16 |
+
from llava.constants import IMAGE_TOKEN_INDEX
|
| 17 |
+
|
| 18 |
+
def add_function(model):
|
| 19 |
+
def my_llava_forward(
|
| 20 |
+
self,
|
| 21 |
+
input_ids: torch.LongTensor = None,
|
| 22 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 23 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 24 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 25 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 26 |
+
labels: Optional[torch.LongTensor] = None,
|
| 27 |
+
use_cache: Optional[bool] = None,
|
| 28 |
+
output_attentions: Optional[bool] = None,
|
| 29 |
+
output_hidden_states: Optional[bool] = None,
|
| 30 |
+
images: Optional[torch.FloatTensor] = None,
|
| 31 |
+
return_dict: Optional[bool] = None,
|
| 32 |
+
):
|
| 33 |
+
(_,position_ids,attention_mask,_,inputs_embeds,_) = self.prepare_inputs_labels_for_multimodal(input_ids,position_ids,attention_mask,None,None,images)
|
| 34 |
+
|
| 35 |
+
outputs = self.model(
|
| 36 |
+
input_ids=None,
|
| 37 |
+
attention_mask=attention_mask,
|
| 38 |
+
position_ids=position_ids,
|
| 39 |
+
past_key_values=None,
|
| 40 |
+
inputs_embeds=inputs_embeds,
|
| 41 |
+
use_cache=True,
|
| 42 |
+
output_attentions=False,
|
| 43 |
+
output_hidden_states=False,
|
| 44 |
+
return_dict=True,
|
| 45 |
+
)
|
| 46 |
+
return outputs[0]
|
| 47 |
+
|
| 48 |
+
model.my_llava_forward = my_llava_forward
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LlamaMLP_mapping(nn.Module):
|
| 52 |
+
def __init__(self, hidden_size,hidden_size_out):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.hidden_size, self.hidden_size_out = hidden_size,hidden_size_out
|
| 55 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
|
| 56 |
+
self.up_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
|
| 57 |
+
self.down_proj = nn.Linear(self.hidden_size_out, self.hidden_size_out, bias=False)
|
| 58 |
+
self.act_fn = ACT2FN["silu"]
|
| 59 |
+
self.act_fn_output = ACT2FN["tanh"]
|
| 60 |
+
self.init_linear()
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 64 |
+
return down_proj
|
| 65 |
+
|
| 66 |
+
def init_linear(self):
|
| 67 |
+
torch.nn.init.xavier_normal_(self.gate_proj.weight)
|
| 68 |
+
self.gate_proj.weight.data=self.gate_proj.weight.data/4.0
|
| 69 |
+
torch.nn.init.xavier_normal_(self.up_proj.weight)
|
| 70 |
+
self.up_proj.weight.data=self.up_proj.weight.data/4.0
|
| 71 |
+
torch.nn.init.xavier_normal_(self.down_proj.weight)
|
| 72 |
+
self.down_proj.weight.data=self.down_proj.weight.data/4.0
|
| 73 |
+
|
| 74 |
+
class MoMA_main_modal(nn.Module):
|
| 75 |
+
def __init__(self,args):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.args = args
|
| 78 |
+
self.device = args.device
|
| 79 |
+
|
| 80 |
+
self.moMA_generator = MoMA_generator(self.device,args)
|
| 81 |
+
self.unet = self.moMA_generator.pipe.unet
|
| 82 |
+
self.vae = self.moMA_generator.pipe.vae
|
| 83 |
+
|
| 84 |
+
print('Loading MoMA: its Multi-modal LLM...')
|
| 85 |
+
model_name = get_model_name_from_path(args.model_path)
|
| 86 |
+
self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit, device=args.device)
|
| 87 |
+
|
| 88 |
+
add_function(self.model_llava)
|
| 89 |
+
|
| 90 |
+
self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.float16)
|
| 91 |
+
self.load_saved_components()
|
| 92 |
+
self.freeze_modules()
|
| 93 |
+
|
| 94 |
+
def load_saved_components(self):
|
| 95 |
+
if not os.path.exists(self.args.load_attn_adapters):
|
| 96 |
+
print('Loading Attentions and LLM mappings...')
|
| 97 |
+
hf_hub_download(repo_id=self.args.model_path, filename="attn_adapters_projectors.th",local_dir='/'.join(self.args.load_attn_adapters.split('/')[:-1]))
|
| 98 |
+
|
| 99 |
+
#load attention adapters and self cross attentions
|
| 100 |
+
state_dict = torch.load(self.args.load_attn_adapters, map_location="cpu")
|
| 101 |
+
self.moMA_generator.image_proj_model.load_state_dict(state_dict["projectors"])
|
| 102 |
+
attn_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
|
| 103 |
+
attn_layers.load_state_dict(state_dict["self_cross_attentions"],strict=False)
|
| 104 |
+
|
| 105 |
+
#load LLM projectors
|
| 106 |
+
self.load_state_dict(state_dict['llm_mapping'],strict=False)
|
| 107 |
+
|
| 108 |
+
def freeze_modules(self):
|
| 109 |
+
all_modules = [self.moMA_generator.pipe.vae,self.moMA_generator.pipe.text_encoder,self.unet,self.model_llava,self.mapping]
|
| 110 |
+
for module in all_modules:
|
| 111 |
+
module.train = False
|
| 112 |
+
module.requires_grad_(False)
|
| 113 |
+
|
| 114 |
+
def forward_MLLM(self,batch):
|
| 115 |
+
llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text']
|
| 116 |
+
|
| 117 |
+
input_ids,attention_masks,position_ids = [],[],[]
|
| 118 |
+
for subject,prompt in zip(subjects,prompts):
|
| 119 |
+
prompt_construct = f"USER: <image>\n A photo of a {subject}. Describe a new image of the same {subject} in: {prompt}. ASSISTANT: *"
|
| 120 |
+
input_id = tokenizer_image_token(prompt_construct, self.tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
| 121 |
+
attention_mask = torch.ones(input_id.shape, dtype=torch.long, device=self.device)
|
| 122 |
+
position_id = torch.tensor(list(range(input_id.shape[-1])), device=self.device)
|
| 123 |
+
|
| 124 |
+
position_ids += [position_id]
|
| 125 |
+
attention_masks += [attention_mask[0]]
|
| 126 |
+
input_ids += [input_id[0]]
|
| 127 |
+
|
| 128 |
+
input_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in input_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
|
| 129 |
+
position_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in position_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
|
| 130 |
+
attention_masks = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in attention_masks],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
|
| 131 |
+
|
| 132 |
+
output = self.model_llava.my_llava_forward(self.model_llava,input_ids=input_ids,attention_mask=attention_masks,position_ids=position_ids,images=llava_processeds)
|
| 133 |
+
output = self.mapping(output)
|
| 134 |
+
return output[:,-1,:]
|
| 135 |
+
|
| 136 |
+
def reset(self):
|
| 137 |
+
self.moMA_generator.reset_all()
|
| 138 |
+
|
| 139 |
+
def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0):
|
| 140 |
+
batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self)
|
| 141 |
+
self.moMA_generator.set_selfAttn_strength(strength)
|
| 142 |
+
|
| 143 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
### key steps
|
| 146 |
+
llava_emb = self.forward_MLLM(batch).clone().detach()
|
| 147 |
+
img,mask = self.moMA_generator.generate_with_MoMA(batch,llava_emb=llava_emb,seed=seed,device=self.args.device)
|
| 148 |
+
self.reset()
|
| 149 |
+
|
| 150 |
+
result = ToPILImage()(img[0])
|
| 151 |
+
return result
|
model_lib/utils.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision.transforms import ToPILImage
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
def parse_args():
|
| 7 |
+
parser = argparse.ArgumentParser(description="Simple example of MoMA.")
|
| 8 |
+
parser.add_argument("--load_attn_adapters",type=str,default="checkpoints/attn_adapters_projectors.th",help="self_cross attentions and LLM projectors.")
|
| 9 |
+
parser.add_argument("--output_path",type=str,default="output",help="output directory.")
|
| 10 |
+
parser.add_argument("--model_path",type=str,default="KunpengSong/MoMA_llava_7b",help="fine tuned llava (Multi-modal LLM decoder)")
|
| 11 |
+
args = parser.parse_known_args()[0]
|
| 12 |
+
args.device = torch.device("cuda", 0)
|
| 13 |
+
args.load_8bit, args.load_4bit = False, True
|
| 14 |
+
return args
|
| 15 |
+
|
| 16 |
+
def show_PIL_image(tensor):
|
| 17 |
+
# tensor of shape [3, 3, 512, 512]
|
| 18 |
+
to_pil = ToPILImage()
|
| 19 |
+
images = [to_pil(tensor[i]) for i in range(tensor.shape[0])]
|
| 20 |
+
|
| 21 |
+
concatenated_image = Image.new('RGB', (images[0].width * 3, images[0].height))
|
| 22 |
+
x_offset = 0
|
| 23 |
+
for img in images:
|
| 24 |
+
concatenated_image.paste(img, (x_offset, 0))
|
| 25 |
+
x_offset += img.width
|
| 26 |
+
|
| 27 |
+
return concatenated_image
|
output/car_A car in autumn with falling leaves..jpg
ADDED
|
output/car_A wooden sculpture of a car on the table..jpg
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pip
|
| 2 |
+
einops
|
| 3 |
+
fastapi
|
| 4 |
+
gradio
|
| 5 |
+
numpy
|
| 6 |
+
requests
|
| 7 |
+
sentencepiece
|
| 8 |
+
tokenizers>=0.12.1
|
| 9 |
+
torch==2.0.1
|
| 10 |
+
torchvision==0.15.2
|
| 11 |
+
uvicorn
|
| 12 |
+
wandb
|
| 13 |
+
shortuuid
|
| 14 |
+
httpx==0.24.0
|
| 15 |
+
deepspeed
|
| 16 |
+
peft==0.4.0
|
| 17 |
+
transformers==4.36.2
|
| 18 |
+
accelerate==0.21.0
|
| 19 |
+
bitsandbytes==0.41.0
|
| 20 |
+
scikit-learn==1.2.2
|
| 21 |
+
sentencepiece==0.1.99
|
| 22 |
+
einops==0.6.1
|
| 23 |
+
einops-exts==0.0.4
|
| 24 |
+
timm==0.6.13
|
| 25 |
+
gradio_client
|
| 26 |
+
opencv-python
|
| 27 |
+
diffusers
|
| 28 |
+
torchaudio
|
| 29 |
+
torchmetrics
|
| 30 |
+
llava-torch
|
| 31 |
+
rembg
|
| 32 |
+
pytorch_lightning
|