File size: 3,575 Bytes
6342ac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

import os
import numpy as np
import yaml
import torch
from cog import BasePredictor, Input, Path

from utils import *
from models import instructir
from text.models import LanguageModel, LMHead

os.environ["TOKENIZERS_PARALLELISM"] = "false"


class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""

        LM_MODEL = "models/lm_instructir-7d.pt"
        MODEL_NAME = "models/im_instructir-7d.pt"
        device = torch.device("cpu")

        with open(os.path.join("configs/eval5d.yml"), "r") as f:
            config = yaml.safe_load(f)

        cfg = dict2namespace(config)

        torch.backends.cudnn.deterministic = True
        self.model = instructir.create_model(
            input_channels=cfg.model.in_ch,
            width=cfg.model.width,
            enc_blks=cfg.model.enc_blks,
            middle_blk_num=cfg.model.middle_blk_num,
            dec_blks=cfg.model.dec_blks,
            txtdim=cfg.model.textdim,
        )

        self.model = self.model.to(device)
        print("IMAGE MODEL CKPT:", MODEL_NAME)
        self.model.load_state_dict(
            torch.load(MODEL_NAME, map_location="cpu"), strict=True
        )

        # Initialize the LanguageModel class
        LMODEL = cfg.llm.model
        self.language_model = LanguageModel(model=LMODEL)
        self.lm_head = LMHead(
            embedding_dim=cfg.llm.model_dim,
            hidden_dim=cfg.llm.embd_dim,
            num_classes=cfg.llm.nclasses,
        )
        self.lm_head = self.lm_head  # .to(device)

        print("LMHEAD MODEL CKPT:", LM_MODEL)
        self.lm_head.load_state_dict(
            torch.load(LM_MODEL, map_location="cpu"), strict=True
        )
        print("Loaded weights!")

    def predict(
        self,
        image: Path = Input(description="Input image."),
        prompt: str = Input(description="Input prompt."),
        seed: int = Input(
            description="Random seed. Leave blank to randomize the seed", default=None
        ),
    ) -> Path:
        """Run a single prediction on the model"""
        if seed is None:
            seed = int.from_bytes(os.urandom(2), "big")
        print(f"Using seed: {seed}")
        seed_everything(SEED=seed)

        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        image = load_img(str(image))
        out_image = process_img(
            image, prompt, self.language_model, self.model, self.lm_head
        )

        out_path = "/tmp/out.png"
        saveImage(out_path, out_image)

        return Path(out_path)


def process_img(image, prompt, language_model, model, lm_head):
    """
    Given an image and a prompt, we run InstructIR to restore the image following the human prompt.
    image: RGB image as numpy array normalized to [0,1]
    prompt: plain python string,

    returns the restored image as numpy array.
    """

    # Convert the image to tensor
    y = torch.Tensor(image).permute(2, 0, 1).unsqueeze(0)

    # Get the text embedding (and predicted degradation class)
    lm_embd = language_model(prompt)
    lm_embd = lm_embd  # .to(device)
    text_embd, deg_pred = lm_head(lm_embd)

    # Forward pass: Paper Figure 2
    x_hat = model(y, text_embd)

    # convert the restored image <x_hat> into a np array
    restored_img = x_hat[0].permute(1, 2, 0).cpu().detach().numpy()
    restored_img = np.clip(restored_img, 0.0, 1.0)
    return restored_img