# Prediction interface for Cog ⚙️ # https://github.com/replicate/cog/blob/main/docs/python.md import os import numpy as np import yaml import torch from cog import BasePredictor, Input, Path from utils import * from models import instructir from text.models import LanguageModel, LMHead os.environ["TOKENIZERS_PARALLELISM"] = "false" class Predictor(BasePredictor): def setup(self) -> None: """Load the model into memory to make running multiple predictions efficient""" LM_MODEL = "models/lm_instructir-7d.pt" MODEL_NAME = "models/im_instructir-7d.pt" device = torch.device("cpu") with open(os.path.join("configs/eval5d.yml"), "r") as f: config = yaml.safe_load(f) cfg = dict2namespace(config) torch.backends.cudnn.deterministic = True self.model = instructir.create_model( input_channels=cfg.model.in_ch, width=cfg.model.width, enc_blks=cfg.model.enc_blks, middle_blk_num=cfg.model.middle_blk_num, dec_blks=cfg.model.dec_blks, txtdim=cfg.model.textdim, ) self.model = self.model.to(device) print("IMAGE MODEL CKPT:", MODEL_NAME) self.model.load_state_dict( torch.load(MODEL_NAME, map_location="cpu"), strict=True ) # Initialize the LanguageModel class LMODEL = cfg.llm.model self.language_model = LanguageModel(model=LMODEL) self.lm_head = LMHead( embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses, ) self.lm_head = self.lm_head # .to(device) print("LMHEAD MODEL CKPT:", LM_MODEL) self.lm_head.load_state_dict( torch.load(LM_MODEL, map_location="cpu"), strict=True ) print("Loaded weights!") def predict( self, image: Path = Input(description="Input image."), prompt: str = Input(description="Input prompt."), seed: int = Input( description="Random seed. Leave blank to randomize the seed", default=None ), ) -> Path: """Run a single prediction on the model""" if seed is None: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") seed_everything(SEED=seed) torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() image = load_img(str(image)) out_image = process_img( image, prompt, self.language_model, self.model, self.lm_head ) out_path = "/tmp/out.png" saveImage(out_path, out_image) return Path(out_path) def process_img(image, prompt, language_model, model, lm_head): """ Given an image and a prompt, we run InstructIR to restore the image following the human prompt. image: RGB image as numpy array normalized to [0,1] prompt: plain python string, returns the restored image as numpy array. """ # Convert the image to tensor y = torch.Tensor(image).permute(2, 0, 1).unsqueeze(0) # Get the text embedding (and predicted degradation class) lm_embd = language_model(prompt) lm_embd = lm_embd # .to(device) text_embd, deg_pred = lm_head(lm_embd) # Forward pass: Paper Figure 2 x_hat = model(y, text_embd) # convert the restored image into a np array restored_img = x_hat[0].permute(1, 2, 0).cpu().detach().numpy() restored_img = np.clip(restored_img, 0.0, 1.0) return restored_img