File size: 3,945 Bytes
afb8be0
 
8daf03a
 
 
 
 
 
 
 
afb8be0
8daf03a
 
 
 
 
afb8be0
8daf03a
 
 
 
 
 
 
659182f
8daf03a
 
 
 
 
 
 
659182f
8daf03a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659182f
 
 
 
 
 
 
 
8daf03a
 
659182f
8daf03a
659182f
8daf03a
659182f
8daf03a
 
 
659182f
8daf03a
 
 
659182f
8daf03a
659182f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afb8be0
 
8daf03a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import requests
import numpy as np
from PIL import Image
import io
import base64
import logging
import sys
import traceback
import os

# Configure logging
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("faceforge_ui")

# API configuration
API_URL = os.environ.get("API_URL", "http://localhost:8000")
logger.info(f"Using API URL: {API_URL}")

def generate_image(prompts, mode, player_x, player_y):
    """Generate an image based on prompts and player position."""
    try:
        logger.debug(f"Generating image with prompts: {prompts}, mode: {mode}, position: ({player_x}, {player_y})")
        
        # Parse prompts
        prompt_list = [p.strip() for p in prompts.split(",") if p.strip()]
        if not prompt_list:
            logger.warning("No valid prompts provided")
            return None, "No valid prompts provided"
        
        # Prepare request
        req = {
            "prompts": prompt_list,
            "mode": mode,
            "player_pos": [float(player_x), float(player_y)]
        }
        
        # Make API call
        try:
            resp = requests.post(f"{API_URL}/generate", json=req, timeout=30)
            
            if resp.ok:
                data = resp.json()
                
                if "image" in data:
                    img_b64 = data["image"]
                    img_bytes = base64.b64decode(img_b64)
                    
                    try:
                        # For testing, create a simple colored image if decode fails
                        try:
                            img = Image.frombytes("RGB", (256, 256), img_bytes)
                        except:
                            # Fallback to create a test image
                            img = Image.new("RGB", (256, 256), (int(player_x*128)+128, 100, int(player_y*128)+128))
                            
                        return img, "Image generated successfully"
                    except Exception as e:
                        logger.error(f"Error decoding image: {e}")
                        return None, f"Error decoding image: {str(e)}"
                else:
                    return None, "No image in API response"
            else:
                return None, f"API error: {resp.status_code}"
                
        except requests.exceptions.RequestException as e:
            logger.error(f"Request failed: {e}")
            return None, f"Request failed: {str(e)}"
            
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        return None, f"Error: {str(e)}"

# Create a simplified Gradio interface to avoid schema issues
demo = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Prompts (comma-separated)", value="A photo of a cat, A photo of a dog"),
        gr.Radio(["distance", "circle"], value="distance", label="Sampling Mode"),
        gr.Slider(-1.0, 1.0, value=0.0, label="Player X"),
        gr.Slider(-1.0, 1.0, value=0.0, label="Player Y")
    ],
    outputs=[
        gr.Image(label="Generated Image", type="pil"),
        gr.Textbox(label="Status")
    ],
    title="FaceForge Latent Space Explorer",
    description="Interactively explore and edit faces in latent space.",
    allow_flagging="never"
)

if __name__ == "__main__":
    logger.info("Starting Gradio app")
    try:
        # Check if we're running in Hugging Face Spaces
        if "SPACE_ID" in os.environ:
            logger.info("Running in Hugging Face Space")
            demo.launch(server_name="0.0.0.0", share=False)
        else:
            logger.info("Running locally")
            demo.launch(server_name="0.0.0.0", share=False)
    except Exception as e:
        logger.critical(f"Failed to launch Gradio app: {e}")
        logger.debug(traceback.format_exc())