File size: 4,733 Bytes
b7f710c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20cf96a
b7f710c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import sys
from PIL import Image

# Append the path to the inference script's directory
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'src', 'slimface', 'inference')))

from end2end_inference import inference_and_confirm

def run_inference(image, reference_dict_path, index_to_class_mapping_path, model_path,

                 edgeface_model_path="ckpts/idiap/edgeface_base.pt", 

                 algorithm="yolo", accelerator="auto", resolution=224, similarity_threshold=0.6):
    
    # Validate image input
    if image is None:
        return '<div class="error-message">Error: No image provided. Please upload an image.</div>'

    # Define temporary image path
    temp_image_path = os.path.join(os.path.dirname(__file__), "temp_data", "temp_image.jpg")
    os.makedirs(os.path.dirname(temp_image_path), exist_ok=True)
    
    # Save the image
    try:
        image.save(temp_image_path)
    except Exception as e:
        return f'<div class="error-message">Error saving image: {str(e)}</div>'

    # Create args object to mimic command-line arguments
    class Args:
        def __init__(self):
            self.unknown_image_path = temp_image_path
            self.reference_dict_path = reference_dict_path.name if hasattr(reference_dict_path, 'name') else reference_dict_path
            self.index_to_class_mapping_path = index_to_class_mapping_path.name if hasattr(index_to_class_mapping_path, 'name') else index_to_class_mapping_path
            self.model_path = model_path.name if hasattr(model_path, 'name') else model_path
            self.edgeface_model_path = edgeface_model_path.name if hasattr(edgeface_model_path, 'name') else edgeface_model_path
            self.algorithm = algorithm
            self.accelerator = accelerator
            self.resolution = resolution
            self.similarity_threshold = similarity_threshold

    args = Args()

    # Validate inputs
    if not all([args.reference_dict_path, args.index_to_class_mapping_path, args.model_path]):
        return '<div class="error-message">Error: Please provide all required files (reference dict, index-to-class mapping, and model).</div>'

    try:
        # Call the inference function from end2end_inference.py
        results = inference_and_confirm(args)
        
        # Format output as HTML for Gradio
        output = '<div class="results-container">'
        output += '<h2 class="result-title">Inference Results</h2>'
        
        if not results:
            output += '<div class="error-message">No results returned from inference.</div>'
        else:
            for idx, result in enumerate(results, 1):
                output += '<div class="result-card">'
                output += f'<h3 class="result-title">Result {idx}</h3>'
                
                # Person Name
                person_name = result.get('predicted_class', 'N/A')
                output += f'<div class="result-item"><span class="label">Person Name</span><span class="value">{person_name}</span></div>'
                
                # Confidence
                confidence = result.get('confidence', 'N/A')
                confidence_str = f'{confidence:.4f}' if isinstance(confidence, (int, float)) else 'N/A'
                output += f'<div class="result-item"><span class="label">Confidence</span><span class="value">{confidence_str}</span></div>'
                
                # Similarity with Reference Image
                similarity = result.get('similarity', 'N/A')
                similarity_str = f'{similarity:.4f}' if isinstance(similarity, (int, float)) else 'N/A'
                output += f'<div class="result-item"><span class="label">Similarity with<br> Reference Image</span><span class="value">{similarity_str}</span></div>'
                
                # Confirmed Person
                confirmed = result.get('confirmed', 'N/A')
                confirmed_class = 'confirmed-true' if confirmed is True else 'confirmed-false' if confirmed is False else ''
                confirmed_str = str(confirmed) if confirmed is not None else 'N/A'
                output += f'<div class="result-item"><span class="label">Confirmed Person</span><span class="value {confirmed_class}">{confirmed_str}</span></div>'
                
                output += '</div>'
        
        output += '</div>'
        return output
    
    except Exception as e:
        return f'<div class="error-message">Error during inference: {str(e)}</div>'
    
    finally:
        # Clean up temporary image
        if os.path.exists(temp_image_path):
            os.remove(temp_image_path)