Colourize / app.py
dkescape's picture
Update app.py
ed0b9b7 verified
raw
history blame
2.79 kB
import os
import cv2
import tempfile
import numpy as np
import gradio as gr
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from pathlib import Path
# Initialize model
def load_model():
global img_colorization
img_colorization = pipeline(
Tasks.image_colorization,
model='iic/cv_ddcolor_image-colorization',
model_revision='v1.0.0'
)
def inference(img):
if img is None:
raise gr.Error("Please upload an image first")
with tempfile.TemporaryDirectory() as temp_dir:
# Convert PIL image to numpy array if needed
if isinstance(img, np.ndarray):
image = img
else:
image = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# Process image
output = img_colorization(image[..., ::-1])
result = output['output_img'].astype(np.uint8)
# Save result
out_path = os.path.join(temp_dir, 'colorized.png')
cv2.imwrite(out_path, result)
return Path(out_path), "✅ Colorization completed successfully!"
# Create modern UI with Blocks
with gr.Blocks(theme="soft", title="🎨 AI Color Restoration Studio") as demo:
gr.Markdown("""
# 🎨 AI Color Restoration Studio
Transform your black & white photos into vibrant colorized versions using state-of-the-art AI!
Upload an image and watch as our deep learning model automatically adds natural colors.
""")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(
label="Upload Monochrome Image",
type="pil",
height=400,
sources=["upload"],
interactive=True
)
submit_btn = gr.Button("✨ Colorize Image", variant="primary")
clear_btn = gr.ClearButton()
with gr.Column(scale=1):
output_img = gr.Image(
label="Colorized Result",
type="pil",
height=400,
interactive=False
)
download_btn = gr.File(label="Download Result")
status = gr.Textbox(label="Status", interactive=False)
# Examples section
gr.Examples(
examples=[
["examples/1.jpg"],
["examples/2.jpg"],
["examples/3.jpg"]
],
inputs=[input_img],
outputs=[output_img, status],
fn=inference,
cache_examples=True
)
# Event handlers
submit_btn.click(
fn=inference,
inputs=[input_img],
outputs=[output_img, status]
)
clear_btn.add([input_img, output_img, status])
if __name__ == "__main__":
load_model()
demo.launch(debug=True)