File size: 7,847 Bytes
afb8be0
 
8daf03a
 
 
 
 
 
 
 
a8da4e0
afb8be0
8daf03a
 
 
 
 
afb8be0
8daf03a
 
a8da4e0
 
 
 
8daf03a
0f25023
 
 
8daf03a
 
 
659182f
8daf03a
 
 
 
 
 
 
659182f
8daf03a
 
 
 
 
 
 
 
a8da4e0
 
8daf03a
 
0f25023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8da4e0
8daf03a
 
a8da4e0
 
 
8daf03a
a8da4e0
 
 
 
 
659182f
a8da4e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8daf03a
a8da4e0
659182f
8daf03a
 
 
0f25023
 
 
 
8daf03a
 
 
a8da4e0
659182f
8daf03a
659182f
a8da4e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afb8be0
a8da4e0
afb8be0
a8da4e0
8daf03a
a8da4e0
 
 
 
 
 
8daf03a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
import requests
import numpy as np
from PIL import Image
import io
import base64
import logging
import sys
import traceback
import os
import json

# Configure logging
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("faceforge_ui")

# Add more debug loggers for gradio internals
logging.getLogger("gradio").setLevel(logging.DEBUG)
logging.getLogger("gradio_client").setLevel(logging.DEBUG)

# API configuration
# In HF Spaces, we need to use a relative path since both UI and API run on the same server
# For local development with separate servers, the env var can be set to http://localhost:8000
API_URL = os.environ.get("API_URL", "/api")
logger.info(f"Using API URL: {API_URL}")

def generate_image(prompts, mode, player_x, player_y):
    """Generate an image based on prompts and player position."""
    try:
        logger.debug(f"Generating image with prompts: {prompts}, mode: {mode}, position: ({player_x}, {player_y})")
        
        # Parse prompts
        prompt_list = [p.strip() for p in prompts.split(",") if p.strip()]
        if not prompt_list:
            logger.warning("No valid prompts provided")
            return None, "No valid prompts provided"
        
        # Prepare request
        req = {
            "prompts": prompt_list,
            "mode": mode,
            "player_pos": [float(player_x), float(player_y)]
        }
        
        logger.debug(f"Request payload: {json.dumps(req)}")
        
        # Make API call
        try:
            # For debugging/testing, create a mock image if API is not available
            if API_URL == "/api" and os.environ.get("MOCK_API", "false").lower() == "true":
                logger.debug("Using mock API response")
                # Create a test image
                img = Image.new("RGB", (256, 256), (int(player_x*128)+128, 100, int(player_y*128)+128))
                return img, "Image generated using mock API"
                
            # Determine the base URL for the API
            if API_URL.startswith("/"):
                # Relative URL, construct the full URL based on the request context
                # For Gradio apps, we'll just use a relative path
                base_url = API_URL
            else:
                # Absolute URL, use as is
                base_url = API_URL
                
            logger.debug(f"Making request to: {base_url}/generate")
            resp = requests.post(f"{base_url}/generate", json=req, timeout=30)
            logger.debug(f"API response status: {resp.status_code}")
            
            if resp.ok:
                try:
                    data = resp.json()
                    logger.debug(f"API response structure: {list(data.keys())}")
                    
                    if "image" in data:
                        img_b64 = data["image"]
                        logger.debug(f"Image base64 length: {len(img_b64)}")
                        img_bytes = base64.b64decode(img_b64)
                        
                        try:
                            # For testing, create a simple colored image if decode fails
                            try:
                                img = Image.open(io.BytesIO(img_bytes))
                                logger.debug(f"Image decoded successfully: {img.size} {img.mode}")
                            except Exception as e:
                                logger.error(f"Failed to decode image from bytes: {e}, creating test image")
                                # Fallback to create a test image
                                img = Image.new("RGB", (256, 256), (int(player_x*128)+128, 100, int(player_y*128)+128))
                                
                            return img, "Image generated successfully"
                        except Exception as e:
                            logger.error(f"Error processing image: {e}")
                            logger.debug(traceback.format_exc())
                            return None, f"Error processing image: {str(e)}"
                    else:
                        logger.warning("No image field in API response")
                        return None, "No image in API response"
                except Exception as e:
                    logger.error(f"Error parsing API response: {e}")
                    logger.debug(f"Raw response: {resp.text[:500]}")
                    return None, f"Error parsing API response: {str(e)}"
            else:
                logger.error(f"API error: {resp.status_code}, {resp.text[:500]}")
                return None, f"API error: {resp.status_code}"
                
        except requests.exceptions.RequestException as e:
            logger.error(f"Request failed: {e}")
            # Fall back to a test image
            logger.debug("Falling back to test image")
            img = Image.new("RGB", (256, 256), (int(player_x*128)+128, 100, int(player_y*128)+128))
            return img, f"API connection failed (using test image): {str(e)}"
            
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        logger.debug(traceback.format_exc())
        return None, f"Error: {str(e)}"

# Create a simplified Gradio interface to avoid schema issues
# Use basic components without custom schemas
def create_demo():
    with gr.Blocks(title="FaceForge Latent Space Explorer") as demo:
        gr.Markdown("# FaceForge Latent Space Explorer")
        gr.Markdown("Interactively explore and edit faces in latent space.")
        
        with gr.Row():
            with gr.Column(scale=3):
                prompts_input = gr.Textbox(
                    label="Prompts (comma-separated)", 
                    value="A photo of a cat, A photo of a dog",
                    lines=2
                )
                mode_input = gr.Radio(
                    choices=["distance", "circle"],
                    value="distance",
                    label="Sampling Mode"
                )
                player_x_input = gr.Slider(
                    minimum=-1.0,
                    maximum=1.0,
                    value=0.0,
                    step=0.1,
                    label="Player X"
                )
                player_y_input = gr.Slider(
                    minimum=-1.0,
                    maximum=1.0,
                    value=0.0,
                    step=0.1,
                    label="Player Y"
                )
                
                generate_btn = gr.Button("Generate")
                
            with gr.Column(scale=5):
                output_image = gr.Image(label="Generated Image")
                output_status = gr.Textbox(label="Status")
                
        generate_btn.click(
            fn=generate_image,
            inputs=[prompts_input, mode_input, player_x_input, player_y_input],
            outputs=[output_image, output_status]
        )
        
    return demo

# Only start if this file is run directly, not when imported
if __name__ == "__main__":
    logger.info("Starting Gradio app directly from app.py")
    try:
        # Print Gradio version for debugging
        logger.info(f"Gradio version: {gr.__version__}")
        
        # Create demo
        demo = create_demo()
        
        # Check if we're running in Hugging Face Spaces
        if "SPACE_ID" in os.environ:
            logger.info("Running in Hugging Face Space")
            demo.launch(server_name="0.0.0.0", share=False)
        else:
            logger.info("Running locally")
            demo.launch(server_name="0.0.0.0", share=False)
    except Exception as e:
        logger.critical(f"Failed to launch Gradio app: {e}")
        logger.debug(traceback.format_exc())