Spaces:
Sleeping
Sleeping
File size: 4,613 Bytes
eb24fe7 8f6b83a eb24fe7 70d5e84 fd18b9b 70d5e84 fd18b9b 70d5e84 bfbcea8 70d5e84 c4bf26d bf6140a eb24fe7 2a5fc50 c4bf26d eb24fe7 c4bf26d bf6140a c4bf26d 2a5fc50 c4bf26d eb24fe7 c4bf26d 2a5fc50 c4bf26d 8f6b83a 70d5e84 2998cfe 70d5e84 2a5fc50 70d5e84 c4bf26d 70d5e84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
try:
import detectron2
except:
import os
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
import cv2
from matplotlib.pyplot import axis
import gradio as gr
import requests
import numpy as np
from torch import nn
import requests
import torch
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
models = [
{
"name": "Version 1 (2-class)",
"model_path": "https://huggingface.co/stalyn314/10xmineralmodel/resolve/main/xplx10x_d2.pth",
"classes": ["minerales", "Afs", "Amp", "Bt", "Ms", "Ol", "Pl", "Px", "Qz"],
"cfg": None,
"metadata": None
},
{
"name": "Version 2 (4-class)",
"model_path": "https://huggingface.co/stalyn314/10xmineralmodel/resolve/main/10xmodel_d2.pth",
"classes": ["minerales", "Afs", "Amp", "Bt", "Ms", "Ol", "Pl", "Px", "Qz"],
"cfg": None,
"metadata": None
},
]
model_name_to_id = {model["name"] : id_ for id_, model in enumerate(models)}
for model in models:
model["cfg"] = get_cfg()
model["cfg"].merge_from_file("./configs/detectron2/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
model["cfg"].MODEL.ROI_HEADS.NUM_CLASSES = len(model["classes"])
model["cfg"].MODEL.WEIGHTS = model["model_path"]
model["metadata"] = MetadataCatalog.get(model["name"])
model["metadata"].thing_classes = model["classes"]
if not torch.cuda.is_available():
model["cfg"].MODEL.DEVICE = "cpu"
def inference(image, min_score, model_name):
results = []
model_id = model_name_to_id[model_name]
models[model_id]["cfg"].MODEL.ROI_HEADS.SCORE_THRESH_TEST = min_score
predictor = DefaultPredictor(models[model_id]["cfg"])
# Asegurémonos de que image es una lista de imágenes o solo una
if isinstance(image, list): # Si es una lista de imágenes
images = image
else: # Si es solo una imagen, la convertimos en una lista
images = [image]
for img in images:
# Convert image from BGR to RGB if necessary
im = img[:,:,::-1]
# Make the prediction
outputs = predictor(im)
v = Visualizer(im, models[model_id]["metadata"], scale=1.2)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
# Convert the result from BGR to RGB
result_image = out.get_image()
result_image_rgb = result_image[:, :, ::-1] # Convert BGR to RGB
results.append(result_image_rgb) # Add the processed image to the list
return results # Return all the results
title = "# DBMDZ Detectron2 Model Demo"
description = """
This demo introduces an interactive playground for our trained Detectron2 model.
Currently, two models are supported that were trained on manually annotated segments from digitized books:
* [Version 1 (2-class)](https://huggingface.co/dbmdz/detectron2-model): This model can detect *Illustration* or *Illumination* segments on a given page.
* [Version 2 (4-class)](https://huggingface.co/dbmdz/detectron2-v2-model): This model is more powerful and can detect *Illustration*, *Stamp*, *Initial* or *Other* segments on a given page.
"""
footer = "Made in Munich with ❤️ and 🥨."
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Tab("From Image"):
image_input = gr.Gallery(type="numpy", label="Input Images", elem_id="input_image")
min_score = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Minimum score")
model_name = gr.Radio(choices=[model["name"] for model in models], value=models[0]["name"], label="Select Detectron2 model")
output_gallery = gr.Gallery(type ="numpy", label="Output Images", elem_id="output_images")
inference_button = gr.Button("Submit")
inference_button.click(fn=inference, inputs=[image_input, min_score, model_name], outputs=output_gallery)
gr.Markdown(footer)
demo.launch()
#gr.Interface(
# inference,
# [gr.inputs.Textbox(label="Image URL", placeholder="https://api.digitale-sammlungen.de/iiif/image/v2/bsb10483966_00008/full/500,/0/default.jpg"),
# gr.inputs.Image(type="numpy", label="Input Image"),
# gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Minimum score"),
# gr.Radio(choices=[model["name"] for model in models], value=models[0]["name"], label="Select Detectron2 model"),
# ],
# gr.outputs.Image(type="pil", label="Output"),
# title=title,
# description=description,
# article=article,
# examples=[]).launch()
|