|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from einops import rearrange |
|
from mmengine.config import Config |
|
from xtuner.registry import BUILDER |
|
from torch.nn.utils.rnn import pad_sequence |
|
import os |
|
import json |
|
from mmengine.logging import print_log |
|
import spaces |
|
|
|
def crop2square(pil_img): |
|
width, height = pil_img.width, pil_img.height |
|
short = min(width, height) |
|
left = (width - short) // 2 |
|
upper = (height - short) // 2 |
|
return pil_img.crop((left, upper, left + short, upper + short)) |
|
|
|
def preprocess_image(image: Image.Image, image_size: int, dtype: torch.dtype): |
|
"""将 PIL Image 缩放(使用邻近插值)、归一化并返回 [1, C, H, W] Tensor。""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img = crop2square(image) |
|
img = img.resize((image_size, image_size)) |
|
|
|
arr = np.asarray(img).astype(np.float32) / 255.0 |
|
arr = 2 * arr - 1 |
|
tensor = torch.from_numpy(arr).to(dtype=dtype) |
|
return rearrange(tensor, "h w c -> 1 c h w") |
|
|
|
def expand2square(pil_img, target_size=1024, background_color=(127, 127, 127)): |
|
""" |
|
Resize an image to fit within a square of size target_size x target_size, |
|
padding with background_color to make it exactly square. |
|
|
|
Args: |
|
pil_img (PIL.Image.Image): The input image. |
|
target_size (int): The desired square resolution. |
|
background_color (tuple): RGB color to pad with. |
|
|
|
Returns: |
|
PIL.Image.Image: The resized and padded square image. |
|
""" |
|
original_width, original_height = pil_img.size |
|
scale = min(target_size / original_width, target_size / original_height) |
|
new_width = int(original_width * scale) |
|
new_height = int(original_height * scale) |
|
|
|
|
|
resized_img = pil_img.resize((new_width, new_height), resample=Image.Resampling.BICUBIC) |
|
|
|
|
|
new_img = Image.new(pil_img.mode, (target_size, target_size), background_color) |
|
paste_position = ((target_size - new_width) // 2, (target_size - new_height) // 2) |
|
new_img.paste(resized_img, paste_position) |
|
|
|
return new_img |
|
|
|
def _print_load_result(module_name, missing, unexpected): |
|
print_log( |
|
f"[INFO] Loaded {module_name}. missing={len(missing)}, unexpected={len(unexpected)}" |
|
) |
|
|
|
|
|
class Inferencer: |
|
def __init__( |
|
self, config_file, model_path, image_size=1024, cfg_prompt="Generate an image." |
|
): |
|
self.config_file = config_file |
|
self.cfg = Config.fromfile(self.config_file) |
|
|
|
self.model_path = model_path |
|
self.device = "cuda" |
|
self.image_size = image_size |
|
self.image_shape = (image_size // 16, image_size // 16) |
|
self.cfg_prompt = cfg_prompt |
|
self.model = None |
|
|
|
def init_model(self): |
|
|
|
|
|
|
|
model = BUILDER.build(self.cfg.model) |
|
|
|
if os.path.isdir(self.model_path): |
|
index_path = os.path.join(self.model_path, "pytorch_model.bin.index.json") |
|
print_log( |
|
f"[INFO] Loading sharded Harmon checkpoint from: {self.model_path}" |
|
) |
|
state_dict = {} |
|
with open(index_path, "r") as f: |
|
index = json.load(f) |
|
for shard in sorted(set(index["weight_map"].values())): |
|
shard_path = os.path.join(self.model_path, shard) |
|
print_log(f"[INFO] Loading shard: {shard_path}") |
|
state_dict.update(torch.load(shard_path, map_location=self.device)) |
|
else: |
|
print_log(f"[INFO] Loading full Harmon checkpoint from: {self.model_path}") |
|
state_dict = torch.load(self.model_path, map_location=self.device) |
|
|
|
m, u = model.load_state_dict(state_dict, strict=False) |
|
_print_load_result("Harmon", m, u) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = model.to(self.device, dtype=model.dtype) |
|
model.eval() |
|
return model |
|
|
|
@spaces.GPU(duration=120) |
|
def gen_image( |
|
self, |
|
raw_prompt, |
|
images_to_generate=1, |
|
cfg=3.0, |
|
num_iter=64, |
|
cfg_schedule="constant", |
|
temperature=1.0, |
|
): |
|
if not self.model: |
|
self.model = self.init_model() |
|
prompt = self.model.prompt_template["INSTRUCTION"].format( |
|
input=f"Generate an image: {raw_prompt.strip()}." |
|
) |
|
prompts = [prompt] * images_to_generate |
|
if cfg != 1.0: |
|
prompts += [ |
|
self.model.prompt_template["INSTRUCTION"].format(input=self.cfg_prompt) |
|
] * images_to_generate |
|
|
|
inputs = self.model.tokenizer( |
|
prompts, add_special_tokens=True, return_tensors="pt", padding=True |
|
).to(self.device) |
|
|
|
print(prompts) |
|
|
|
images = self.model.sample( |
|
**inputs, |
|
num_iter=num_iter, |
|
cfg=cfg, |
|
cfg_schedule=cfg_schedule, |
|
temperature=temperature, |
|
progress=False, |
|
image_shape=self.image_shape, |
|
) |
|
images = rearrange(images, "(n b) c h w -> b n h w c", n=images_to_generate) |
|
images = ( |
|
torch.clamp(127.5 * images + 128.0, 0, 255) |
|
.to("cpu", dtype=torch.uint8) |
|
.numpy() |
|
) |
|
|
|
return [Image.fromarray(img) for img in images[0]] |
|
|
|
@spaces.GPU(duration=120) |
|
def query_image(self, img: Image.Image, prompt=""): |
|
model = self.model |
|
if not model: |
|
model = self.init_model() |
|
tokenizer = model.tokenizer |
|
special_tokens_dict = {"additional_special_tokens": ["<image>"]} |
|
tokenizer.add_special_tokens(special_tokens_dict) |
|
image_token_idx = tokenizer.encode("<image>", add_special_tokens=False)[-1] |
|
|
|
|
|
image = img.convert("RGB") |
|
image = expand2square(image) |
|
image = torch.from_numpy(np.array(image)).to( |
|
dtype=model.dtype, device=self.device |
|
) |
|
image = rearrange(image, "h w c -> c h w")[None] |
|
image = 2 * (image / 255) - 1 |
|
|
|
|
|
full_prompt = model.prompt_template["INSTRUCTION"].format( |
|
input="<image>\n" + prompt |
|
) |
|
image_length = (self.image_size // 16) ** 2 + 64 |
|
full_prompt = full_prompt.replace("<image>", "<image>" * image_length) |
|
input_ids = tokenizer.encode( |
|
full_prompt, add_special_tokens=True, return_tensors="pt" |
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
_, z_enc = model.extract_visual_feature(model.encode(image)) |
|
inputs_embeds = z_enc.new_zeros(*input_ids.shape, model.llm.config.hidden_size) |
|
inputs_embeds[input_ids == image_token_idx] = z_enc.flatten(0, 1) |
|
inputs_embeds[input_ids != image_token_idx] = model.llm.get_input_embeddings()( |
|
input_ids[input_ids != image_token_idx] |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.llm.generate( |
|
inputs_embeds=inputs_embeds, |
|
use_cache=True, |
|
do_sample=False, |
|
max_new_tokens=4096, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, |
|
) |
|
|
|
return tokenizer.decode(output[0]) |
|
|
|
@spaces.GPU(duration=120) |
|
def edit_image( |
|
self, |
|
source_image: Image.Image, |
|
prompt: str, |
|
num_iter: int = 48, |
|
cfg: float = 3.0, |
|
cfg_prompt: str = "Repeat this image.", |
|
cfg_schedule: str = "constant", |
|
temperature: float = 0.85, |
|
grid_size: int = 1 |
|
) -> Image.Image: |
|
"""Edit single image based on prompt.""" |
|
|
|
model = self.model |
|
if not model: |
|
model = self.init_model() |
|
tokenizer = model.tokenizer |
|
special_tokens_dict = {"additional_special_tokens": ["<image>"]} |
|
tokenizer.add_special_tokens(special_tokens_dict) |
|
image_token_idx = tokenizer.encode("<image>", add_special_tokens=False)[-1] |
|
device = "cuda" |
|
|
|
img_tensor = preprocess_image(source_image, self.image_size, model.dtype).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
x_enc = model.encode(img_tensor) |
|
x_con, z_enc = model.extract_visual_feature(x_enc) |
|
|
|
|
|
m = n = self.image_size // 16 |
|
image_length = m * n + 64 |
|
|
|
if hasattr(self.cfg.model, 'prompt_template'): |
|
prompt_str = self.cfg.model.prompt_template['INSTRUCTION'].format( |
|
input="<image>\n" + prompt.strip() |
|
) |
|
cfg_prompt_str = self.cfg.model.prompt_template['INSTRUCTION'].format( |
|
input="<image>\n" + cfg_prompt.strip() |
|
) |
|
else: |
|
prompt_str = f"<image>\n{prompt.strip()}" |
|
cfg_prompt_str = f"<image>\n{cfg_prompt.strip()}" |
|
|
|
|
|
prompt_str = prompt_str.replace('<image>', '<image>' * image_length) |
|
cfg_prompt_str = cfg_prompt_str.replace('<image>', '<image>' * image_length) |
|
|
|
|
|
input_ids = model.tokenizer.encode( |
|
prompt_str, add_special_tokens=True, return_tensors='pt')[0].cuda() |
|
|
|
if cfg != 1.0: |
|
null_input_ids = model.tokenizer.encode( |
|
cfg_prompt_str, add_special_tokens=True, return_tensors='pt')[0].cuda() |
|
attention_mask = pad_sequence( |
|
[torch.ones_like(input_ids), torch.ones_like(null_input_ids)], |
|
batch_first=True, padding_value=0).to(torch.bool) |
|
input_ids = pad_sequence( |
|
[input_ids, null_input_ids], |
|
batch_first=True, padding_value=model.tokenizer.eos_token_id) |
|
else: |
|
input_ids = input_ids[None] |
|
attention_mask = torch.ones_like(input_ids).to(torch.bool) |
|
|
|
|
|
if cfg != 1.0: |
|
z_enc = torch.cat([z_enc, z_enc], dim=0) |
|
x_con = torch.cat([x_con, x_con], dim=0) |
|
|
|
inputs_embeds = z_enc.new_zeros(*input_ids.shape, model.llm.config.hidden_size) |
|
|
|
inputs_embeds[input_ids == image_token_idx] = z_enc.flatten(0, 1) |
|
inputs_embeds[input_ids != image_token_idx] = model.llm.get_input_embeddings()( |
|
input_ids[input_ids != image_token_idx] |
|
) |
|
|
|
|
|
bsz = grid_size ** 2 |
|
x_con = torch.cat([x_con] * bsz) |
|
if cfg != 1.0: |
|
inputs_embeds = torch.cat([ |
|
inputs_embeds[:1].expand(bsz, -1, -1), |
|
inputs_embeds[1:].expand(bsz, -1, -1), |
|
]) |
|
attention_mask = torch.cat([ |
|
attention_mask[:1].expand(bsz, -1), |
|
attention_mask[1:].expand(bsz, -1), |
|
]) |
|
else: |
|
inputs_embeds = inputs_embeds.expand(bsz, -1, -1) |
|
attention_mask = attention_mask.expand(bsz, -1) |
|
|
|
|
|
samples = model.sample( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
num_iter=num_iter, |
|
cfg=cfg, |
|
cfg_schedule=cfg_schedule, |
|
temperature=temperature, |
|
progress=False, |
|
image_shape=(m, n), |
|
x_con=x_con |
|
) |
|
|
|
|
|
samples = rearrange(samples, '(m n) c h w -> (m h) (n w) c', m=grid_size, n=grid_size) |
|
samples = torch.clamp(127.5 * samples + 128.0, 0, 255) |
|
out = samples.to("cpu", torch.uint8).numpy() |
|
|
|
return [ Image.fromarray(out) ] |
|
|
|
@spaces.GPU(duration=120) |
|
def query_text(self, prompt=""): |
|
model = self.model |
|
if not model: |
|
model = self.init_model() |
|
tokenizer = model.tokenizer |
|
|
|
|
|
full_prompt = model.prompt_template["INSTRUCTION"].format(input=prompt) |
|
input_ids = tokenizer.encode( |
|
full_prompt, add_special_tokens=True, return_tensors="pt" |
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.llm.generate( |
|
input_ids=input_ids, |
|
use_cache=True, |
|
do_sample=True, |
|
max_new_tokens=1024, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, |
|
) |
|
|
|
res = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
return res |
|
|