Transcendental-Programmer commited on
Commit
8daf03a
·
1 Parent(s): 7bcd76c

fix: bug fixes and improvements

Browse files
Files changed (3) hide show
  1. app.py +29 -0
  2. faceforge_api/main.py +159 -4
  3. faceforge_ui/app.py +148 -13
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+
5
+ # Configure logging
6
+ logging.basicConfig(
7
+ level=logging.DEBUG,
8
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
9
+ handlers=[logging.StreamHandler(sys.stdout)]
10
+ )
11
+ logger = logging.getLogger("faceforge_app")
12
+
13
+ # Add the project root to the Python path
14
+ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
15
+
16
+ try:
17
+ logger.info("Starting FaceForge app")
18
+ # Import the demo from the UI module
19
+ from faceforge_ui.app import demo
20
+
21
+ # Launch the app
22
+ if __name__ == "__main__":
23
+ logger.info("Launching Gradio interface")
24
+ demo.launch(server_name="0.0.0.0")
25
+ except Exception as e:
26
+ logger.critical(f"Failed to start app: {e}")
27
+ import traceback
28
+ logger.debug(traceback.format_exc())
29
+ raise
faceforge_api/main.py CHANGED
@@ -1,13 +1,168 @@
1
- from fastapi import FastAPI
2
  from fastapi.responses import JSONResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  app = FastAPI()
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  @app.get("/")
7
  def read_root():
 
8
  return {"message": "FaceForge API is running"}
9
 
10
  @app.post("/generate")
11
- def generate_image():
12
- # Placeholder for image generation logic
13
- return JSONResponse(content={"status": "success", "image": None})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
  from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ from typing import List, Optional
6
+ import numpy as np
7
+ import base64
8
+ import logging
9
+ import sys
10
+ import traceback
11
+ import io
12
+ from PIL import Image
13
+
14
+ from faceforge_core.latent_explorer import LatentSpaceExplorer
15
+ from faceforge_core.attribute_directions import LatentDirectionFinder
16
+ from faceforge_core.custom_loss import attribute_preserving_loss
17
+
18
+ # Configure logging
19
+ logging.basicConfig(
20
+ level=logging.DEBUG,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
22
+ handlers=[logging.StreamHandler(sys.stdout)]
23
+ )
24
+ logger = logging.getLogger("faceforge_api")
25
+
26
+ # --- Models for API ---
27
+
28
+ class PointIn(BaseModel):
29
+ text: str
30
+ encoding: Optional[List[float]] = None
31
+ xy_pos: Optional[List[float]] = None
32
+
33
+ class GenerateRequest(BaseModel):
34
+ prompts: List[str]
35
+ positions: Optional[List[List[float]]] = None
36
+ mode: str = "distance"
37
+ player_pos: Optional[List[float]] = None
38
+
39
+ class ManipulateRequest(BaseModel):
40
+ encoding: List[float]
41
+ direction: List[float]
42
+ alpha: float
43
+
44
+ class AttributeDirectionRequest(BaseModel):
45
+ latents: List[List[float]]
46
+ labels: Optional[List[int]] = None
47
+ n_components: Optional[int] = 10
48
+
49
+ # --- FastAPI app ---
50
 
51
  app = FastAPI()
52
 
53
+ # Add CORS middleware to allow requests from any origin
54
+ app.add_middleware(
55
+ CORSMiddleware,
56
+ allow_origins=["*"],
57
+ allow_credentials=True,
58
+ allow_methods=["*"],
59
+ allow_headers=["*"],
60
+ )
61
+
62
+ # Global explorer instance
63
+ explorer = LatentSpaceExplorer()
64
+
65
+ # Error handling middleware
66
+ @app.middleware("http")
67
+ async def error_handling_middleware(request: Request, call_next):
68
+ try:
69
+ return await call_next(request)
70
+ except Exception as e:
71
+ logger.error(f"Unhandled exception: {str(e)}")
72
+ logger.debug(traceback.format_exc())
73
+ return JSONResponse(
74
+ status_code=500,
75
+ content={"detail": "Internal server error", "error": str(e)},
76
+ )
77
+
78
  @app.get("/")
79
  def read_root():
80
+ logger.debug("Root endpoint called")
81
  return {"message": "FaceForge API is running"}
82
 
83
  @app.post("/generate")
