swiftlenss / app.py
Lawliet18's picture
Add application file
6342ac4
raw
history blame
2.01 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
import yaml
import os
from utils import load_img, plot_all, dict2namespace, seed_everything
from models import instructir
from text.models import LanguageModel, LMHead
# Setup
SEED = 42
seed_everything(SEED=SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CONFIG = "configs/eval5d.yml"
LM_MODEL = "models/lm_instructir-7d.pt"
MODEL_NAME = "models/im_instructir-7d.pt"
# Load config
with open(CONFIG, "r") as f:
config = yaml.safe_load(f)
cfg = dict2namespace(config)
# Load image model
model = instructir.create_model(
input_channels=cfg.model.in_ch,
width=cfg.model.width,
enc_blks=cfg.model.enc_blks,
middle_blk_num=cfg.model.middle_blk_num,
dec_blks=cfg.model.dec_blks,
txtdim=cfg.model.textdim
).to(device)
model.load_state_dict(torch.load(MODEL_NAME, map_location=device), strict=True)
model.eval()
# Load language model
language_model = LanguageModel(model=cfg.llm.model)
lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses)
lm_head.load_state_dict(torch.load(LM_MODEL, map_location=device), strict=True)
lm_head.eval()
def process(image, prompt):
image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
y = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).to(device)
lm_embd = language_model(prompt)
text_embd, _ = lm_head(lm_embd)
x_hat = model(y, text_embd)
restored = x_hat[0].permute(1, 2, 0).cpu().detach().numpy()
restored = np.clip(restored, 0., 1.)
restored = (restored * 255).astype(np.uint8)
return Image.fromarray(restored)
interface = gr.Interface(
fn=process,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Prompt", placeholder="Describe the restoration..."),
],
outputs=gr.Image(type="pil"),
title="swiftlens: Prompt-Guided Image Restoration"
)
if __name__ == "__main__":
interface.launch()