Transcendental-Programmer
commited on
Commit
·
8daf03a
1
Parent(s):
7bcd76c
fix: bug fixes and improvements
Browse files- app.py +29 -0
- faceforge_api/main.py +159 -4
- faceforge_ui/app.py +148 -13
app.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
|
5 |
+
# Configure logging
|
6 |
+
logging.basicConfig(
|
7 |
+
level=logging.DEBUG,
|
8 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
9 |
+
handlers=[logging.StreamHandler(sys.stdout)]
|
10 |
+
)
|
11 |
+
logger = logging.getLogger("faceforge_app")
|
12 |
+
|
13 |
+
# Add the project root to the Python path
|
14 |
+
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
15 |
+
|
16 |
+
try:
|
17 |
+
logger.info("Starting FaceForge app")
|
18 |
+
# Import the demo from the UI module
|
19 |
+
from faceforge_ui.app import demo
|
20 |
+
|
21 |
+
# Launch the app
|
22 |
+
if __name__ == "__main__":
|
23 |
+
logger.info("Launching Gradio interface")
|
24 |
+
demo.launch(server_name="0.0.0.0")
|
25 |
+
except Exception as e:
|
26 |
+
logger.critical(f"Failed to start app: {e}")
|
27 |
+
import traceback
|
28 |
+
logger.debug(traceback.format_exc())
|
29 |
+
raise
|
faceforge_api/main.py
CHANGED
@@ -1,13 +1,168 @@
|
|
1 |
-
from fastapi import FastAPI
|
2 |
from fastapi.responses import JSONResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
app = FastAPI()
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
@app.get("/")
|
7 |
def read_root():
|
|
|
8 |
return {"message": "FaceForge API is running"}
|
9 |
|
10 |
@app.post("/generate")
|
11 |
-
def generate_image():
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Request
|
2 |
from fastapi.responses import JSONResponse
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from typing import List, Optional
|
6 |
+
import numpy as np
|
7 |
+
import base64
|
8 |
+
import logging
|
9 |
+
import sys
|
10 |
+
import traceback
|
11 |
+
import io
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from faceforge_core.latent_explorer import LatentSpaceExplorer
|
15 |
+
from faceforge_core.attribute_directions import LatentDirectionFinder
|
16 |
+
from faceforge_core.custom_loss import attribute_preserving_loss
|
17 |
+
|
18 |
+
# Configure logging
|
19 |
+
logging.basicConfig(
|
20 |
+
level=logging.DEBUG,
|
21 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
22 |
+
handlers=[logging.StreamHandler(sys.stdout)]
|
23 |
+
)
|
24 |
+
logger = logging.getLogger("faceforge_api")
|
25 |
+
|
26 |
+
# --- Models for API ---
|
27 |
+
|
28 |
+
class PointIn(BaseModel):
|
29 |
+
text: str
|
30 |
+
encoding: Optional[List[float]] = None
|
31 |
+
xy_pos: Optional[List[float]] = None
|
32 |
+
|
33 |
+
class GenerateRequest(BaseModel):
|
34 |
+
prompts: List[str]
|
35 |
+
positions: Optional[List[List[float]]] = None
|
36 |
+
mode: str = "distance"
|
37 |
+
player_pos: Optional[List[float]] = None
|
38 |
+
|
39 |
+
class ManipulateRequest(BaseModel):
|
40 |
+
encoding: List[float]
|
41 |
+
direction: List[float]
|
42 |
+
alpha: float
|
43 |
+
|
44 |
+
class AttributeDirectionRequest(BaseModel):
|
45 |
+
latents: List[List[float]]
|
46 |
+
labels: Optional[List[int]] = None
|
47 |
+
n_components: Optional[int] = 10
|
48 |
+
|
49 |
+
# --- FastAPI app ---
|
50 |
|
51 |
app = FastAPI()
|
52 |
|
53 |
+
# Add CORS middleware to allow requests from any origin
|
54 |
+
app.add_middleware(
|
55 |
+
CORSMiddleware,
|
56 |
+
allow_origins=["*"],
|
57 |
+
allow_credentials=True,
|
58 |
+
allow_methods=["*"],
|
59 |
+
allow_headers=["*"],
|
60 |
+
)
|
61 |
+
|
62 |
+
# Global explorer instance
|
63 |
+
explorer = LatentSpaceExplorer()
|
64 |
+
|
65 |
+
# Error handling middleware
|
66 |
+
@app.middleware("http")
|
67 |
+
async def error_handling_middleware(request: Request, call_next):
|
68 |
+
try:
|
69 |
+
return await call_next(request)
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f"Unhandled exception: {str(e)}")
|
72 |
+
logger.debug(traceback.format_exc())
|
73 |
+
return JSONResponse(
|
74 |
+
status_code=500,
|
75 |
+
content={"detail": "Internal server error", "error": str(e)},
|
76 |
+
)
|
77 |
+
|
78 |
@app.get("/")
|
79 |
def read_root():
|
80 |
+
logger.debug("Root endpoint called")
|
81 |
return {"message": "FaceForge API is running"}
|
82 |
|
83 |
@app.post("/generate")
|
84 |
+
def generate_image(req: GenerateRequest):
|
85 |
+
try:
|
86 |
+
logger.debug(f"Generate image request: {req}")
|
87 |
+
|
88 |
+
# Clear existing points
|
89 |
+
explorer.points = []
|
90 |
+
|
91 |
+
# Add points for each prompt
|
92 |
+
for i, prompt in enumerate(req.prompts):
|
93 |
+
logger.debug(f"Processing prompt {i}: {prompt}")
|
94 |
+
|
95 |
+
# Generate a mock encoding (in production, this would use a real model)
|
96 |
+
encoding = np.random.randn(512) # Stub: replace with real encoding
|
97 |
+
|
98 |
+
# Get position if provided, otherwise None
|
99 |
+
xy_pos = req.positions[i] if req.positions and i < len(req.positions) else None
|
100 |
+
logger.debug(f"Position for prompt {i}: {xy_pos}")
|
101 |
+
|
102 |
+
# Add point to explorer
|
103 |
+
explorer.add_point(prompt, encoding, xy_pos)
|
104 |
+
|
105 |
+
# Get player position
|
106 |
+
if req.player_pos is None:
|
107 |
+
player_pos = [0.0, 0.0]
|
108 |
+
else:
|
109 |
+
player_pos = req.player_pos
|
110 |
+
logger.debug(f"Player position: {player_pos}")
|
111 |
+
|
112 |
+
# Sample encoding
|
113 |
+
logger.debug(f"Sampling with mode: {req.mode}")
|
114 |
+
sampled = explorer.sample_encoding(tuple(player_pos), mode=req.mode)
|
115 |
+
|
116 |
+
# Generate mock image (in production, this would use the sampled encoding)
|
117 |
+
img = (np.random.rand(256, 256, 3) * 255).astype(np.uint8)
|
118 |
+
|
119 |
+
# Convert to base64
|
120 |
+
logger.debug("Converting image to base64")
|
121 |
+
pil_img = Image.fromarray(img)
|
122 |
+
buffer = io.BytesIO()
|
123 |
+
pil_img.save(buffer, format="PNG")
|
124 |
+
img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
125 |
+
|
126 |
+
logger.debug("Image generated successfully")
|
127 |
+
return {"status": "success", "image": img_b64}
|
128 |
+
|
129 |
+
except Exception as e:
|
130 |
+
logger.error(f"Error in generate_image: {str(e)}")
|
131 |
+
logger.debug(traceback.format_exc())
|
132 |
+
raise HTTPException(status_code=500, detail=str(e))
|
133 |
+
|
134 |
+
@app.post("/manipulate")
|
135 |
+
def manipulate(req: ManipulateRequest):
|
136 |
+
try:
|
137 |
+
logger.debug(f"Manipulate request: {req}")
|
138 |
+
encoding = np.array(req.encoding)
|
139 |
+
direction = np.array(req.direction)
|
140 |
+
manipulated = encoding + req.alpha * direction
|
141 |
+
logger.debug("Manipulation successful")
|
142 |
+
return {"manipulated_encoding": manipulated.tolist()}
|
143 |
+
except Exception as e:
|
144 |
+
logger.error(f"Error in manipulate: {str(e)}")
|
145 |
+
logger.debug(traceback.format_exc())
|
146 |
+
raise HTTPException(status_code=500, detail=str(e))
|
147 |
+
|
148 |
+
@app.post("/attribute_direction")
|
149 |
+
def attribute_direction(req: AttributeDirectionRequest):
|
150 |
+
try:
|
151 |
+
logger.debug(f"Attribute direction request: {req}")
|
152 |
+
latents = np.array(req.latents)
|
153 |
+
finder = LatentDirectionFinder(latents)
|
154 |
+
|
155 |
+
if req.labels is not None:
|
156 |
+
logger.debug("Using classifier-based direction finding")
|
157 |
+
direction = finder.classifier_direction(req.labels)
|
158 |
+
logger.debug("Direction found successfully")
|
159 |
+
return {"direction": direction.tolist()}
|
160 |
+
else:
|
161 |
+
logger.debug(f"Using PCA with {req.n_components} components")
|
162 |
+
components, explained = finder.pca_direction(n_components=req.n_components)
|
163 |
+
logger.debug("PCA completed successfully")
|
164 |
+
return {"components": components.tolist(), "explained_variance": explained.tolist()}
|
165 |
+
except Exception as e:
|
166 |
+
logger.error(f"Error in attribute_direction: {str(e)}")
|
167 |
+
logger.debug(traceback.format_exc())
|
168 |
+
raise HTTPException(status_code=500, detail=str(e))
|
faceforge_ui/app.py
CHANGED
@@ -1,19 +1,154 @@
|
|
1 |
import gradio as gr
|
2 |
import requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
iface = gr.Interface(
|
11 |
-
fn=generate_image,
|
12 |
-
inputs=[],
|
13 |
-
outputs="image",
|
14 |
-
title="FaceForge Latent Space Explorer",
|
15 |
-
description="Interactively explore and edit faces in latent space."
|
16 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
if __name__ == "__main__":
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import requests
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
+
import base64
|
7 |
+
import logging
|
8 |
+
import sys
|
9 |
+
import traceback
|
10 |
+
import os
|
11 |
|
12 |
+
# Configure logging
|
13 |
+
logging.basicConfig(
|
14 |
+
level=logging.DEBUG,
|
15 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
16 |
+
handlers=[logging.StreamHandler(sys.stdout)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
)
|
18 |
+
logger = logging.getLogger("faceforge_ui")
|
19 |
+
|
20 |
+
# API configuration
|
21 |
+
API_URL = os.environ.get("API_URL", "http://localhost:8000")
|
22 |
+
logger.info(f"Using API URL: {API_URL}")
|
23 |
+
|
24 |
+
def generate_image(prompts, mode, player_x, player_y):
|
25 |
+
"""
|
26 |
+
Generate an image based on prompts and player position.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
prompts: Comma-separated list of prompts
|
30 |
+
mode: Sampling mode ('distance' or 'circle')
|
31 |
+
player_x: X-coordinate of player position
|
32 |
+
player_y: Y-coordinate of player position
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
PIL.Image or None: Generated image or None if generation failed
|
36 |
+
"""
|
37 |
+
try:
|
38 |
+
logger.debug(f"Generating image with prompts: {prompts}, mode: {mode}, position: ({player_x}, {player_y})")
|
39 |
+
|
40 |
+
# Parse prompts
|
41 |
+
prompt_list = [p.strip() for p in prompts.split(",") if p.strip()]
|
42 |
+
if not prompt_list:
|
43 |
+
logger.warning("No valid prompts provided")
|
44 |
+
return None
|
45 |
+
|
46 |
+
logger.debug(f"Parsed prompts: {prompt_list}")
|
47 |
+
|
48 |
+
# Prepare request
|
49 |
+
req = {
|
50 |
+
"prompts": prompt_list,
|
51 |
+
"mode": mode,
|
52 |
+
"player_pos": [float(player_x), float(player_y)]
|
53 |
+
}
|
54 |
+
|
55 |
+
logger.debug(f"Sending request to API: {req}")
|
56 |
+
|
57 |
+
# Make API call
|
58 |
+
try:
|
59 |
+
resp = requests.post(f"{API_URL}/generate", json=req, timeout=30)
|
60 |
+
logger.debug(f"API response status: {resp.status_code}")
|
61 |
+
|
62 |
+
if resp.ok:
|
63 |
+
data = resp.json()
|
64 |
+
logger.debug("Successfully received API response")
|
65 |
+
|
66 |
+
if "image" in data:
|
67 |
+
img_b64 = data["image"]
|
68 |
+
img_bytes = base64.b64decode(img_b64)
|
69 |
+
|
70 |
+
try:
|
71 |
+
img = Image.frombytes("RGB", (256, 256), img_bytes)
|
72 |
+
logger.debug("Successfully decoded image")
|
73 |
+
return img
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"Error decoding image: {e}")
|
76 |
+
logger.debug(traceback.format_exc())
|
77 |
+
return None
|
78 |
+
else:
|
79 |
+
logger.warning("No image in API response")
|
80 |
+
return None
|
81 |
+
else:
|
82 |
+
logger.error(f"API error: {resp.status_code} - {resp.text}")
|
83 |
+
return None
|
84 |
+
|
85 |
+
except requests.exceptions.RequestException as e:
|
86 |
+
logger.error(f"Request failed: {e}")
|
87 |
+
logger.debug(traceback.format_exc())
|
88 |
+
return None
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
logger.error(f"Unexpected error: {e}")
|
92 |
+
logger.debug(traceback.format_exc())
|
93 |
+
return None
|
94 |
+
|
95 |
+
# Create Gradio interface
|
96 |
+
logger.info("Initializing Gradio interface")
|
97 |
+
with gr.Blocks() as demo:
|
98 |
+
gr.Markdown("# FaceForge Latent Space Explorer")
|
99 |
+
|
100 |
+
with gr.Row():
|
101 |
+
with gr.Column():
|
102 |
+
prompts = gr.Textbox(
|
103 |
+
label="Prompts (comma-separated)",
|
104 |
+
value="A photo of a cat, A photo of a dog",
|
105 |
+
info="Enter prompts separated by commas"
|
106 |
+
)
|
107 |
+
mode = gr.Radio(
|
108 |
+
choices=["distance", "circle"],
|
109 |
+
value="distance",
|
110 |
+
label="Sampling Mode",
|
111 |
+
info="Choose how to sample the latent space"
|
112 |
+
)
|
113 |
+
player_x = gr.Slider(-1.0, 1.0, value=0.0, label="Player X")
|
114 |
+
player_y = gr.Slider(-1.0, 1.0, value=0.0, label="Player Y")
|
115 |
+
btn = gr.Button("Generate")
|
116 |
+
|
117 |
+
with gr.Column():
|
118 |
+
img = gr.Image(label="Generated Image")
|
119 |
+
status = gr.Textbox(label="Status", interactive=False)
|
120 |
+
|
121 |
+
def on_generate_click(prompts, mode, player_x, player_y):
|
122 |
+
try:
|
123 |
+
logger.info("Generate button clicked")
|
124 |
+
result = generate_image(prompts, mode, player_x, player_y)
|
125 |
+
if result is not None:
|
126 |
+
return [result, "Image generated successfully"]
|
127 |
+
else:
|
128 |
+
return [None, "Failed to generate image. Check logs for details."]
|
129 |
+
except Exception as e:
|
130 |
+
logger.error(f"Error in generate button handler: {e}")
|
131 |
+
logger.debug(traceback.format_exc())
|
132 |
+
return [None, f"Error: {str(e)}"]
|
133 |
+
|
134 |
+
btn.click(
|
135 |
+
fn=on_generate_click,
|
136 |
+
inputs=[prompts, mode, player_x, player_y],
|
137 |
+
outputs=[img, status]
|
138 |
+
)
|
139 |
+
|
140 |
+
demo.load(lambda: "Ready to generate images", outputs=status)
|
141 |
|
142 |
if __name__ == "__main__":
|
143 |
+
logger.info("Starting Gradio app")
|
144 |
+
try:
|
145 |
+
# Check if we're running in Hugging Face Spaces
|
146 |
+
if "SPACE_ID" in os.environ:
|
147 |
+
logger.info("Running in Hugging Face Space")
|
148 |
+
demo.launch(server_name="0.0.0.0", share=False)
|
149 |
+
else:
|
150 |
+
logger.info("Running locally")
|
151 |
+
demo.launch(server_name="0.0.0.0", share=False)
|
152 |
+
except Exception as e:
|
153 |
+
logger.critical(f"Failed to launch Gradio app: {e}")
|
154 |
+
logger.debug(traceback.format_exc())
|