84
+ def generate_image(req: GenerateRequest):
85
+ try:
86
+ logger.debug(f"Generate image request: {req}")
87
+
88
+ # Clear existing points
89
+ explorer.points = []
90
+
91
+ # Add points for each prompt
92
+ for i, prompt in enumerate(req.prompts):
93
+ logger.debug(f"Processing prompt {i}: {prompt}")
94
+
95
+ # Generate a mock encoding (in production, this would use a real model)
96
+ encoding = np.random.randn(512) # Stub: replace with real encoding
97
+
98
+ # Get position if provided, otherwise None
99
+ xy_pos = req.positions[i] if req.positions and i < len(req.positions) else None
100
+ logger.debug(f"Position for prompt {i}: {xy_pos}")
101
+
102
+ # Add point to explorer
103
+ explorer.add_point(prompt, encoding, xy_pos)
104
+
105
+ # Get player position
106
+ if req.player_pos is None:
107
+ player_pos = [0.0, 0.0]
108
+ else:
109
+ player_pos = req.player_pos
110
+ logger.debug(f"Player position: {player_pos}")
111
+
112
+ # Sample encoding
113
+ logger.debug(f"Sampling with mode: {req.mode}")
114
+ sampled = explorer.sample_encoding(tuple(player_pos), mode=req.mode)
115
+
116
+ # Generate mock image (in production, this would use the sampled encoding)
117
+ img = (np.random.rand(256, 256, 3) * 255).astype(np.uint8)
118
+
119
+ # Convert to base64
120
+ logger.debug("Converting image to base64")
121
+ pil_img = Image.fromarray(img)
122
+ buffer = io.BytesIO()
123
+ pil_img.save(buffer, format="PNG")
124
+ img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
125
+
126
+ logger.debug("Image generated successfully")
127
+ return {"status": "success", "image": img_b64}
128
+
129
+ except Exception as e:
130
+ logger.error(f"Error in generate_image: {str(e)}")
131
+ logger.debug(traceback.format_exc())
132
+ raise HTTPException(status_code=500, detail=str(e))
133
+
134
+ @app.post("/manipulate")
135
+ def manipulate(req: ManipulateRequest):
136
+ try:
137
+ logger.debug(f"Manipulate request: {req}")
138
+ encoding = np.array(req.encoding)
139
+ direction = np.array(req.direction)
140
+ manipulated = encoding + req.alpha * direction
141
+ logger.debug("Manipulation successful")
142
+ return {"manipulated_encoding": manipulated.tolist()}
143
+ except Exception as e:
144
+ logger.error(f"Error in manipulate: {str(e)}")
145
+ logger.debug(traceback.format_exc())
146
+ raise HTTPException(status_code=500, detail=str(e))
147
+
148
+ @app.post("/attribute_direction")
149
+ def attribute_direction(req: AttributeDirectionRequest):
150
+ try:
151
+ logger.debug(f"Attribute direction request: {req}")
152
+ latents = np.array(req.latents)
153
+ finder = LatentDirectionFinder(latents)
154
+
155
+ if req.labels is not None:
156
+ logger.debug("Using classifier-based direction finding")
157
+ direction = finder.classifier_direction(req.labels)
158
+ logger.debug("Direction found successfully")
159
+ return {"direction": direction.tolist()}
160
+ else:
161
+ logger.debug(f"Using PCA with {req.n_components} components")
162
+ components, explained = finder.pca_direction(n_components=req.n_components)
163
+ logger.debug("PCA completed successfully")
164
+ return {"components": components.tolist(), "explained_variance": explained.tolist()}
165
+ except Exception as e:
166
+ logger.error(f"Error in attribute_direction: {str(e)}")
167
+ logger.debug(traceback.format_exc())
168
+ raise HTTPException(status_code=500, detail=str(e))
faceforge_ui/app.py CHANGED
@@ -1,19 +1,154 @@
1
  import gradio as gr
2
  import requests
 
 
 
 
 
 
 
 
3
 
