File size: 2,014 Bytes
6342ac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import torch
import numpy as np
from PIL import Image
import yaml
import os
from utils import load_img, plot_all, dict2namespace, seed_everything
from models import instructir
from text.models import LanguageModel, LMHead

# Setup
SEED = 42
seed_everything(SEED=SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CONFIG = "configs/eval5d.yml"
LM_MODEL = "models/lm_instructir-7d.pt"
MODEL_NAME = "models/im_instructir-7d.pt"

# Load config
with open(CONFIG, "r") as f:
    config = yaml.safe_load(f)
cfg = dict2namespace(config)

# Load image model
model = instructir.create_model(
    input_channels=cfg.model.in_ch,
    width=cfg.model.width,
    enc_blks=cfg.model.enc_blks,
    middle_blk_num=cfg.model.middle_blk_num,
    dec_blks=cfg.model.dec_blks,
    txtdim=cfg.model.textdim
).to(device)
model.load_state_dict(torch.load(MODEL_NAME, map_location=device), strict=True)
model.eval()

# Load language model
language_model = LanguageModel(model=cfg.llm.model)
lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses)
lm_head.load_state_dict(torch.load(LM_MODEL, map_location=device), strict=True)
lm_head.eval()

def process(image, prompt):
    image = image.convert("RGB")
    image = np.array(image).astype(np.float32) / 255.0
    y = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).to(device)

    lm_embd = language_model(prompt)
    text_embd, _ = lm_head(lm_embd)
    x_hat = model(y, text_embd)

    restored = x_hat[0].permute(1, 2, 0).cpu().detach().numpy()
    restored = np.clip(restored, 0., 1.)
    restored = (restored * 255).astype(np.uint8)
    return Image.fromarray(restored)

interface = gr.Interface(
    fn=process,
    inputs=[
        gr.Image(type="pil"),
        gr.Textbox(label="Prompt", placeholder="Describe the restoration..."),
    ],
    outputs=gr.Image(type="pil"),
    title="swiftlens: Prompt-Guided Image Restoration"
)

if __name__ == "__main__":
    interface.launch()