mlbench123 commited on
Commit
74e6395
·
verified ·
1 Parent(s): f06bc30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1007 -0
app.py CHANGED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Union, Tuple
4
+ from PIL import Image
5
+ import ezdxf.units
6
+ import numpy as np
7
+ import torch
8
+ from torchvision import transforms
9
+ from ultralytics import YOLOWorld, YOLO
10
+ from ultralytics.engine.results import Results
11
+ from ultralytics.utils.plotting import save_one_box
12
+ from transformers import AutoModelForImageSegmentation
13
+ import cv2
14
+ import ezdxf
15
+ import gradio as gr
16
+ import gc
17
+ from scalingtestupdated import calculate_scaling_factor
18
+ from scipy.interpolate import splprep, splev
19
+ from scipy.ndimage import gaussian_filter1d
20
+ import json
21
+ import time
22
+ import signal
23
+ from shapely.ops import unary_union
24
+ from shapely.geometry import MultiPolygon, GeometryCollection, Polygon, Point
25
+ from u2netp import U2NETP
26
+ import logging
27
+ import shutil
28
+
29
+ # Initialize logging
30
+ logging.basicConfig(level=logging.INFO)
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Create cache directory for models
34
+ CACHE_DIR = os.path.join(os.path.dirname(__file__), ".cache")
35
+ os.makedirs(CACHE_DIR, exist_ok=True)
36
+
37
+ # Paper size configurations (in mm)
38
+ PAPER_SIZES = {
39
+ "A4": {"width": 210, "height": 297},
40
+ "A3": {"width": 297, "height": 420},
41
+ "US Letter": {"width": 215.9, "height": 279.4}
42
+ }
43
+
44
+ # Custom Exception Classes
45
+ class TimeoutReachedError(Exception):
46
+ pass
47
+
48
+ class BoundaryOverlapError(Exception):
49
+ pass
50
+
51
+ class TextOverlapError(Exception):
52
+ pass
53
+
54
+ class PaperNotDetectedError(Exception):
55
+ """Raised when the paper cannot be detected in the image"""
56
+ pass
57
+
58
+ class MultipleObjectsError(Exception):
59
+ """Raised when multiple objects are detected on the paper"""
60
+ def __init__(self, message="Multiple objects detected. Please place only a single object on the paper."):
61
+ super().__init__(message)
62
+
63
+ class NoObjectDetectedError(Exception):
64
+ """Raised when no object is detected on the paper"""
65
+ def __init__(self, message="No object detected on the paper. Please ensure an object is placed on the paper."):
66
+ super().__init__(message)
67
+
68
+ class FingerCutOverlapError(Exception):
69
+ """Raised when finger cuts overlap with existing geometry"""
70
+ def __init__(self, message="There was an overlap with fingercuts... Please try again to generate dxf."):
71
+ super().__init__(message)
72
+
73
+ # Global model variables for lazy loading
74
+ paper_detector_global = None
75
+ u2net_global = None
76
+ birefnet = None
77
+
78
+ # Model paths
79
+ paper_model_path = os.path.join(CACHE_DIR, "paper_detector.pt") # You'll need to train/provide this
80
+ u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
81
+
82
+ # Device configuration
83
+ device = "cpu"
84
+ torch.set_float32_matmul_precision(["high", "highest"][0])
85
+
86
+ def ensure_model_files():
87
+ """Ensure model files are available in cache directory"""
88
+ if not os.path.exists(paper_model_path):
89
+ if os.path.exists("paper_detector.pt"):
90
+ shutil.copy("paper_detector.pt", paper_model_path)
91
+ else:
92
+ logger.warning("paper_detector.pt model file not found - using fallback detection")
93
+
94
+ if not os.path.exists(u2net_model_path):
95
+ if os.path.exists("u2netp.pth"):
96
+ shutil.copy("u2netp.pth", u2net_model_path)
97
+ else:
98
+ raise FileNotFoundError("u2netp.pth model file not found")
99
+
100
+ ensure_model_files()
101
+
102
+ # Lazy loading functions
103
+ def get_paper_detector():
104
+ """Lazy load paper detector model"""
105
+ global paper_detector_global
106
+ if paper_detector_global is None:
107
+ logger.info("Loading paper detector model...")
108
+ if os.path.exists(paper_model_path):
109
+ paper_detector_global = YOLO(paper_model_path)
110
+ else:
111
+ # Fallback to generic object detection for paper-like rectangles
112
+ logger.warning("Using fallback paper detection")
113
+ paper_detector_global = None
114
+ logger.info("Paper detector loaded successfully")
115
+ return paper_detector_global
116
+
117
+ def get_u2net():
118
+ """Lazy load U2NETP model"""
119
+ global u2net_global
120
+ if u2net_global is None:
121
+ logger.info("Loading U2NETP model...")
122
+ u2net_global = U2NETP(3, 1)
123
+ u2net_global.load_state_dict(torch.load(u2net_model_path, map_location="cpu"))
124
+ u2net_global.to(device)
125
+ u2net_global.eval()
126
+ logger.info("U2NETP model loaded successfully")
127
+ return u2net_global
128
+
129
+ def load_birefnet_model():
130
+ """Load BiRefNet model from HuggingFace"""
131
+ return AutoModelForImageSegmentation.from_pretrained(
132
+ 'ZhengPeng7/BiRefNet',
133
+ trust_remote_code=True
134
+ )
135
+
136
+ def get_birefnet():
137
+ """Lazy load BiRefNet model"""
138
+ global birefnet
139
+ if birefnet is None:
140
+ logger.info("Loading BiRefNet model...")
141
+ birefnet = load_birefnet_model()
142
+ birefnet.to(device)
143
+ birefnet.eval()
144
+ logger.info("BiRefNet model loaded successfully")
145
+ return birefnet
146
+
147
+ def detect_paper_contour(image: np.ndarray) -> Tuple[np.ndarray, float]:
148
+ """
149
+ Detect paper in the image using contour detection as fallback
150
+ Returns the paper contour and estimated scaling factor
151
+ """
152
+ # Convert to grayscale
153
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
154
+
155
+ # Apply Gaussian blur
156
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
157
+
158
+ # Edge detection
159
+ edges = cv2.Canny(blurred, 50, 150)
160
+
161
+ # Find contours
162
+ contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
163
+
164
+ # Filter contours by area and aspect ratio to find paper-like rectangles
165
+ paper_contours = []
166
+ min_area = (image.shape[0] * image.shape[1]) * 0.1 # At least 10% of image
167
+
168
+ for contour in contours:
169
+ area = cv2.contourArea(contour)
170
+ if area > min_area:
171
+ # Approximate contour to polygon
172
+ epsilon = 0.02 * cv2.arcLength(contour, True)
173
+ approx = cv2.approxPolyDP(contour, epsilon, True)
174
+
175
+ # Check if it's roughly rectangular (4 corners)
176
+ if len(approx) >= 4:
177
+ # Calculate bounding rectangle
178
+ rect = cv2.boundingRect(approx)
179
+ aspect_ratio = rect[2] / rect[3] # width / height
180
+
181
+ # Check if aspect ratio matches common paper ratios
182
+ # A4: 1.414, A3: 1.414, US Letter: 1.294
183
+ if 0.7 < aspect_ratio < 1.8: # Allow some tolerance
184
+ paper_contours.append((contour, area, aspect_ratio))
185
+
186
+ if not paper_contours:
187
+ raise PaperNotDetectedError("Could not detect paper in the image")
188
+
189
+ # Select the largest paper-like contour
190
+ paper_contours.sort(key=lambda x: x[1], reverse=True)
191
+ best_contour = paper_contours[0][0]
192
+
193
+ return best_contour, 0.0 # Return 0.0 as placeholder scaling factor
194
+
195
+ def detect_paper_bounds(image: np.ndarray, paper_size: str) -> Tuple[np.ndarray, float]:
196
+ """
197
+ Detect paper bounds in the image and calculate scaling factor
198
+ """
199
+ try:
200
+ paper_detector = get_paper_detector()
201
+
202
+ if paper_detector is not None:
203
+ # Use trained model if available
204
+ results = paper_detector.predict(image, conf=0.5)
205
+ if not results or len(results) == 0 or len(results[0].boxes) == 0:
206
+ logger.warning("Model detection failed, using fallback contour detection")
207
+ return detect_paper_contour(image)
208
+
209
+ # Get the largest detected paper
210
+ boxes = results[0].cpu().boxes.xyxy
211
+ largest_box = None
212
+ max_area = 0
213
+
214
+ for box in boxes:
215
+ x_min, y_min, x_max, y_max = box
216
+ area = (x_max - x_min) * (y_max - y_min)
217
+ if area > max_area:
218
+ max_area = area
219
+ largest_box = box
220
+
221
+ if largest_box is None:
222
+ raise PaperNotDetectedError("No paper detected by model")
223
+
224
+ # Convert box to contour-like format
225
+ x_min, y_min, x_max, y_max = map(int, largest_box)
226
+ paper_contour = np.array([
227
+ [[x_min, y_min]],
228
+ [[x_max, y_min]],
229
+ [[x_max, y_max]],
230
+ [[x_min, y_max]]
231
+ ])
232
+
233
+ else:
234
+ # Use fallback contour detection
235
+ paper_contour, _ = detect_paper_contour(image)
236
+
237
+ # Calculate scaling factor based on paper size
238
+ scaling_factor = calculate_paper_scaling_factor(paper_contour, paper_size)
239
+
240
+ return paper_contour, scaling_factor
241
+
242
+ except Exception as e:
243
+ logger.error(f"Error in paper detection: {e}")
244
+ raise PaperNotDetectedError(f"Failed to detect paper: {str(e)}")
245
+
246
+ def calculate_paper_scaling_factor(paper_contour: np.ndarray, paper_size: str) -> float:
247
+ """
248
+ Calculate scaling factor based on detected paper dimensions
249
+ """
250
+ # Get paper dimensions
251
+ paper_dims = PAPER_SIZES[paper_size]
252
+ expected_width_mm = paper_dims["width"]
253
+ expected_height_mm = paper_dims["height"]
254
+
255
+ # Calculate bounding rectangle of paper contour
256
+ rect = cv2.boundingRect(paper_contour)
257
+ detected_width_px = rect[2]
258
+ detected_height_px = rect[3]
259
+
260
+ # Calculate scaling factors for both dimensions
261
+ scale_x = expected_width_mm / detected_width_px
262
+ scale_y = expected_height_mm / detected_height_px
263
+
264
+ # Use average of both scales
265
+ scaling_factor = (scale_x + scale_y) / 2
266
+
267
+ logger.info(f"Paper detection: {detected_width_px}x{detected_height_px} px -> {expected_width_mm}x{expected_height_mm} mm")
268
+ logger.info(f"Calculated scaling factor: {scaling_factor:.4f} mm/px")
269
+
270
+ return scaling_factor
271
+
272
+ def validate_single_object(mask: np.ndarray, paper_contour: np.ndarray) -> None:
273
+ """
274
+ Validate that only a single object is present on the paper
275
+ """
276
+ # Create a mask for the paper area
277
+ paper_mask = np.zeros(mask.shape[:2], dtype=np.uint8)
278
+ cv2.fillPoly(paper_mask, [paper_contour], 255)
279
+
280
+ # Apply paper mask to object mask
281
+ masked_objects = cv2.bitwise_and(mask, paper_mask)
282
+
283
+ # Find contours of objects within paper bounds
284
+ contours, _ = cv2.findContours(masked_objects, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
285
+
286
+ # Filter out very small contours (noise)
287
+ min_area = 1000 # Minimum area threshold
288
+ significant_contours = [c for c in contours if cv2.contourArea(c) > min_area]
289
+
290
+ if len(significant_contours) == 0:
291
+ raise NoObjectDetectedError()
292
+ elif len(significant_contours) > 1:
293
+ raise MultipleObjectsError()
294
+
295
+ logger.info(f"Single object validated: {len(significant_contours)} significant contour(s) found")
296
+
297
+ def remove_bg_u2netp(image: np.ndarray) -> np.ndarray:
298
+ """Remove background using U2NETP model"""
299
+ try:
300
+ u2net_model = get_u2net()
301
+
302
+ image_pil = Image.fromarray(image)
303
+ transform_u2netp = transforms.Compose([
304
+ transforms.Resize((320, 320)),
305
+ transforms.ToTensor(),
306
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
307
+ ])
308
+
309
+ input_tensor = transform_u2netp(image_pil).unsqueeze(0).to(device)
310
+
311
+ with torch.no_grad():
312
+ outputs = u2net_model(input_tensor)
313
+
314
+ pred = outputs[0]
315
+ pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
316
+ pred_np = pred.squeeze().cpu().numpy()
317
+ pred_np = cv2.resize(pred_np, (image_pil.width, image_pil.height))
318
+ pred_np = (pred_np * 255).astype(np.uint8)
319
+
320
+ return pred_np
321
+ except Exception as e:
322
+ logger.error(f"Error in U2NETP background removal: {e}")
323
+ raise
324
+
325
+ def remove_bg(image: np.ndarray) -> np.ndarray:
326
+ """Remove background using BiRefNet model for main objects"""
327
+ try:
328
+ birefnet_model = get_birefnet()
329
+
330
+ transform_image = transforms.Compose([
331
+ transforms.Resize((1024, 1024)),
332
+ transforms.ToTensor(),
333
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
334
+ ])
335
+
336
+ image_pil = Image.fromarray(image)
337
+ input_images = transform_image(image_pil).unsqueeze(0).to(device)
338
+
339
+ with torch.no_grad():
340
+ preds = birefnet_model(input_images)[-1].sigmoid().cpu()
341
+ pred = preds[0].squeeze()
342
+
343
+ pred_pil = transforms.ToPILImage()(pred)
344
+
345
+ scale_ratio = 1024 / max(image_pil.size)
346
+ scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
347
+
348
+ return np.array(pred_pil.resize(scaled_size))
349
+ except Exception as e:
350
+ logger.error(f"Error in BiRefNet background removal: {e}")
351
+ raise
352
+
353
+ def exclude_paper_area(mask: np.ndarray, paper_contour: np.ndarray, expansion_factor: float = 1.1) -> np.ndarray:
354
+ """
355
+ Remove paper area from the mask to focus only on objects
356
+ """
357
+ # Create paper mask with slight expansion to ensure complete removal
358
+ paper_mask = np.zeros(mask.shape[:2], dtype=np.uint8)
359
+
360
+ # Expand paper contour slightly
361
+ epsilon = expansion_factor * cv2.arcLength(paper_contour, True)
362
+ expanded_contour = cv2.approxPolyDP(paper_contour, epsilon, True)
363
+
364
+ cv2.fillPoly(paper_mask, [expanded_contour], 255)
365
+
366
+ # Invert paper mask and apply to object mask
367
+ paper_mask_inv = cv2.bitwise_not(paper_mask)
368
+ result_mask = cv2.bitwise_and(mask, paper_mask_inv)
369
+
370
+ return result_mask
371
+
372
+ def resample_contour(contour, edge_radius_px: int = 0):
373
+ """Resample contour with radius-aware smoothing and periodic handling."""
374
+ logger.info(f"Starting resample_contour with contour of shape {contour.shape}")
375
+
376
+ num_points = 1500
377
+ sigma = max(2, int(edge_radius_px) // 4)
378
+
379
+ if len(contour) < 4:
380
+ error_msg = f"Contour must have at least 4 points, but has {len(contour)} points."
381
+ logger.error(error_msg)
382
+ raise ValueError(error_msg)
383
+
384
+ try:
385
+ contour = contour[:, 0, :]
386
+ logger.debug(f"Reshaped contour to shape {contour.shape}")
387
+
388
+ if not np.array_equal(contour[0], contour[-1]):
389
+ contour = np.vstack([contour, contour[0]])
390
+
391
+ tck, u = splprep(contour.T, u=None, s=0, per=True)
392
+
393
+ u_new = np.linspace(u.min(), u.max(), num_points)
394
+ x_new, y_new = splev(u_new, tck, der=0)
395
+
396
+ if sigma > 0:
397
+ x_new = gaussian_filter1d(x_new, sigma=sigma, mode='wrap')
398
+ y_new = gaussian_filter1d(y_new, sigma=sigma, mode='wrap')
399
+
400
+ x_new[-1] = x_new[0]
401
+ y_new[-1] = y_new[0]
402
+
403
+ result = np.array([x_new, y_new]).T
404
+ logger.info(f"Completed resample_contour with result shape {result.shape}")
405
+ return result
406
+
407
+ except Exception as e:
408
+ logger.error(f"Error in resample_contour: {e}")
409
+ raise
410
+
411
+ def save_dxf_spline(inflated_contours, scaling_factor, height, finger_clearance=False):
412
+ """Save contours as DXF splines with optional finger cuts"""
413
+ doc = ezdxf.new(units=ezdxf.units.MM)
414
+ doc.header["$INSUNITS"] = ezdxf.units.MM
415
+ msp = doc.modelspace()
416
+ final_polygons_inch = []
417
+ finger_centers = []
418
+ original_polygons = []
419
+
420
+ # Scale correction factor
421
+ scale_correction = 1.079
422
+
423
+ for contour in inflated_contours:
424
+ try:
425
+ resampled_contour = resample_contour(contour)
426
+
427
+ points_inch = [(x * scaling_factor, (height - y) * scaling_factor)
428
+ for x, y in resampled_contour]
429
+
430
+ if len(points_inch) < 3:
431
+ continue
432
+
433
+ tool_polygon = build_tool_polygon(points_inch)
434
+ original_polygons.append(tool_polygon)
435
+
436
+ if finger_clearance:
437
+ try:
438
+ tool_polygon, center = place_finger_cut_adjusted(
439
+ tool_polygon, points_inch, finger_centers, final_polygons_inch
440
+ )
441
+ except FingerCutOverlapError:
442
+ tool_polygon = original_polygons[-1]
443
+
444
+ exterior_coords = polygon_to_exterior_coords(tool_polygon)
445
+ if len(exterior_coords) < 3:
446
+ continue
447
+
448
+ # Apply scale correction
449
+ corrected_coords = [(x * scale_correction, y * scale_correction) for x, y in exterior_coords]
450
+
451
+ msp.add_spline(corrected_coords, degree=3, dxfattribs={"layer": "TOOLS"})
452
+ final_polygons_inch.append(tool_polygon)
453
+
454
+ except ValueError as e:
455
+ logger.warning(f"Skipping contour: {e}")
456
+
457
+ dxf_filepath = os.path.join("./outputs", "out.dxf")
458
+ doc.saveas(dxf_filepath)
459
+ return dxf_filepath, final_polygons_inch, original_polygons
460
+
461
+ def build_tool_polygon(points_inch):
462
+ """Build a polygon from inch-converted points"""
463
+ return Polygon(points_inch)
464
+
465
+ def polygon_to_exterior_coords(poly):
466
+ """Extract exterior coordinates from polygon"""
467
+ logger.info(f"Starting polygon_to_exterior_coords with input geometry type: {poly.geom_type}")
468
+
469
+ try:
470
+ if poly.geom_type == "GeometryCollection" or poly.geom_type == "MultiPolygon":
471
+ logger.debug(f"Performing unary_union on {poly.geom_type}")
472
+ unified = unary_union(poly)
473
+ if unified.is_empty:
474
+ logger.warning("unary_union produced an empty geometry; returning empty list")
475
+ return []
476
+
477
+ if unified.geom_type == "GeometryCollection" or unified.geom_type == "MultiPolygon":
478
+ largest = None
479
+ max_area = 0.0
480
+ for g in getattr(unified, "geoms", []):
481
+ if hasattr(g, "area") and g.area > max_area and hasattr(g, "exterior"):
482
+ max_area = g.area
483
+ largest = g
484
+ if largest is None:
485
+ logger.warning("No valid Polygon found in unified geometry; returning empty list")
486
+ return []
487
+ poly = largest
488
+ else:
489
+ poly = unified
490
+
491
+ if not hasattr(poly, "exterior") or poly.exterior is None:
492
+ logger.warning("Input geometry has no exterior ring; returning empty list")
493
+ return []
494
+
495
+ raw_coords = list(poly.exterior.coords)
496
+ total = len(raw_coords)
497
+ logger.info(f"Extracted {total} raw exterior coordinates")
498
+
499
+ if total == 0:
500
+ return []
501
+
502
+ # Subsample coordinates to at most 100 points
503
+ max_pts = 100
504
+ if total > max_pts:
505
+ step = total // max_pts
506
+ sampled = [raw_coords[i] for i in range(0, total, step)]
507
+ if sampled[-1] != raw_coords[-1]:
508
+ sampled.append(raw_coords[-1])
509
+ logger.info(f"Downsampled perimeter from {total} to {len(sampled)} points")
510
+ return sampled
511
+ else:
512
+ return raw_coords
513
+
514
+ except Exception as e:
515
+ logger.error(f"Error in polygon_to_exterior_coords: {e}")
516
+ return []
517
+
518
+ def place_finger_cut_adjusted(
519
+ tool_polygon: Polygon,
520
+ points_inch: list,
521
+ existing_centers: list,
522
+ all_polygons: list,
523
+ circle_diameter: float = 25.4,
524
+ min_gap: float = 0.5,
525
+ max_attempts: int = 100
526
+ ) -> Tuple[Polygon, tuple]:
527
+ """Place finger cuts with collision avoidance"""
528
+ logger.info(f"Starting place_finger_cut_adjusted with {len(points_inch)} input points")
529
+
530
+ def fallback_solution():
531
+ logger.warning("Using fallback approach for finger cut placement")
532
+ fallback_center = points_inch[len(points_inch) // 2]
533
+ r = circle_diameter / 2.0
534
+ fallback_circle = Point(fallback_center).buffer(r, resolution=32)
535
+ try:
536
+ union_poly = tool_polygon.union(fallback_circle)
537
+ except Exception as e:
538
+ logger.warning(f"Fallback union failed ({e}); trying buffer-union fallback")
539
+ union_poly = tool_polygon.buffer(0).union(fallback_circle.buffer(0))
540
+
541
+ existing_centers.append(fallback_center)
542
+ logger.info(f"Fallback finger cut placed at {fallback_center}")
543
+ return union_poly, fallback_center
544
+
545
+ r = circle_diameter / 2.0
546
+ needed_center_dist = circle_diameter + min_gap
547
+
548
+ raw_perimeter = polygon_to_exterior_coords(tool_polygon)
549
+ if not raw_perimeter:
550
+ logger.warning("No valid exterior coords found; using fallback immediately")
551
+ return fallback_solution()
552
+
553
+ if len(raw_perimeter) > 100:
554
+ step = len(raw_perimeter) // 100
555
+ perimeter_coords = raw_perimeter[::step]
556
+ logger.info(f"Subsampled perimeter from {len(raw_perimeter)} to {len(perimeter_coords)} points")
557
+ else:
558
+ perimeter_coords = raw_perimeter[:]
559
+
560
+ indices = list(range(len(perimeter_coords)))
561
+ np.random.shuffle(indices)
562
+ logger.debug(f"Shuffled perimeter indices for candidate order")
563
+
564
+ start_time = time.time()
565
+ timeout_secs = 5.0
566
+
567
+ attempts = 0
568
+ try:
569
+ while attempts < max_attempts:
570
+ if time.time() - start_time > timeout_secs - 0.1:
571
+ logger.warning(f"Approaching timeout after {attempts} attempts")
572
+ return fallback_solution()
573
+
574
+ for idx in indices:
575
+ if time.time() - start_time > timeout_secs - 0.05:
576
+ logger.warning("Timeout during candidate-point loop")
577
+ return fallback_solution()
578
+
579
+ cx, cy = perimeter_coords[idx]
580
+ for dx, dy in [(0, 0), (-min_gap/2, 0), (min_gap/2, 0), (0, -min_gap/2), (0, min_gap/2)]:
581
+ candidate_center = (cx + dx, cy + dy)
582
+
583
+ # Check distance to existing finger centers
584
+ too_close_finger = any(
585
+ np.hypot(candidate_center[0] - ex, candidate_center[1] - ey)
586
+ < needed_center_dist
587
+ for (ex, ey) in existing_centers
588
+ )
589
+ if too_close_finger:
590
+ continue
591
+
592
+ # Build candidate circle
593
+ candidate_circle = Point(candidate_center).buffer(r, resolution=32)
594
+
595
+ # Must overlap ≥30% with this polygon
596
+ try:
597
+ inter_area = tool_polygon.intersection(candidate_circle).area
598
+ except Exception:
599
+ continue
600
+
601
+ if inter_area < 0.3 * candidate_circle.area:
602
+ continue
603
+
604
+ # Must not intersect other polygons
605
+ invalid = False
606
+ for other_poly in all_polygons:
607
+ if other_poly.equals(tool_polygon):
608
+ continue
609
+ if other_poly.buffer(min_gap).intersects(candidate_circle) or \
610
+ other_poly.buffer(min_gap).touches(candidate_circle):
611
+ invalid = True
612
+ break
613
+ if invalid:
614
+ continue
615
+
616
+ # Union and return
617
+ try:
618
+ union_poly = tool_polygon.union(candidate_circle)
619
+ if union_poly.geom_type == "MultiPolygon" and len(union_poly.geoms) > 1:
620
+ continue
621
+ if union_poly.equals(tool_polygon):
622
+ continue
623
+ except Exception:
624
+ continue
625
+
626
+ existing_centers.append(candidate_center)
627
+ logger.info(f"Finger cut placed successfully at {candidate_center} after {attempts} attempts")
628
+ return union_poly, candidate_center
629
+
630
+ attempts += 1
631
+ if attempts >= (max_attempts // 2) and (time.time() - start_time) > timeout_secs * 0.8:
632
+ logger.warning(f"Approaching timeout (attempt {attempts})")
633
+ return fallback_solution()
634
+
635
+ logger.warning(f"No valid spot after {max_attempts} attempts, using fallback")
636
+ return fallback_solution()
637
+
638
+ except Exception as e:
639
+ logger.error(f"Error in place_finger_cut_adjusted: {e}")
640
+ return fallback_solution()
641
+
642
+ def extract_outlines(binary_image: np.ndarray) -> Tuple[np.ndarray, list]:
643
+ """Extract outlines from binary image"""
644
+ contours, _ = cv2.findContours(
645
+ binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
646
+ )
647
+ outline_image = np.full_like(binary_image, 255)
648
+ return outline_image, contours
649
+
650
+ def round_edges(mask: np.ndarray, radius_mm: float, scaling_factor: float) -> np.ndarray:
651
+ """Round mask edges using contour smoothing"""
652
+ if radius_mm <= 0 or scaling_factor <= 0:
653
+ return mask
654
+
655
+ radius_px = max(1, int(radius_mm / scaling_factor))
656
+
657
+ if np.count_nonzero(mask) < 500:
658
+ return cv2.dilate(cv2.erode(mask, np.ones((3,3))), np.ones((3,3)))
659
+
660
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
661
+ contours = [c for c in contours if cv2.contourArea(c) > 100]
662
+ smoothed_contours = []
663
+
664
+ for contour in contours:
665
+ try:
666
+ resampled = resample_contour(contour, radius_px)
667
+ resampled = resampled.astype(np.int32).reshape((-1, 1, 2))
668
+ smoothed_contours.append(resampled)
669
+ except Exception as e:
670
+ logger.warning(f"Error smoothing contour: {e}")
671
+ smoothed_contours.append(contour)
672
+
673
+ rounded = np.zeros_like(mask)
674
+ cv2.drawContours(rounded, smoothed_contours, -1, 255, thickness=cv2.FILLED)
675
+
676
+ return rounded
677
+
678
+ def cleanup_memory():
679
+ """Clean up memory after processing"""
680
+ if torch.cuda.is_available():
681
+ torch.cuda.empty_cache()
682
+ gc.collect()
683
+ logger.info("Memory cleanup completed")
684
+
685
+ def cleanup_models():
686
+ """Unload models to free memory"""
687
+ global paper_detector_global, u2net_global, birefnet
688
+ if paper_detector_global is not None:
689
+ del paper_detector_global
690
+ paper_detector_global = None
691
+ if u2net_global is not None:
692
+ del u2net_global
693
+ u2net_global = None
694
+ if birefnet is not None:
695
+ del birefnet
696
+ birefnet = None
697
+ cleanup_memory()
698
+
699
+ def make_square(img: np.ndarray):
700
+ """Make the image square by padding"""
701
+ height, width = img.shape[:2]
702
+ max_dim = max(height, width)
703
+
704
+ pad_height = (max_dim - height) // 2
705
+ pad_width = (max_dim - width) // 2
706
+
707
+ pad_height_extra = max_dim - height - 2 * pad_height
708
+ pad_width_extra = max_dim - width - 2 * pad_width
709
+
710
+ if len(img.shape) == 3:
711
+ padded = np.pad(
712
+ img,
713
+ (
714
+ (pad_height, pad_height + pad_height_extra),
715
+ (pad_width, pad_width + pad_width_extra),
716
+ (0, 0),
717
+ ),
718
+ mode="edge",
719
+ )
720
+ else:
721
+ padded = np.pad(
722
+ img,
723
+ (
724
+ (pad_height, pad_height + pad_height_extra),
725
+ (pad_width, pad_width + pad_width_extra),
726
+ ),
727
+ mode="edge",
728
+ )
729
+
730
+ return padded
731
+
732
+ def predict_with_paper(image, paper_size, offset, offset_unit, edge_radius, finger_clearance=False):
733
+ """Main prediction function using paper as reference"""
734
+
735
+ if offset_unit == "inches":
736
+ offset *= 25.4
737
+
738
+ if edge_radius is None or edge_radius == 0:
739
+ edge_radius = 0.0001
740
+
741
+ if offset < 0:
742
+ raise gr.Error("Offset Value Can't be negative")
743
+
744
+ try:
745
+ # Detect paper bounds and calculate scaling factor
746
+ paper_contour, scaling_factor = detect_paper_bounds(image, paper_size)
747
+ logger.info(f"Paper detected with scaling factor: {scaling_factor:.4f} mm/px")
748
+
749
+ except PaperNotDetectedError as e:
750
+ return (
751
+ None, None, None, None,
752
+ f"Error: {str(e)}"
753
+ )
754
+ except Exception as e:
755
+ raise gr.Error(f"Error processing image: {str(e)}")
756
+
757
+ try:
758
+ # Remove background from main objects
759
+ orig_size = image.shape[:2]
760
+ objects_mask = remove_bg(image)
761
+ processed_size = objects_mask.shape[:2]
762
+
763
+ # Resize mask to match original image
764
+ objects_mask = cv2.resize(objects_mask, (image.shape[1], image.shape[0]))
765
+
766
+ # Remove paper area from mask to focus only on objects
767
+ objects_mask = exclude_paper_area(objects_mask, paper_contour)
768
+
769
+ # Validate single object
770
+ validate_single_object(objects_mask, paper_contour)
771
+
772
+ except (MultipleObjectsError, NoObjectDetectedError) as e:
773
+ return (
774
+ None, None, None, None,
775
+ f"Error: {str(e)}"
776
+ )
777
+ except Exception as e:
778
+ raise gr.Error(f"Error in object detection: {str(e)}")
779
+
780
+ # Apply edge rounding if specified
781
+ if edge_radius > 0:
782
+ rounded_mask = round_edges(objects_mask, edge_radius, scaling_factor)
783
+ else:
784
+ rounded_mask = objects_mask.copy()
785
+
786
+ # Apply dilation for offset
787
+ if offset > 0:
788
+ offset_pixels = (float(offset) / scaling_factor) * 2 + 1 if scaling_factor else 1
789
+ kernel = np.ones((int(offset_pixels), int(offset_pixels)), np.uint8)
790
+ dilated_mask = cv2.dilate(rounded_mask, kernel)
791
+ else:
792
+ dilated_mask = rounded_mask.copy()
793
+
794
+ # Save original dilated mask for output
795
+ Image.fromarray(dilated_mask).save("./outputs/scaled_mask_original.jpg")
796
+ dilated_mask_orig = dilated_mask.copy()
797
+
798
+ # Extract contours
799
+ outlines, contours = extract_outlines(dilated_mask)
800
+
801
+ try:
802
+ # Generate DXF
803
+ dxf, finger_polygons, original_polygons = save_dxf_spline(
804
+ contours,
805
+ scaling_factor,
806
+ processed_size[0],
807
+ finger_clearance=(finger_clearance == "On")
808
+ )
809
+ except FingerCutOverlapError as e:
810
+ raise gr.Error(str(e))
811
+
812
+ # Create annotated image
813
+ shrunked_img_contours = image.copy()
814
+
815
+ if finger_clearance == "On":
816
+ outlines = np.full_like(dilated_mask, 255)
817
+ for poly in finger_polygons:
818
+ try:
819
+ coords = np.array([
820
+ (int(x / scaling_factor), int(processed_size[0] - y / scaling_factor))
821
+ for x, y in poly.exterior.coords
822
+ ], np.int32).reshape((-1, 1, 2))
823
+
824
+ cv2.drawContours(shrunked_img_contours, [coords], -1, (0, 255, 0), thickness=2)
825
+ cv2.drawContours(outlines, [coords], -1, 0, thickness=2)
826
+ except Exception as e:
827
+ logger.warning(f"Failed to draw finger cut: {e}")
828
+ continue
829
+ else:
830
+ outlines = np.full_like(dilated_mask, 255)
831
+ cv2.drawContours(shrunked_img_contours, contours, -1, (0, 255, 0), thickness=2)
832
+ cv2.drawContours(outlines, contours, -1, 0, thickness=2)
833
+
834
+ # Draw paper bounds on annotated image
835
+ cv2.drawContours(shrunked_img_contours, [paper_contour], -1, (255, 0, 0), thickness=3)
836
+
837
+ # Add paper size text
838
+ paper_text = f"Paper: {paper_size}"
839
+ cv2.putText(shrunked_img_contours, paper_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
840
+
841
+ cleanup_models()
842
+
843
+ return (
844
+ shrunked_img_contours,
845
+ outlines,
846
+ dxf,
847
+ dilated_mask_orig,
848
+ f"Scale: {scaling_factor:.4f} mm/px | Paper: {paper_size}"
849
+ )
850
+
851
+ def predict_full_paper(image, paper_size, enable_fillet, fillet_value_mm, enable_finger_cut, selected_outputs):
852
+ """
853
+ Full prediction function with paper reference and flexible outputs
854
+ Returns DXF + conditionally selected additional outputs
855
+ """
856
+ radius = fillet_value_mm if enable_fillet == "On" else 0
857
+ finger_flag = "On" if enable_finger_cut == "On" else "Off"
858
+
859
+ # Always get all outputs from predict_with_paper
860
+ ann, outlines, dxf_path, mask, scale_info = predict_with_paper(
861
+ image,
862
+ paper_size,
863
+ offset=0, # No offset for now, can be added as parameter later
864
+ offset_unit="mm",
865
+ edge_radius=radius,
866
+ finger_clearance=finger_flag,
867
+ )
868
+
869
+ # Return based on selected outputs
870
+ return (
871
+ dxf_path, # Always return DXF
872
+ ann if "Annotated Image" in selected_outputs else None,
873
+ outlines if "Outlines" in selected_outputs else None,
874
+ mask if "Mask" in selected_outputs else None,
875
+ scale_info # Always return scaling info
876
+ )
877
+
878
+ # Gradio Interface
879
+ if __name__ == "__main__":
880
+ os.makedirs("./outputs", exist_ok=True)
881
+
882
+ with gr.Blocks(title="Paper-Based DXF Generator", theme=gr.themes.Soft()) as demo:
883
+ gr.Markdown("""
884
+ # Paper-Based DXF Generator
885
+
886
+ Upload an image with a single object placed on paper (A4, A3, or US Letter).
887
+ The paper serves as a size reference for accurate DXF generation.
888
+
889
+ **Instructions:**
890
+ 1. Place a single object on paper
891
+ 2. Select the correct paper size
892
+ 3. Configure options as needed
893
+ 4. Click Submit to generate DXF
894
+ """)
895
+
896
+ with gr.Row():
897
+ with gr.Column():
898
+ input_image = gr.Image(
899
+ label="Input Image (Object on Paper)",
900
+ type="numpy",
901
+ height=400
902
+ )
903
+
904
+ paper_size = gr.Radio(
905
+ choices=["A4", "A3", "US Letter"],
906
+ value="A4",
907
+ label="Paper Size",
908
+ info="Select the paper size used in your image"
909
+ )
910
+
911
+ with gr.Group():
912
+ gr.Markdown("### Edge Rounding")
913
+ enable_fillet = gr.Radio(
914
+ choices=["On", "Off"],
915
+ value="Off",
916
+ label="Enable Edge Rounding",
917
+ interactive=True
918
+ )
919
+
920
+ fillet_value_mm = gr.Slider(
921
+ minimum=0,
922
+ maximum=20,
923
+ step=1,
924
+ value=5,
925
+ label="Edge Radius (mm)",
926
+ visible=False,
927
+ interactive=True
928
+ )
929
+
930
+ with gr.Group():
931
+ gr.Markdown("### Finger Cuts")
932
+ enable_finger_cut = gr.Radio(
933
+ choices=["On", "Off"],
934
+ value="Off",
935
+ label="Enable Finger Cuts",
936
+ info="Add circular cuts for easier handling"
937
+ )
938
+
939
+ output_options = gr.CheckboxGroup(
940
+ choices=["Annotated Image", "Outlines", "Mask"],
941
+ value=[],
942
+ label="Additional Outputs",
943
+ info="DXF is always included"
944
+ )
945
+
946
+ submit_btn = gr.Button("Generate DXF", variant="primary", size="lg")
947
+
948
+ with gr.Column():
949
+ with gr.Group():
950
+ gr.Markdown("### Generated Files")
951
+ dxf_file = gr.File(label="DXF File", file_types=[".dxf"])
952
+ scale_info = gr.Textbox(label="Scaling Information", interactive=False)
953
+
954
+ with gr.Group():
955
+ gr.Markdown("### Preview Images")
956
+ output_image = gr.Image(label="Annotated Image", visible=False)
957
+ outlines_image = gr.Image(label="Outlines", visible=False)
958
+ mask_image = gr.Image(label="Mask", visible=False)
959
+
960
+ # Dynamic visibility updates
961
+ def toggle_fillet(choice):
962
+ return gr.update(visible=(choice == "On"))
963
+
964
+ def update_outputs_visibility(selected):
965
+ return [
966
+ gr.update(visible="Annotated Image" in selected),
967
+ gr.update(visible="Outlines" in selected),
968
+ gr.update(visible="Mask" in selected)
969
+ ]
970
+
971
+ # Event handlers
972
+ enable_fillet.change(
973
+ fn=toggle_fillet,
974
+ inputs=enable_fillet,
975
+ outputs=fillet_value_mm
976
+ )
977
+
978
+ output_options.change(
979
+ fn=update_outputs_visibility,
980
+ inputs=output_options,
981
+ outputs=[output_image, outlines_image, mask_image]
982
+ )
983
+
984
+ submit_btn.click(
985
+ fn=predict_full_paper,
986
+ inputs=[
987
+ input_image,
988
+ paper_size,
989
+ enable_fillet,
990
+ fillet_value_mm,
991
+ enable_finger_cut,
992
+ output_options
993
+ ],
994
+ outputs=[dxf_file, output_image, outlines_image, mask_image, scale_info]
995
+ )
996
+
997
+ # Example gallery
998
+ with gr.Row():
999
+ gr.Markdown("""
1000
+ ### Tips for Best Results:
1001
+ - Ensure good lighting and clear paper edges
1002
+ - Place object completely on the paper
1003
+ - Avoid shadows that might interfere with detection
1004
+ - Use high contrast between object and paper
1005
+ """)
1006
+
1007
+ demo.launch(share=True)