File size: 4,362 Bytes
eb24fe7
 
 
 
 
 
8f6b83a
 
eb24fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d5e84
 
 
fd18b9b
 
70d5e84
 
 
 
 
fd18b9b
 
70d5e84
 
 
 
 
 
 
 
 
 
bfbcea8
70d5e84
 
 
 
 
 
 
 
 
 
c4bf26d
 
 
bf6140a
 
 
eb24fe7
c4bf26d
 
 
eb24fe7
c4bf26d
 
bf6140a
c4bf26d
 
 
 
 
 
eb24fe7
c4bf26d
 
 
8f6b83a
70d5e84
 
 
 
 
 
 
 
 
 
 
 
 
 
2998cfe
70d5e84
 
 
 
 
f24ea11
70d5e84
 
 
c4bf26d
70d5e84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
try:
    import detectron2
except:
    import os 
    os.system('pip install git+https://github.com/facebookresearch/detectron2.git')

import cv2

from matplotlib.pyplot import axis
import gradio as gr
import requests
import numpy as np
from torch import nn
import requests

import torch

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog


models = [
{
    "name": "Version 1 (2-class)",
    "model_path": "https://huggingface.co/stalyn314/10xmineralmodel/resolve/main/xplx10x_d2.pth",
    "classes": ["minerales", "Afs", "Amp", "Bt", "Ms", "Ol", "Pl", "Px", "Qz"],
    "cfg": None,
    "metadata": None
},
{
    "name": "Version 2 (4-class)",
    "model_path": "https://huggingface.co/stalyn314/10xmineralmodel/resolve/main/10xmodel_d2.pth",
    "classes": ["minerales", "Afs", "Amp", "Bt", "Ms", "Ol", "Pl", "Px", "Qz"],
    "cfg": None,
    "metadata": None
},
]

model_name_to_id = {model["name"] : id_ for id_, model in enumerate(models)}

for model in models:

    model["cfg"] = get_cfg()
    model["cfg"].merge_from_file("./configs/detectron2/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
    model["cfg"].MODEL.ROI_HEADS.NUM_CLASSES = len(model["classes"])
    model["cfg"].MODEL.WEIGHTS = model["model_path"]
    
    model["metadata"] = MetadataCatalog.get(model["name"])
    model["metadata"].thing_classes = model["classes"]

    if not torch.cuda.is_available():
      model["cfg"].MODEL.DEVICE = "cpu"


def inference(image, min_score, model_name):
    results = []
    
    model_id = model_name_to_id[model_name]
    models[model_id]["cfg"].MODEL.ROI_HEADS.SCORE_THRESH_TEST = min_score
    predictor = DefaultPredictor(models[model_id]["cfg"])

    for img in image:
        # Convert image from BGR to RGB if necessary
        im = img[:,:,::-1]

        # Make the prediction
        outputs = predictor(im)

        v = Visualizer(im, models[model_id]["metadata"], scale=1.2)
        out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
        
        # Convert the result from BGR to RGB
        result_image = out.get_image()
        result_image_rgb = result_image[:, :, ::-1]  # Convert BGR to RGB

        results.append(result_image_rgb)  # Add the processed image to the list
    
    return results  # Return all the results

title = "# DBMDZ Detectron2 Model Demo"
description = """
This demo introduces an interactive playground for our trained Detectron2 model.
Currently, two models are supported that were trained on manually annotated segments from digitized books:
* [Version 1 (2-class)](https://huggingface.co/dbmdz/detectron2-model): This model can detect *Illustration* or *Illumination* segments on a given page.
* [Version 2 (4-class)](https://huggingface.co/dbmdz/detectron2-v2-model): This model is more powerful and can detect *Illustration*, *Stamp*, *Initial* or *Other* segments on a given page.
"""
footer = "Made in Munich with ❤️ and 🥨."

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    
    with gr.Tab("From Image"):
        image_input = gr.Gallery(type="numpy", label="Input Images", elem_id="input_image")

    min_score = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Minimum score")

    model_name = gr.Radio(choices=[model["name"] for model in models], value=models[0]["name"], label="Select Detectron2 model")

    output_gallery = gr.Gallery(label="Output Images", elem_id="output_images")

    inference_button = gr.Button("Submit")
    
    inference_button.click(fn=inference, inputs=[image_input, min_score, model_name], outputs=output_gallery)

    gr.Markdown(footer)

demo.launch()

#gr.Interface(
#    inference,
#    [gr.inputs.Textbox(label="Image URL", placeholder="https://api.digitale-sammlungen.de/iiif/image/v2/bsb10483966_00008/full/500,/0/default.jpg"),
#     gr.inputs.Image(type="numpy", label="Input Image"),
#     gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Minimum score"),
#     gr.Radio(choices=[model["name"] for model in models], value=models[0]["name"], label="Select Detectron2 model"),
#    ], 
#    gr.outputs.Image(type="pil", label="Output"),
#    title=title,
#    description=description,
#    article=article,
#    examples=[]).launch()