Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| from gradio_app.inference import run_inference | |
| from gradio_app.components import ( | |
| list_reference_files, list_mapping_files, | |
| list_classifier_files, list_edgeface_files | |
| ) | |
| from gradio_app.project_info import ( | |
| CONTENT_DESCRIPTION, CONTENT_OUTTRO, | |
| CONTENT_IN_1, CONTENT_IN_2, | |
| CONTENT_OUT_1, CONTENT_OUT_2 | |
| ) | |
| from glob import glob | |
| import os | |
| def create_image_io_row(): | |
| """Create the row for image input and output display.""" | |
| with gr.Row(elem_classes=["image-io-row"]): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| output = gr.HTML(label="Inference Results", elem_classes=["results-container"]) | |
| return image_input, output | |
| def create_model_settings_row(): | |
| """Create the row for model files and settings.""" | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(elem_classes=["section-group"]): | |
| gr.Markdown("### Model Files", elem_classes=["section-title"]) | |
| ref_dict = gr.Dropdown( | |
| choices=["Select a file"] + list_reference_files(), | |
| label="Reference Dict JSON", | |
| value="data/reference_data/reference_image_data.json" | |
| ) | |
| index_map = gr.Dropdown( | |
| choices=["Select a file"] + list_mapping_files(), | |
| label="Index to Class Mapping JSON", | |
| value="ckpts/index_to_class_mapping.json" | |
| ) | |
| classifier_model = gr.Dropdown( | |
| choices=["Select a file"] + list_classifier_files(), | |
| label="Classifier Model (.pth)", | |
| value="ckpts/SlimFace_efficientnet_b3_full_model.pth" | |
| ) | |
| edgeface_model = gr.Dropdown( | |
| choices=["Select a file"] + list_edgeface_files(), | |
| label="EdgeFace Model (.pt)", | |
| value="ckpts/idiap/edgeface_s_gamma_05.pt" | |
| ) | |
| with gr.Column(): | |
| with gr.Group(elem_classes=["section-group"]): | |
| gr.Markdown("### Advanced Settings", elem_classes=["section-title"]) | |
| algorithm = gr.Dropdown( | |
| choices=["yolo", "mtcnn", "retinaface"], | |
| label="Detection Algorithm", | |
| value="yolo" | |
| ) | |
| accelerator = gr.Dropdown( | |
| choices=["auto", "cpu", "cuda", "mps"], | |
| label="Accelerator", | |
| value="auto" | |
| ) | |
| resolution = gr.Slider( | |
| minimum=128, | |
| maximum=512, | |
| step=32, | |
| label="Image Resolution", | |
| value=300 | |
| ) | |
| similarity_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.05, | |
| label="Similarity Threshold", | |
| value=0.3 | |
| ) | |
| return ref_dict, index_map, classifier_model, edgeface_model, algorithm, accelerator, resolution, similarity_threshold | |
| # Load local CSS file | |
| CSS = open("apps/gradio_app/static/styles.css").read() | |
| def create_interface(): | |
| """Create the Gradio interface for SlimFace.""" | |
| with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# SlimFace Demonstration") | |
| gr.Markdown(CONTENT_DESCRIPTION) | |
| gr.Markdown(CONTENT_IN_1) | |
| gr.HTML(CONTENT_IN_2) | |
| image_input, output = create_image_io_row() | |
| ref_dict, index_map, classifier_model, edgeface_model, algorithm, accelerator, resolution, similarity_threshold = create_model_settings_row() | |
| # Add example image gallery as a row of columns | |
| with gr.Group(): | |
| gr.Markdown("### Example Images") | |
| example_images = glob("apps/assets/examples/*.[jp][pn][gf]") | |
| if example_images: | |
| with gr.Row(elem_classes=["example-row"]): | |
| for img_path in example_images: | |
| with gr.Column(min_width=120): | |
| gr.Image( | |
| value=img_path, | |
| label=os.path.basename(img_path), | |
| type="filepath", | |
| height=100, | |
| elem_classes=["example-image"] | |
| ) | |
| gr.Button(f"Use {os.path.basename(img_path)}").click( | |
| fn=lambda x=img_path: Image.open(x), | |
| outputs=image_input | |
| ) | |
| else: | |
| gr.Markdown("No example images found in apps/assets/examples/") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Run Inference", variant="primary", elem_classes=["centered-button"]) | |
| submit_btn.click( | |
| fn=run_inference, | |
| inputs=[ | |
| image_input, | |
| ref_dict, | |
| index_map, | |
| classifier_model, | |
| edgeface_model, | |
| algorithm, | |
| accelerator, | |
| resolution, | |
| similarity_threshold | |
| ], | |
| outputs=output | |
| ) | |
| gr.Markdown(CONTENT_OUTTRO) | |
| gr.HTML(CONTENT_OUT_1) | |
| gr.Markdown(CONTENT_OUT_2) | |
| return demo | |
| def main(): | |
| """Launch the Gradio interface.""" | |
| demo = create_interface() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |