Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
import yaml | |
import os | |
from utils import load_img, plot_all, dict2namespace, seed_everything | |
from models import instructir | |
from text.models import LanguageModel, LMHead | |
# Setup | |
SEED = 42 | |
seed_everything(SEED=SEED) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
CONFIG = "configs/eval5d.yml" | |
LM_MODEL = "models/lm_instructir-7d.pt" | |
MODEL_NAME = "models/im_instructir-7d.pt" | |
# Load config | |
with open(CONFIG, "r") as f: | |
config = yaml.safe_load(f) | |
cfg = dict2namespace(config) | |
# Load image model | |
model = instructir.create_model( | |
input_channels=cfg.model.in_ch, | |
width=cfg.model.width, | |
enc_blks=cfg.model.enc_blks, | |
middle_blk_num=cfg.model.middle_blk_num, | |
dec_blks=cfg.model.dec_blks, | |
txtdim=cfg.model.textdim | |
).to(device) | |
model.load_state_dict(torch.load(MODEL_NAME, map_location=device), strict=True) | |
model.eval() | |
# Load language model | |
language_model = LanguageModel(model=cfg.llm.model) | |
lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses) | |
lm_head.load_state_dict(torch.load(LM_MODEL, map_location=device), strict=True) | |
lm_head.eval() | |
def process(image, prompt): | |
image = image.convert("RGB") | |
image = np.array(image).astype(np.float32) / 255.0 | |
y = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).to(device) | |
lm_embd = language_model(prompt) | |
text_embd, _ = lm_head(lm_embd) | |
x_hat = model(y, text_embd) | |
restored = x_hat[0].permute(1, 2, 0).cpu().detach().numpy() | |
restored = np.clip(restored, 0., 1.) | |
restored = (restored * 255).astype(np.uint8) | |
return Image.fromarray(restored) | |
interface = gr.Interface( | |
fn=process, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Textbox(label="Prompt", placeholder="Describe the restoration..."), | |
], | |
outputs=gr.Image(type="pil"), | |
title="swiftlens: Prompt-Guided Image Restoration" | |
) | |
if __name__ == "__main__": | |
interface.launch() | |