Spaces:
Running
Running
import glob | |
import logging | |
import os | |
import shutil | |
import subprocess | |
import cv2 | |
import numpy as np | |
import torch | |
from diffusers.models.autoencoders.vq_model import VQModel | |
from safetensors.torch import load_file | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from tqdm import tqdm | |
from .auto_encoder import Autoencoder, Autoencoder_dataset | |
from .pose_estimator import get_pose_estimator | |
from .utils.loss_utils import cos_loss, l2_loss | |
from .video_preprocessor import VideoPreprocessor | |
def extract_with_openseg(cfg): | |
import tensorflow as tf2 | |
import tensorflow._api.v2.compat.v1 as tf | |
gpus = tf.config.experimental.list_physical_devices('GPU') | |
if gpus: | |
try: | |
for gpu in gpus: | |
tf.config.experimental.set_memory_growth(gpu, True) | |
except RuntimeError as e: | |
print(e) | |
openseg = tf2.saved_model.load( | |
cfg.feature_extractor.model_path, | |
tags=[tf.saved_model.tag_constants.SERVING] | |
) | |
imgs_path = os.path.join(cfg.pipeline.data_path, "input") | |
img_names = list( | |
filter( | |
lambda x: x.endswith("png") or x.endswith("jpg"), sorted(os.listdir(imgs_path)) | |
) | |
) | |
img_list = [] | |
np_image_string_list = [] | |
for img_name in img_names: | |
img_path = os.path.join(imgs_path, img_name) | |
image = cv2.imread(img_path) | |
with tf.gfile.GFile(img_path, 'rb') as f: | |
np_image_string = np.array([f.read()]) | |
image = torch.from_numpy(image) | |
img_list.append(image) | |
np_image_string_list.append(np_image_string) | |
images = [img_list[i].permute(2, 0, 1)[None, ...] for i in range(len(img_list))] | |
imgs = torch.cat(images) | |
save_path = os.path.join(cfg.pipeline.data_path, "lang_features") | |
os.makedirs(save_path, exist_ok=True) | |
embed_size = 768 | |
for i, (img, np_image_string) in enumerate(tqdm((zip(imgs, np_image_string_list)), desc="Extracting lang features", total=(len(imgs)))): | |
text_emb = tf.zeros([1, 1, embed_size]) | |
results = openseg.signatures["serving_default"]( | |
inp_image_bytes=tf.convert_to_tensor(np_image_string[0]), | |
inp_text_emb=text_emb | |
) | |
img_info = results['image_info'] | |
crop_sz = [ | |
int(img_info[0, 0] * img_info[2, 0]), | |
int(img_info[0, 1] * img_info[2, 1]) | |
] | |
image_embedding_feat = results['image_embedding_feat'][:, :crop_sz[0], :crop_sz[1]] | |
img_size = (img.shape[1], img.shape[2]) | |
feat_2d = tf.cast( | |
tf.image.resize_nearest_neighbor( | |
image_embedding_feat, img_size, align_corners=True | |
)[0], dtype=tf.float32 | |
).numpy() | |
# perform mask-pooling over feat2d | |
feat_2d = np.transpose(feat_2d, axes=(2, 0, 1)) | |
pooled_feats2d = [] | |
curr_mask = np.load(os.path.join(cfg.pipeline.data_path, "lang_features_dim3", str(i+1).zfill(4)+"_s.npy")) | |
for color_id in range(-1, curr_mask.max() + 1): | |
if not feat_2d[:, curr_mask == color_id].shape[-1]: | |
continue | |
pooled = feat_2d[:, curr_mask == color_id].mean(axis=-1) | |
pooled /= np.linalg.norm(pooled) | |
pooled_feats2d.append(pooled) | |
pooled_feats2d = np.stack(pooled_feats2d) | |
np.save(os.path.join(save_path, str(i+1).zfill(4)+".npy"), pooled_feats2d) | |
class Preprocessor: | |
def __init__(self, cfg): | |
self.cfg = cfg | |
if not cfg.pipeline.skip_video_process: | |
self.video_processor = VideoPreprocessor(cfg) | |
else: | |
self.video_processor = None | |
if not cfg.pipeline.skip_pose_estimate: | |
self.pose_estimator = get_pose_estimator(cfg) | |
else: | |
self.pose_estimator = None | |
if not cfg.pipeline.skip_lang_feature_extraction: | |
# load feature extractor | |
if cfg.feature_extractor.type == "open-seg": | |
self.lseg = None | |
self.sem_ae = Autoencoder() | |
self.sem_ae.cuda() | |
elif cfg.feature_extractor.type == "lseg": | |
self.lseg = LSegFeatureExtractor.from_pretrained(cfg.lseg.model_path) | |
self.lseg.to(cfg.lseg.device, dtype=torch.float32).eval() | |
self.sem_ae = VQModel( | |
in_channels=512, | |
out_channels=512, | |
latent_channels=4, | |
norm_num_groups=2, | |
block_out_channels=[256, 64, 16], | |
down_block_types=["DownEncoderBlock2D"] * 3, | |
up_block_types=["UpDecoderBlock2D"] * 3, | |
layers_per_block=1, | |
norm_type="spatial", | |
num_vq_embeddings=1024, | |
) | |
self.sem_ae.load_state_dict(load_file(cfg.ae.model_path)) | |
self.sem_ae.to(cfg.ae.device, dtype=torch.float32).eval() | |
self.img_transform = transforms.Compose( | |
[ | |
transforms.Lambda(lambda x: x / 255), | |
transforms.Normalize( | |
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True | |
), | |
] | |
) | |
else: | |
self.lseg = None | |
self.sem_ae = None | |
self.img_transform = None | |
def generate_lang_features_with_openseg(self): | |
extract_with_openseg(self.cfg) | |
logging.info("Done feature extraction.") | |
num_epochs = 400 | |
os.makedirs(os.path.join(self.cfg.pipeline.data_path, "ckpt"), exist_ok=True) | |
save_path = os.path.join(self.cfg.pipeline.data_path, "lang_features") | |
train_dataset = Autoencoder_dataset(save_path) | |
train_loader = DataLoader( | |
dataset=train_dataset, | |
batch_size=512, | |
shuffle=True, | |
num_workers=32, | |
drop_last=False | |
) | |
test_loader = DataLoader( | |
dataset=train_dataset, | |
batch_size=512, | |
shuffle=False, | |
num_workers=32, | |
drop_last=False | |
) | |
optimizer = torch.optim.Adam(self.sem_ae.parameters(), lr=1e-4) | |
pbar = tqdm(range(num_epochs)) | |
best_eval_loss = 100.0 | |
best_epoch = 0 | |
for epoch in pbar: | |
self.sem_ae.train() | |
for idx, feature in enumerate(train_loader): | |
data = feature.to("cuda") | |
outputs_dim3 = self.sem_ae.encode(data) | |
outputs = self.sem_ae.decode(outputs_dim3) | |
l2loss = l2_loss(outputs, data) | |
cosloss = cos_loss(outputs, data) | |
loss = l2loss + cosloss * 0.001 | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if epoch > 300: | |
eval_loss = 0.0 | |
self.sem_ae.eval() | |
for idx, feature in enumerate(test_loader): | |
data = feature.to("cuda") | |
with torch.no_grad(): | |
outputs = self.sem_ae(data) | |
loss = l2_loss(outputs, data) + cos_loss(outputs, data) | |
eval_loss += loss * len(feature) | |
eval_loss = eval_loss / len(train_dataset) | |
print("eval_loss:{:.8f}".format(eval_loss)) | |
if eval_loss < best_eval_loss: | |
best_eval_loss = eval_loss | |
best_epoch = epoch | |
torch.save(self.sem_ae.state_dict(), os.path.join(self.cfg.pipeline.data_path, "ckpt", "best_ckpt.pth")) | |
pbar.set_postfix({"Loss": f"{loss.item():.{7}f}"}) | |
pbar.update(1) | |
print(f"best_epoch: {best_epoch}") | |
print("best_loss: {:.8f}".format(best_eval_loss)) | |
# compress lang_feats with ae | |
logging.info("Compresing language features with best ckpt...") | |
best_state_dict = torch.load(os.path.join(self.cfg.pipeline.data_path, "ckpt", "best_ckpt.pth"), weights_only=False) | |
self.sem_ae.load_state_dict(best_state_dict) | |
# check device | |
orig_lang_feat_names = sorted(glob.glob(os.path.join(save_path, "*.npy"))) | |
dim3_save_path = os.path.join(self.cfg.pipeline.data_path, "lang_features_dim3") | |
with torch.no_grad(): | |
for idx, orig_lang_feat_name in enumerate(orig_lang_feat_names): | |
orig_lang_feat = torch.from_numpy(np.load(orig_lang_feat_name)).cuda() | |
mask = np.load(os.path.join(dim3_save_path, str(idx+1).zfill(4)+"_s.npy")) | |
# check dtype | |
lang_feat = self.sem_ae.encode(orig_lang_feat).detach().cpu().numpy() | |
full_lang_feat = np.zeros((3, mask.shape[0], mask.shape[1])) | |
curr_id = 0 | |
for color_id in range(-1, mask.max() + 1): | |
if not mask[mask == color_id].shape[-1]: | |
continue | |
full_lang_feat[:, mask == color_id] = lang_feat[curr_id][:, None] | |
curr_id += 1 | |
np.save(os.path.join(dim3_save_path, str(idx+1).zfill(4)+"_f.npy"), full_lang_feat) | |
def generate_lang_features_with_lseg(self): | |
from cogvideox_interpolation.lseg import LSegFeatureExtractor | |
imgs_path = os.path.join(self.cfg.pipeline.data_path, "input") | |
img_names = list( | |
filter( | |
lambda x: x.endswith("png") or x.endswith("jpg"), os.listdir(imgs_path) | |
) | |
) | |
save_path = os.path.join(self.cfg.pipeline.data_path, "lang_features_dim4") | |
os.makedirs(save_path, exist_ok=True) | |
for img_name in tqdm(img_names): | |
img_path = os.path.join(imgs_path, img_name) | |
img = cv2.imread(img_path) | |
resolution = (640, 480) | |
img = cv2.resize(img, resolution) | |
frame_embed = self.img_transform(torch.from_numpy(img).permute(2, 0, 1)).to( | |
self.cfg.lseg.device, dtype=torch.float32 | |
)[None, ...] | |
lseg_features = self.lseg.extract_features(frame_embed) | |
if lseg_features.device != self.sem_ae.device: | |
lseg_features = lseg_features.to("cpu").to(self.sem_ae.device) | |
z = self.sem_ae.encode(lseg_features).latents # [1, 4, 240, 320] | |
np.save( | |
os.path.join(save_path, f"{img_name.split('.')[0]}_f.npy"), | |
z.detach().cpu().numpy(), | |
) | |
def select_valid_data(self): | |
cfg = self.cfg | |
curr_data_path = cfg.pipeline.data_path | |
raw_data_path = os.path.join(curr_data_path, "raw") | |
os.makedirs(raw_data_path, exist_ok=True) | |
dirs_to_move = ["camera", "input", "lang_features_dim3", "normal"] | |
orig_view_nums = len(os.listdir(os.path.join(curr_data_path, "camera"))) | |
indexs = np.linspace(0, orig_view_nums-1, cfg.pipeline.chunk_num * cfg.pipeline.keep_num_per_chunk) | |
indexs = indexs.astype(np.int32).tolist() | |
cfg.pipeline.selected_idxs = indexs | |
for dir_to_move in dirs_to_move: | |
shutil.move(os.path.join(curr_data_path, dir_to_move), raw_data_path) | |
src_dir = os.path.join(raw_data_path, dir_to_move) | |
tar_dir = os.path.join(curr_data_path, dir_to_move) | |
os.makedirs(tar_dir, exist_ok=True) | |
file_lst = sorted(os.listdir(src_dir)) | |
file_suffix = file_lst[0].split(".")[-1] | |
if dir_to_move == "lang_features_dim3": | |
f_file_lst = [file_lst[2 * idx] for idx in cfg.pipeline.selected_idxs] | |
s_file_lst = [file_lst[2 * idx + 1] for idx in cfg.pipeline.selected_idxs] | |
for file_idx in range(len(f_file_lst)): | |
shutil.copy( | |
os.path.join(src_dir, f_file_lst[file_idx]), | |
os.path.join(tar_dir, f"{file_idx+1:04d}_f.{file_suffix}"), | |
) | |
shutil.copy( | |
os.path.join(src_dir, s_file_lst[file_idx]), | |
os.path.join(tar_dir, f"{file_idx+1:04d}_s.{file_suffix}"), | |
) | |
else: | |
file_lst = [file_lst[idx] for idx in cfg.pipeline.selected_idxs] | |
for file_idx, file_name in enumerate(file_lst): | |
shutil.copy( | |
os.path.join(src_dir, file_name), | |
os.path.join(tar_dir, f"{file_idx+1:04d}.{file_suffix}"), | |
) | |
def preprocess(self): | |
if not self.cfg.pipeline.skip_video_process: | |
logging.info("Processing input videos...") | |
self.video_processor.video_process() | |
if not self.cfg.pipeline.skip_pose_estimate: | |
logging.info("Estimating poses...") | |
self.pose_estimator.get_poses() | |
if not self.cfg.pipeline.skip_lang_feature_extraction: | |
logging.info("Generating language features...") | |
if self.cfg.feature_extractor.type == "lseg": | |
self.generate_lang_features_with_lseg() | |
elif self.cfg.feature_extractor.type == "open-seg": | |
self.generate_lang_features_with_openseg() | |
if self.cfg.pipeline.selection: | |
logging.info("Selecting views with higher confidence...") | |
self.select_valid_data() | |
logging.info("Done all preprocessing!") | |