LangScene-X / field_construction /preprocessor.py
seawolf2357's picture
Upload folder using huggingface_hub
684943d verified
import glob
import logging
import os
import shutil
import subprocess
import cv2
import numpy as np
import torch
from diffusers.models.autoencoders.vq_model import VQModel
from safetensors.torch import load_file
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from .auto_encoder import Autoencoder, Autoencoder_dataset
from .pose_estimator import get_pose_estimator
from .utils.loss_utils import cos_loss, l2_loss
from .video_preprocessor import VideoPreprocessor
def extract_with_openseg(cfg):
import tensorflow as tf2
import tensorflow._api.v2.compat.v1 as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
openseg = tf2.saved_model.load(
cfg.feature_extractor.model_path,
tags=[tf.saved_model.tag_constants.SERVING]
)
imgs_path = os.path.join(cfg.pipeline.data_path, "input")
img_names = list(
filter(
lambda x: x.endswith("png") or x.endswith("jpg"), sorted(os.listdir(imgs_path))
)
)
img_list = []
np_image_string_list = []
for img_name in img_names:
img_path = os.path.join(imgs_path, img_name)
image = cv2.imread(img_path)
with tf.gfile.GFile(img_path, 'rb') as f:
np_image_string = np.array([f.read()])
image = torch.from_numpy(image)
img_list.append(image)
np_image_string_list.append(np_image_string)
images = [img_list[i].permute(2, 0, 1)[None, ...] for i in range(len(img_list))]
imgs = torch.cat(images)
save_path = os.path.join(cfg.pipeline.data_path, "lang_features")
os.makedirs(save_path, exist_ok=True)
embed_size = 768
for i, (img, np_image_string) in enumerate(tqdm((zip(imgs, np_image_string_list)), desc="Extracting lang features", total=(len(imgs)))):
text_emb = tf.zeros([1, 1, embed_size])
results = openseg.signatures["serving_default"](
inp_image_bytes=tf.convert_to_tensor(np_image_string[0]),
inp_text_emb=text_emb
)
img_info = results['image_info']
crop_sz = [
int(img_info[0, 0] * img_info[2, 0]),
int(img_info[0, 1] * img_info[2, 1])
]
image_embedding_feat = results['image_embedding_feat'][:, :crop_sz[0], :crop_sz[1]]
img_size = (img.shape[1], img.shape[2])
feat_2d = tf.cast(
tf.image.resize_nearest_neighbor(
image_embedding_feat, img_size, align_corners=True
)[0], dtype=tf.float32
).numpy()
# perform mask-pooling over feat2d
feat_2d = np.transpose(feat_2d, axes=(2, 0, 1))
pooled_feats2d = []
curr_mask = np.load(os.path.join(cfg.pipeline.data_path, "lang_features_dim3", str(i+1).zfill(4)+"_s.npy"))
for color_id in range(-1, curr_mask.max() + 1):
if not feat_2d[:, curr_mask == color_id].shape[-1]:
continue
pooled = feat_2d[:, curr_mask == color_id].mean(axis=-1)
pooled /= np.linalg.norm(pooled)
pooled_feats2d.append(pooled)
pooled_feats2d = np.stack(pooled_feats2d)
np.save(os.path.join(save_path, str(i+1).zfill(4)+".npy"), pooled_feats2d)
class Preprocessor:
def __init__(self, cfg):
self.cfg = cfg
if not cfg.pipeline.skip_video_process:
self.video_processor = VideoPreprocessor(cfg)
else:
self.video_processor = None
if not cfg.pipeline.skip_pose_estimate:
self.pose_estimator = get_pose_estimator(cfg)
else:
self.pose_estimator = None
if not cfg.pipeline.skip_lang_feature_extraction:
# load feature extractor
if cfg.feature_extractor.type == "open-seg":
self.lseg = None
self.sem_ae = Autoencoder()
self.sem_ae.cuda()
elif cfg.feature_extractor.type == "lseg":
self.lseg = LSegFeatureExtractor.from_pretrained(cfg.lseg.model_path)
self.lseg.to(cfg.lseg.device, dtype=torch.float32).eval()
self.sem_ae = VQModel(
in_channels=512,
out_channels=512,
latent_channels=4,
norm_num_groups=2,
block_out_channels=[256, 64, 16],
down_block_types=["DownEncoderBlock2D"] * 3,
up_block_types=["UpDecoderBlock2D"] * 3,
layers_per_block=1,
norm_type="spatial",
num_vq_embeddings=1024,
)
self.sem_ae.load_state_dict(load_file(cfg.ae.model_path))
self.sem_ae.to(cfg.ae.device, dtype=torch.float32).eval()
self.img_transform = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255),
transforms.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True
),
]
)
else:
self.lseg = None
self.sem_ae = None
self.img_transform = None
def generate_lang_features_with_openseg(self):
extract_with_openseg(self.cfg)
logging.info("Done feature extraction.")
num_epochs = 400
os.makedirs(os.path.join(self.cfg.pipeline.data_path, "ckpt"), exist_ok=True)
save_path = os.path.join(self.cfg.pipeline.data_path, "lang_features")
train_dataset = Autoencoder_dataset(save_path)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=512,
shuffle=True,
num_workers=32,
drop_last=False
)
test_loader = DataLoader(
dataset=train_dataset,
batch_size=512,
shuffle=False,
num_workers=32,
drop_last=False
)
optimizer = torch.optim.Adam(self.sem_ae.parameters(), lr=1e-4)
pbar = tqdm(range(num_epochs))
best_eval_loss = 100.0
best_epoch = 0
for epoch in pbar:
self.sem_ae.train()
for idx, feature in enumerate(train_loader):
data = feature.to("cuda")
outputs_dim3 = self.sem_ae.encode(data)
outputs = self.sem_ae.decode(outputs_dim3)
l2loss = l2_loss(outputs, data)
cosloss = cos_loss(outputs, data)
loss = l2loss + cosloss * 0.001
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch > 300:
eval_loss = 0.0
self.sem_ae.eval()
for idx, feature in enumerate(test_loader):
data = feature.to("cuda")
with torch.no_grad():
outputs = self.sem_ae(data)
loss = l2_loss(outputs, data) + cos_loss(outputs, data)
eval_loss += loss * len(feature)
eval_loss = eval_loss / len(train_dataset)
print("eval_loss:{:.8f}".format(eval_loss))
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
best_epoch = epoch
torch.save(self.sem_ae.state_dict(), os.path.join(self.cfg.pipeline.data_path, "ckpt", "best_ckpt.pth"))
pbar.set_postfix({"Loss": f"{loss.item():.{7}f}"})
pbar.update(1)
print(f"best_epoch: {best_epoch}")
print("best_loss: {:.8f}".format(best_eval_loss))
# compress lang_feats with ae
logging.info("Compresing language features with best ckpt...")
best_state_dict = torch.load(os.path.join(self.cfg.pipeline.data_path, "ckpt", "best_ckpt.pth"), weights_only=False)
self.sem_ae.load_state_dict(best_state_dict)
# check device
orig_lang_feat_names = sorted(glob.glob(os.path.join(save_path, "*.npy")))
dim3_save_path = os.path.join(self.cfg.pipeline.data_path, "lang_features_dim3")
with torch.no_grad():
for idx, orig_lang_feat_name in enumerate(orig_lang_feat_names):
orig_lang_feat = torch.from_numpy(np.load(orig_lang_feat_name)).cuda()
mask = np.load(os.path.join(dim3_save_path, str(idx+1).zfill(4)+"_s.npy"))
# check dtype
lang_feat = self.sem_ae.encode(orig_lang_feat).detach().cpu().numpy()
full_lang_feat = np.zeros((3, mask.shape[0], mask.shape[1]))
curr_id = 0
for color_id in range(-1, mask.max() + 1):
if not mask[mask == color_id].shape[-1]:
continue
full_lang_feat[:, mask == color_id] = lang_feat[curr_id][:, None]
curr_id += 1
np.save(os.path.join(dim3_save_path, str(idx+1).zfill(4)+"_f.npy"), full_lang_feat)
def generate_lang_features_with_lseg(self):
from cogvideox_interpolation.lseg import LSegFeatureExtractor
imgs_path = os.path.join(self.cfg.pipeline.data_path, "input")
img_names = list(
filter(
lambda x: x.endswith("png") or x.endswith("jpg"), os.listdir(imgs_path)
)
)
save_path = os.path.join(self.cfg.pipeline.data_path, "lang_features_dim4")
os.makedirs(save_path, exist_ok=True)
for img_name in tqdm(img_names):
img_path = os.path.join(imgs_path, img_name)
img = cv2.imread(img_path)
resolution = (640, 480)
img = cv2.resize(img, resolution)
frame_embed = self.img_transform(torch.from_numpy(img).permute(2, 0, 1)).to(
self.cfg.lseg.device, dtype=torch.float32
)[None, ...]
lseg_features = self.lseg.extract_features(frame_embed)
if lseg_features.device != self.sem_ae.device:
lseg_features = lseg_features.to("cpu").to(self.sem_ae.device)
z = self.sem_ae.encode(lseg_features).latents # [1, 4, 240, 320]
np.save(
os.path.join(save_path, f"{img_name.split('.')[0]}_f.npy"),
z.detach().cpu().numpy(),
)
def select_valid_data(self):
cfg = self.cfg
curr_data_path = cfg.pipeline.data_path
raw_data_path = os.path.join(curr_data_path, "raw")
os.makedirs(raw_data_path, exist_ok=True)
dirs_to_move = ["camera", "input", "lang_features_dim3", "normal"]
orig_view_nums = len(os.listdir(os.path.join(curr_data_path, "camera")))
indexs = np.linspace(0, orig_view_nums-1, cfg.pipeline.chunk_num * cfg.pipeline.keep_num_per_chunk)
indexs = indexs.astype(np.int32).tolist()
cfg.pipeline.selected_idxs = indexs
for dir_to_move in dirs_to_move:
shutil.move(os.path.join(curr_data_path, dir_to_move), raw_data_path)
src_dir = os.path.join(raw_data_path, dir_to_move)
tar_dir = os.path.join(curr_data_path, dir_to_move)
os.makedirs(tar_dir, exist_ok=True)
file_lst = sorted(os.listdir(src_dir))
file_suffix = file_lst[0].split(".")[-1]
if dir_to_move == "lang_features_dim3":
f_file_lst = [file_lst[2 * idx] for idx in cfg.pipeline.selected_idxs]
s_file_lst = [file_lst[2 * idx + 1] for idx in cfg.pipeline.selected_idxs]
for file_idx in range(len(f_file_lst)):
shutil.copy(
os.path.join(src_dir, f_file_lst[file_idx]),
os.path.join(tar_dir, f"{file_idx+1:04d}_f.{file_suffix}"),
)
shutil.copy(
os.path.join(src_dir, s_file_lst[file_idx]),
os.path.join(tar_dir, f"{file_idx+1:04d}_s.{file_suffix}"),
)
else:
file_lst = [file_lst[idx] for idx in cfg.pipeline.selected_idxs]
for file_idx, file_name in enumerate(file_lst):
shutil.copy(
os.path.join(src_dir, file_name),
os.path.join(tar_dir, f"{file_idx+1:04d}.{file_suffix}"),
)
def preprocess(self):
if not self.cfg.pipeline.skip_video_process:
logging.info("Processing input videos...")
self.video_processor.video_process()
if not self.cfg.pipeline.skip_pose_estimate:
logging.info("Estimating poses...")
self.pose_estimator.get_poses()
if not self.cfg.pipeline.skip_lang_feature_extraction:
logging.info("Generating language features...")
if self.cfg.feature_extractor.type == "lseg":
self.generate_lang_features_with_lseg()
elif self.cfg.feature_extractor.type == "open-seg":
self.generate_lang_features_with_openseg()
if self.cfg.pipeline.selection:
logging.info("Selecting views with higher confidence...")
self.select_valid_data()
logging.info("Done all preprocessing!")