4
- def generate_image():
5
- response = requests.post("http://localhost:8000/generate")
6
- if response.ok:
7
- return None # Placeholder: should return image from backend
8
- return None
9
-
10
- iface = gr.Interface(
11
- fn=generate_image,
12
- inputs=[],
13
- outputs="image",
14
- title="FaceForge Latent Space Explorer",
15
- description="Interactively explore and edit faces in latent space."
16
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  if __name__ == "__main__":
19
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import requests
3
+ import numpy as np
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+ import logging
8
+ import sys
9
+ import traceback
10
+ import os
11
 
12
+ # Configure logging
13
+ logging.basicConfig(
14
+ level=logging.DEBUG,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
16
+ handlers=[logging.StreamHandler(sys.stdout)]
 
 
 
 
 
 
 
17
  )
18
+ logger = logging.getLogger("faceforge_ui")
19
+
20
+ # API configuration
21
+ API_URL = os.environ.get("API_URL", "http://localhost:8000")
22
+ logger.info(f"Using API URL: {API_URL}")
23
+
24
+ def generate_image(prompts, mode, player_x, player_y):
25
+ """
26
+ Generate an image based on prompts and player position.
27
+
28
+ Args:
29
+ prompts: Comma-separated list of prompts
30
+ mode: Sampling mode ('distance' or 'circle')
31
+ player_x: X-coordinate of player position
32
+ player_y: Y-coordinate of player position
33
+
34
+ Returns:
35
+ PIL.Image or None: Generated image or None if generation failed
36
+ """
37
+ try:
38
+ logger.debug(f"Generating image with prompts: {prompts}, mode: {mode}, position: ({player_x}, {player_y})")
39
+
40
+ # Parse prompts
41
+ prompt_list = [p.strip() for p in prompts.split(",") if p.strip()]
42
+ if not prompt_list:
43
+ logger.warning("No valid prompts provided")
44
+ return None
45
+
46
+ logger.debug(f"Parsed prompts: {prompt_list}")
47
+
48
+ # Prepare request
49
+ req = {
50
+ "prompts": prompt_list,
51
+ "mode": mode,
52
+ "player_pos": [float(player_x), float(player_y)]
53
+ }
54
+
55
+ logger.debug(f"Sending request to API: {req}")
56
+
57
+ # Make API call
58
+ try:
59
+ resp = requests.post(f"{API_URL}/generate", json=req, timeout=30)
60
+ logger.debug(f"API response status: {resp.status_code}")
61
+
62
+ if resp.ok:
63
+ data = resp.json()
64
+ logger.debug("Successfully received API response")
65
+
66
+ if "image" in data:
67
+ img_b64 = data["image"]
68
+ img_bytes = base64.b64decode(img_b64)
69
+
70
+ try:
71
+ img = Image.frombytes("RGB", (256, 256), img_bytes)
72
+ logger.debug("Successfully decoded image")
73
+ return img
74
+ except Exception as e:
75
+ logger.error(f"Error decoding image: {e}")
76
+ logger.debug(traceback.format_exc())
77
+ return None
78
+ else:
79
+ logger.warning("No image in API response")
80
+ return None
81
+ else:
82
+ logger.error(f"API error: {resp.status_code} - {resp.text}")
83
+ return None
84
+
85
+ except requests.exceptions.RequestException as e:
86
+ logger.error(f"Request failed: {e}")
87
+ logger.debug(traceback.format_exc())
88
+ return None
89
+
90
+ except Exception as e:
91
+ logger.error(f"Unexpected error: {e}")
92
+ logger.debug(traceback.format_exc())
93
+ return None
94
+
95
+ # Create Gradio interface
96
+ logger.info("Initializing Gradio interface")
97
+ with gr.Blocks() as demo:
98
+ gr.Markdown("# FaceForge Latent Space Explorer")
99
+
100
+ with gr.Row():
101
+ with gr.Column():
102
+ prompts = gr.Textbox(
103
+ label="Prompts (comma-separated)",
104
+ value="A photo of a cat, A photo of a dog",
105
+ info="Enter prompts separated by commas"
106
+ )
107
+ mode = gr.Radio(
108
+ choices=["distance", "circle"],
109
+ value="distance",
110
+ label="Sampling Mode",
111
+ info="Choose how to sample the latent space"
112
+ )
113
+ player_x = gr.Slider(-1.0, 1.0, value=0.0, label="Player X")
114
+ player_y = gr.Slider(-1.0, 1.0, value=0.0, label="Player Y")
115
+ btn = gr.Button("Generate")
116
+
117
+ with gr.Column():
118
+ img = gr.Image(label="Generated Image")
119
+ status = gr.Textbox(label="Status", interactive=False)
120
+
121
+ def on_generate_click(prompts, mode, player_x, player_y):
122
+ try:
123
+ logger.info("Generate button clicked")
124
+ result = generate_image(prompts, mode, player_x, player_y)
125
+ if result is not None:
126
+ return [result, "Image generated successfully"]
127
+ else:
128
+ return [None, "Failed to generate image. Check logs for details."]
129
+ except Exception as e:
130
+ logger.error(f"Error in generate button handler: {e}")
131
+ logger.debug(traceback.format_exc())
132
+ return [None, f"Error: {str(e)}"]
133
+
134
+ btn.click(
135
+ fn=on_generate_click,
136
+ inputs=[prompts, mode, player_x, player_y],
137
+ outputs=[img, status]
138
+ )
139
+
140
+ demo.load(lambda: "Ready to generate images", outputs=status)
141
 
142
  if __name__ == "__main__":
143
+ logger.info("Starting Gradio app")
144
+ try:
145
+ # Check if we're running in Hugging Face Spaces
146
+ if "SPACE_ID" in os.environ:
147
+ logger.info("Running in Hugging Face Space")
148
+ demo.launch(server_name="0.0.0.0", share=False)
149
+ else:
150
+ logger.info("Running locally")
151
+ demo.launch(server_name="0.0.0.0", share=False)
152
+ except Exception as e:
153
+ logger.critical(f"Failed to launch Gradio app: {e}")
154
+ logger.debug(traceback.format_exc())