Spaces:
Sleeping
Sleeping
# Prediction interface for Cog ⚙️ | |
# https://github.com/replicate/cog/blob/main/docs/python.md | |
import os | |
import numpy as np | |
import yaml | |
import torch | |
from cog import BasePredictor, Input, Path | |
from utils import * | |
from models import instructir | |
from text.models import LanguageModel, LMHead | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
class Predictor(BasePredictor): | |
def setup(self) -> None: | |
"""Load the model into memory to make running multiple predictions efficient""" | |
LM_MODEL = "models/lm_instructir-7d.pt" | |
MODEL_NAME = "models/im_instructir-7d.pt" | |
device = torch.device("cpu") | |
with open(os.path.join("configs/eval5d.yml"), "r") as f: | |
config = yaml.safe_load(f) | |
cfg = dict2namespace(config) | |
torch.backends.cudnn.deterministic = True | |
self.model = instructir.create_model( | |
input_channels=cfg.model.in_ch, | |
width=cfg.model.width, | |
enc_blks=cfg.model.enc_blks, | |
middle_blk_num=cfg.model.middle_blk_num, | |
dec_blks=cfg.model.dec_blks, | |
txtdim=cfg.model.textdim, | |
) | |
self.model = self.model.to(device) | |
print("IMAGE MODEL CKPT:", MODEL_NAME) | |
self.model.load_state_dict( | |
torch.load(MODEL_NAME, map_location="cpu"), strict=True | |
) | |
# Initialize the LanguageModel class | |
LMODEL = cfg.llm.model | |
self.language_model = LanguageModel(model=LMODEL) | |
self.lm_head = LMHead( | |
embedding_dim=cfg.llm.model_dim, | |
hidden_dim=cfg.llm.embd_dim, | |
num_classes=cfg.llm.nclasses, | |
) | |
self.lm_head = self.lm_head # .to(device) | |
print("LMHEAD MODEL CKPT:", LM_MODEL) | |
self.lm_head.load_state_dict( | |
torch.load(LM_MODEL, map_location="cpu"), strict=True | |
) | |
print("Loaded weights!") | |
def predict( | |
self, | |
image: Path = Input(description="Input image."), | |
prompt: str = Input(description="Input prompt."), | |
seed: int = Input( | |
description="Random seed. Leave blank to randomize the seed", default=None | |
), | |
) -> Path: | |
"""Run a single prediction on the model""" | |
if seed is None: | |
seed = int.from_bytes(os.urandom(2), "big") | |
print(f"Using seed: {seed}") | |
seed_everything(SEED=seed) | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats() | |
image = load_img(str(image)) | |
out_image = process_img( | |
image, prompt, self.language_model, self.model, self.lm_head | |
) | |
out_path = "/tmp/out.png" | |
saveImage(out_path, out_image) | |
return Path(out_path) | |
def process_img(image, prompt, language_model, model, lm_head): | |
""" | |
Given an image and a prompt, we run InstructIR to restore the image following the human prompt. | |
image: RGB image as numpy array normalized to [0,1] | |
prompt: plain python string, | |
returns the restored image as numpy array. | |
""" | |
# Convert the image to tensor | |
y = torch.Tensor(image).permute(2, 0, 1).unsqueeze(0) | |
# Get the text embedding (and predicted degradation class) | |
lm_embd = language_model(prompt) | |
lm_embd = lm_embd # .to(device) | |
text_embd, deg_pred = lm_head(lm_embd) | |
# Forward pass: Paper Figure 2 | |
x_hat = model(y, text_embd) | |
# convert the restored image <x_hat> into a np array | |
restored_img = x_hat[0].permute(1, 2, 0).cpu().detach().numpy() | |
restored_img = np.clip(restored_img, 0.0, 1.0) | |
return restored_img | |