swiftlenss / eval_instructir.py
Lawliet18's picture
Add application file
6342ac4
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms.functional as TF
import json
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import yaml
import random
import gc
from utils import *
from models import instructir
from text.models import LanguageModel, LMHead
from test import test_model
def seed_everything(SEED=42):
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/eval5d.yml', help='Path to config file')
parser.add_argument('--model', type=str, default="models/im_instructir-7d.pt", help='Path to the image model weights')
parser.add_argument('--lm', type=str, default="models/lm_instructir-7d.pt", help='Path to the language model weights')
parser.add_argument('--promptify', type=str, default="simple_augment")
parser.add_argument('--device', type=int, default=0, help="GPU device")
parser.add_argument('--debug', action='store_true', help="Debug mode")
parser.add_argument('--save', type=str, default='results/', help="Path to save the resultant images")
args = parser.parse_args()
SEED=42
seed_everything(SEED=SEED)
torch.backends.cudnn.deterministic = True
GPU = args.device
DEBUG = args.debug
MODEL_NAME = args.model
CONFIG = args.config
LM_MODEL = args.lm
SAVE_PATH = args.save
print ('CUDA GPU available: ', torch.cuda.is_available())
torch.cuda.set_device(f'cuda:{GPU}')
device = torch.device(f'cuda:{GPU}' if torch.cuda.is_available() else "cpu")
print ('CUDA visible devices: ' + str(torch.cuda.device_count()))
print ('CUDA current device: ', torch.cuda.current_device(), torch.cuda.get_device_name(torch.cuda.current_device()))
# parse config file
with open(os.path.join(CONFIG), "r") as f:
config = yaml.safe_load(f)
cfg = dict2namespace(config)
print (20*"****")
print ("EVALUATION")
print (MODEL_NAME, LM_MODEL, device, DEBUG, CONFIG, args.promptify)
print (20*"****")
################### TESTING DATASET
TESTSETS = []
dn_testsets = []
rain_testsets = []
# Denoising
try:
for testset in cfg.test.dn_datasets:
for sigma in cfg.test.dn_sigmas:
noisy_testpath = os.path.join(cfg.test.dn_datapath, testset+ f"_{sigma}")
clean_testpath = os.path.join(cfg.test.dn_datapath, testset)
#print (clean_testpath, noisy_testpath)
dn_testsets.append([clean_testpath, noisy_testpath])
except:
dn_testsets = []
# RAIN
try:
for noisy_testpath, clean_testpath in zip(cfg.test.rain_inputs, cfg.test.rain_targets):
rain_testsets.append([clean_testpath, noisy_testpath])
except:
rain_testsets = []
# HAZE
try:
haze_testsets = [[cfg.test.haze_targets, cfg.test.haze_inputs]]
except:
haze_testsets = []
# BLUR
try:
blur_testsets = [[cfg.test.gopro_targets, cfg.test.gopro_inputs]]
except:
blur_testsets = []
# LOL
try:
lol_testsets = [[cfg.test.lol_targets, cfg.test.lol_inputs]]
except:
lol_testsets = []
# MIT5K
try:
mit_testsets = [[cfg.test.mit_targets, cfg.test.mit_inputs]]
except:
mit_testsets = []
TESTSETS += dn_testsets
TESTSETS += rain_testsets
TESTSETS += haze_testsets
TESTSETS += blur_testsets
TESTSETS += lol_testsets
TESTSETS += mit_testsets
# print ("Tests:", TESTSETS)
print ("TOTAL TESTSET:", len(TESTSETS))
print (20 * "----")
################### RESTORATION MODEL
print ("Creating InstructIR")
model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks,
middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim)
################### LOAD IMAGE MODEL
assert MODEL_NAME, "Model weights required for evaluation"
print ("IMAGE MODEL CKPT:", MODEL_NAME)
model.load_state_dict(torch.load(MODEL_NAME), strict=True)
model = model.to(device)
nparams = count_params (model)
print ("Loaded weights!", nparams / 1e6)
################### LANGUAGE MODEL
try:
PROMPT_DB = cfg.llm.text_db
except:
PROMPT_DB = None
if cfg.model.use_text:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Initialize the LanguageModel class
LMODEL = cfg.llm.model
language_model = LanguageModel(model=LMODEL)
lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses)
lm_head = lm_head.to(device)
lm_nparams = count_params (lm_head)
print ("LMHEAD MODEL CKPT:", LM_MODEL)
lm_head.load_state_dict(torch.load(LM_MODEL), strict=True)
print ("Loaded weights!")
else:
LMODEL = None
language_model = None
lm_head = None
lm_nparams = 0
print (20 * "----")
################### TESTING !!
from datasets import RefDegImage, augment_prompt, create_testsets
if args.promptify == "simple_augment":
promptify = augment_prompt
elif args.promptify == "chatgpt":
prompts = json.load(open(cfg.llm.text_db))
def promptify(deg):
return np.random.choice(prompts[deg])
else:
def promptify(deg):
return args.promptify
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
test_datasets = create_testsets(TESTSETS, debug=True)
test_model (model, language_model, lm_head, test_datasets, device, promptify, savepath=SAVE_PATH)