File size: 7,127 Bytes
64066bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719b0c2
64066bb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import gradio as gr
from gradio_client import Client
import os
import requests
from datetime import datetime
import time

# Configuration
BACKEND_URL = os.environ.get("BACKEND_URL", "").strip()
CHECK_INTERVAL = 60  # Check backend status every 60 seconds

class BackendConnection:
    def __init__(self):
        self.client = None
        self.last_check = 0
        self.is_connected = False
        self.error_message = ""
    
    def check_backend(self):
        """Check if backend is accessible"""
        current_time = time.time()
        
        # Only check every CHECK_INTERVAL seconds
        if current_time - self.last_check < CHECK_INTERVAL and self.client:
            return self.is_connected
        
        self.last_check = current_time
        
        if not BACKEND_URL:
            self.error_message = "Backend URL not configured. Please set the BACKEND_URL secret."
            self.is_connected = False
            return False
        
        try:
            # Try to connect to the backend
            response = requests.get(f"{BACKEND_URL}/health", timeout=5)
            if response.status_code == 200:
                # Try to create client
                self.client = Client(BACKEND_URL)
                self.is_connected = True
                self.error_message = ""
                return True
        except Exception as e:
            self.error_message = f"Cannot connect to GPU server: {str(e)}"
            self.is_connected = False
            self.client = None
        
        return False

# Global connection manager
backend = BackendConnection()

def process_with_backend(file_obj, webcam_img, model_type, conf_thresh, max_dets, task_type):
    """Forward request to backend"""
    if not backend.check_backend():
        return (
            gr.update(visible=False),  # raw_img_file
            gr.update(visible=False),  # raw_vid_file
            gr.update(visible=False),  # raw_img_webcam
            gr.update(value=None, visible=True),  # processed_img
            gr.update(visible=False),  # processed_vid
            f"❌ GPU Server Offline: {backend.error_message}"
        )
    
    try:
        # Forward to backend
        result = backend.client.predict(
            file_obj,
            webcam_img,
            model_type,
            conf_thresh,
            max_dets,
            task_type,
            api_name="/predict"
        )
        return result
    except Exception as e:
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(value=None, visible=True),
            gr.update(visible=False),
            f"❌ Processing Error: {str(e)}"
        )

# Create the interface
with gr.Blocks(theme=gr.themes.Soft(), title="PrimateFace Demo") as demo:
    # Header with status
    with gr.Row():
        gr.Markdown("# 🐡 PrimateFace Detection, Pose Estimation, and Gaze Demo")
        status_indicator = gr.Markdown("🟑 Checking GPU server status...", elem_id="status")
    
    gr.Markdown("""
    This demo showcases state-of-the-art primate face analysis including detection, 
    68-point facial landmarks, and gaze estimation across multiple primate species.
    """)
    
    # Check initial status
    if backend.check_backend():
        status_indicator.value = "🟒 GPU Server: Online"
        
        # Load the actual interface from backend
        try:
            # Get the interface configuration from backend
            with gr.Row():
                with gr.Column(scale=1):
                    # Input section
                    with gr.Tabs():
                        with gr.TabItem("Upload File"):
                            input_file = gr.File(label="Upload Image or Video", file_types=["image", "video"])
                            display_raw_image_file = gr.Image(label="Preview", visible=False)
                            display_raw_video_file = gr.Video(label="Preview", visible=False)
                        
                        with gr.TabItem("Webcam"):
                            gr.Markdown("Click on the feed or press Enter to capture")
                            input_webcam = gr.Image(sources=["webcam"], type="pil")
                            display_raw_image_webcam = gr.Image(label="Captured", visible=False)
                    
                    clear_button = gr.Button("Clear All")
                
                with gr.Column(scale=1):
                    # Output section
                    gr.Markdown("### Processed Output")
                    display_processed_image = gr.Image(label="Result", visible=False)
                    display_processed_video = gr.Video(label="Result", visible=False)
                    error_message = gr.Markdown(visible=False)
            
            # Controls
            submit_button = gr.Button("πŸš€ Detect Faces", variant="primary", size="lg")
            
            with gr.Accordion("Advanced Settings", open=False):
                model_choice = gr.Radio(
                    choices=["MMDetection"], 
                    value="MMDetection",
                    label="Model",
                    visible=False
                )
                task_type = gr.Dropdown(
                    choices=["Face Detection", "Face Pose Estimation", "Gaze Estimation [experimental]"],
                    value="Face Detection",
                    label="Task"
                )
                conf_threshold = gr.Slider(0.05, 0.95, 0.25, step=0.05, label="Confidence Threshold")
                max_detections = gr.Slider(1, 10, 3, step=1, label="Max Detections")
            
            # Wire up the interface
            submit_button.click(
                process_with_backend,
                inputs=[input_file, input_webcam, model_choice, conf_threshold, max_detections, task_type],
                outputs=[display_raw_image_file, display_raw_video_file, display_raw_image_webcam, 
                        display_processed_image, display_processed_video, error_message]
            )
            
        except Exception as e:
            gr.Markdown(f"❌ Failed to load interface: {str(e)}")
    
    else:
        status_indicator.value = "πŸ”΄ GPU Server: Offline"
        gr.Markdown(f"""
        ### πŸ”΄ GPU Server is Currently Offline
        
        {backend.error_message}
        
        The demo requires a GPU server for processing. Please check back later.
        
        **Note:** The server may be temporarily unavailable for maintenance.
        """)
    
    # Add info section regardless of status
    gr.Markdown("""
    ---
    ### Technical Details
    - **Detection**: MMDetection (Cascade R-CNN R101-FPN)
    - **Pose**: MMPose (HRNet-W18) with 68 facial landmarks  
    - **Gaze**: Gazelle (DINOv2-based) experimental gaze estimation
    - **GPU**: NVIDIA GPU with CUDA support
    
    ### About
    Created by Felipe Parodi | [Project GitHub](https://github.com/KordingLab/PrimateFace) | [Personal](https://github.com/felipe-parodi)
    """)

if __name__ == "__main__":
    demo.launch()