Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| from typing import List, Union, Tuple | |
| from PIL import Image | |
| import ezdxf.units | |
| import numpy as np | |
| import torch | |
| from torchvision import transforms | |
| from ultralytics import YOLOWorld, YOLO | |
| from ultralytics.engine.results import Results | |
| from ultralytics.utils.plotting import save_one_box | |
| from transformers import AutoModelForImageSegmentation | |
| import cv2 | |
| import ezdxf | |
| import gradio as gr | |
| import gc | |
| from scalingtestupdated import calculate_scaling_factor | |
| from scipy.interpolate import splprep, splev | |
| from scipy.ndimage import gaussian_filter1d | |
| import json | |
| import time | |
| import signal | |
| from shapely.ops import unary_union | |
| from shapely.geometry import MultiPolygon, GeometryCollection, Polygon, Point | |
| from u2netp import U2NETP | |
| import logging | |
| import shutil | |
| # Initialize logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Create cache directory for models | |
| CACHE_DIR = os.path.join(os.path.dirname(__file__), ".cache") | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # Paper size configurations (in mm) | |
| PAPER_SIZES = { | |
| "A4": {"width": 210, "height": 297}, | |
| "A3": {"width": 297, "height": 420}, | |
| "US Letter": {"width": 215.9, "height": 279.4} | |
| } | |
| # Custom Exception Classes | |
| class TimeoutReachedError(Exception): | |
| pass | |
| class BoundaryOverlapError(Exception): | |
| pass | |
| class TextOverlapError(Exception): | |
| pass | |
| class PaperNotDetectedError(Exception): | |
| """Raised when the paper cannot be detected in the image""" | |
| pass | |
| class MultipleObjectsError(Exception): | |
| """Raised when multiple objects are detected on the paper""" | |
| def __init__(self, message="Multiple objects detected. Please place only a single object on the paper."): | |
| super().__init__(message) | |
| class NoObjectDetectedError(Exception): | |
| """Raised when no object is detected on the paper""" | |
| def __init__(self, message="No object detected on the paper. Please ensure an object is placed on the paper."): | |
| super().__init__(message) | |
| class FingerCutOverlapError(Exception): | |
| """Raised when finger cuts overlap with existing geometry""" | |
| def __init__(self, message="There was an overlap with fingercuts... Please try again to generate dxf."): | |
| super().__init__(message) | |
| # Global model variables for lazy loading | |
| paper_detector_global = None | |
| u2net_global = None | |
| birefnet = None | |
| # Model paths | |
| paper_model_path = os.path.join(CACHE_DIR, "paper_detector.pt") # You'll need to train/provide this | |
| u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth") | |
| # Device configuration | |
| device = "cpu" | |
| torch.set_float32_matmul_precision(["high", "highest"][0]) | |
| def ensure_model_files(): | |
| """Ensure model files are available in cache directory""" | |
| if not os.path.exists(paper_model_path): | |
| if os.path.exists("paper_detector.pt"): | |
| shutil.copy("paper_detector.pt", paper_model_path) | |
| else: | |
| logger.warning("paper_detector.pt model file not found - using fallback detection") | |
| if not os.path.exists(u2net_model_path): | |
| if os.path.exists("u2netp.pth"): | |
| shutil.copy("u2netp.pth", u2net_model_path) | |
| else: | |
| raise FileNotFoundError("u2netp.pth model file not found") | |
| ensure_model_files() | |
| # Lazy loading functions | |
| def get_paper_detector(): | |
| """Lazy load paper detector model""" | |
| global paper_detector_global | |
| if paper_detector_global is None: | |
| logger.info("Loading paper detector model...") | |
| if os.path.exists(paper_model_path): | |
| paper_detector_global = YOLO(paper_model_path) | |
| else: | |
| # Fallback to generic object detection for paper-like rectangles | |
| logger.warning("Using fallback paper detection") | |
| paper_detector_global = None | |
| logger.info("Paper detector loaded successfully") | |
| return paper_detector_global | |
| def get_u2net(): | |
| """Lazy load U2NETP model""" | |
| global u2net_global | |
| if u2net_global is None: | |
| logger.info("Loading U2NETP model...") | |
| u2net_global = U2NETP(3, 1) | |
| u2net_global.load_state_dict(torch.load(u2net_model_path, map_location="cpu")) | |
| u2net_global.to(device) | |
| u2net_global.eval() | |
| logger.info("U2NETP model loaded successfully") | |
| return u2net_global | |
| def load_birefnet_model(): | |
| """Load BiRefNet model from HuggingFace""" | |
| return AutoModelForImageSegmentation.from_pretrained( | |
| 'ZhengPeng7/BiRefNet', | |
| trust_remote_code=True | |
| ) | |
| def get_birefnet(): | |
| """Lazy load BiRefNet model""" | |
| global birefnet | |
| if birefnet is None: | |
| logger.info("Loading BiRefNet model...") | |
| birefnet = load_birefnet_model() | |
| birefnet.to(device) | |
| birefnet.eval() | |
| logger.info("BiRefNet model loaded successfully") | |
| return birefnet | |
| def detect_paper_contour(image: np.ndarray) -> Tuple[np.ndarray, float]: | |
| """ | |
| Detect paper in the image using contour detection as fallback | |
| Returns the paper contour and estimated scaling factor | |
| """ | |
| # Convert to grayscale | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image | |
| # Apply Gaussian blur | |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) | |
| # Edge detection | |
| edges = cv2.Canny(blurred, 50, 150) | |
| # Find contours | |
| contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # Filter contours by area and aspect ratio to find paper-like rectangles | |
| paper_contours = [] | |
| min_area = (image.shape[0] * image.shape[1]) * 0.1 # At least 10% of image | |
| for contour in contours: | |
| area = cv2.contourArea(contour) | |
| if area > min_area: | |
| # Approximate contour to polygon | |
| epsilon = 0.02 * cv2.arcLength(contour, True) | |
| approx = cv2.approxPolyDP(contour, epsilon, True) | |
| # Check if it's roughly rectangular (4 corners) | |
| if len(approx) >= 4: | |
| # Calculate bounding rectangle | |
| rect = cv2.boundingRect(approx) | |
| aspect_ratio = rect[2] / rect[3] # width / height | |
| # Check if aspect ratio matches common paper ratios | |
| # A4: 1.414, A3: 1.414, US Letter: 1.294 | |
| if 0.7 < aspect_ratio < 1.8: # Allow some tolerance | |
| paper_contours.append((contour, area, aspect_ratio)) | |
| if not paper_contours: | |
| raise PaperNotDetectedError("Could not detect paper in the image") | |
| # Select the largest paper-like contour | |
| paper_contours.sort(key=lambda x: x[1], reverse=True) | |
| best_contour = paper_contours[0][0] | |
| return best_contour, 0.0 # Return 0.0 as placeholder scaling factor | |
| def detect_paper_bounds(image: np.ndarray, paper_size: str) -> Tuple[np.ndarray, float]: | |
| """ | |
| Detect paper bounds in the image and calculate scaling factor | |
| """ | |
| try: | |
| paper_detector = get_paper_detector() | |
| if paper_detector is not None: | |
| # Use trained model if available | |
| results = paper_detector.predict(image, conf=0.5) | |
| if not results or len(results) == 0 or len(results[0].boxes) == 0: | |
| logger.warning("Model detection failed, using fallback contour detection") | |
| return detect_paper_contour(image) | |
| # Get the largest detected paper | |
| boxes = results[0].cpu().boxes.xyxy | |
| largest_box = None | |
| max_area = 0 | |
| for box in boxes: | |
| x_min, y_min, x_max, y_max = box | |
| area = (x_max - x_min) * (y_max - y_min) | |
| if area > max_area: | |
| max_area = area | |
| largest_box = box | |
| if largest_box is None: | |
| raise PaperNotDetectedError("No paper detected by model") | |
| # Convert box to contour-like format | |
| x_min, y_min, x_max, y_max = map(int, largest_box) | |
| paper_contour = np.array([ | |
| [[x_min, y_min]], | |
| [[x_max, y_min]], | |
| [[x_max, y_max]], | |
| [[x_min, y_max]] | |
| ]) | |
| else: | |
| # Use fallback contour detection | |
| paper_contour, _ = detect_paper_contour(image) | |
| # Calculate scaling factor based on paper size | |
| scaling_factor = calculate_paper_scaling_factor(paper_contour, paper_size) | |
| return paper_contour, scaling_factor | |
| except Exception as e: | |
| logger.error(f"Error in paper detection: {e}") | |
| raise PaperNotDetectedError(f"Failed to detect paper: {str(e)}") | |
| def calculate_paper_scaling_factor(paper_contour: np.ndarray, paper_size: str) -> float: | |
| """ | |
| Calculate scaling factor based on detected paper dimensions | |
| """ | |
| # Get paper dimensions | |
| paper_dims = PAPER_SIZES[paper_size] | |
| expected_width_mm = paper_dims["width"] | |
| expected_height_mm = paper_dims["height"] | |
| # Calculate bounding rectangle of paper contour | |
| rect = cv2.boundingRect(paper_contour) | |
| detected_width_px = rect[2] | |
| detected_height_px = rect[3] | |
| # Calculate scaling factors for both dimensions | |
| scale_x = expected_width_mm / detected_width_px | |
| scale_y = expected_height_mm / detected_height_px | |
| # Use average of both scales | |
| scaling_factor = (scale_x + scale_y) / 2 | |
| logger.info(f"Paper detection: {detected_width_px}x{detected_height_px} px -> {expected_width_mm}x{expected_height_mm} mm") | |
| logger.info(f"Calculated scaling factor: {scaling_factor:.4f} mm/px") | |
| return scaling_factor | |
| def validate_single_object(mask: np.ndarray, paper_contour: np.ndarray) -> None: | |
| """ | |
| Validate that only a single object is present on the paper | |
| """ | |
| # Create a mask for the paper area | |
| paper_mask = np.zeros(mask.shape[:2], dtype=np.uint8) | |
| cv2.fillPoly(paper_mask, [paper_contour], 255) | |
| # Apply paper mask to object mask | |
| masked_objects = cv2.bitwise_and(mask, paper_mask) | |
| # Find contours of objects within paper bounds | |
| contours, _ = cv2.findContours(masked_objects, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # Filter out very small contours (noise) | |
| min_area = 1000 # Minimum area threshold | |
| significant_contours = [c for c in contours if cv2.contourArea(c) > min_area] | |
| if len(significant_contours) == 0: | |
| raise NoObjectDetectedError() | |
| elif len(significant_contours) > 1: | |
| raise MultipleObjectsError() | |
| logger.info(f"Single object validated: {len(significant_contours)} significant contour(s) found") | |
| def remove_bg_u2netp(image: np.ndarray) -> np.ndarray: | |
| """Remove background using U2NETP model""" | |
| try: | |
| u2net_model = get_u2net() | |
| image_pil = Image.fromarray(image) | |
| transform_u2netp = transforms.Compose([ | |
| transforms.Resize((320, 320)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| input_tensor = transform_u2netp(image_pil).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = u2net_model(input_tensor) | |
| pred = outputs[0] | |
| pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8) | |
| pred_np = pred.squeeze().cpu().numpy() | |
| pred_np = cv2.resize(pred_np, (image_pil.width, image_pil.height)) | |
| pred_np = (pred_np * 255).astype(np.uint8) | |
| return pred_np | |
| except Exception as e: | |
| logger.error(f"Error in U2NETP background removal: {e}") | |
| raise | |
| def remove_bg(image: np.ndarray) -> np.ndarray: | |
| """Remove background using BiRefNet model for main objects""" | |
| try: | |
| birefnet_model = get_birefnet() | |
| transform_image = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| image_pil = Image.fromarray(image) | |
| input_images = transform_image(image_pil).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| preds = birefnet_model(input_images)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| pred_pil = transforms.ToPILImage()(pred) | |
| scale_ratio = 1024 / max(image_pil.size) | |
| scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio)) | |
| return np.array(pred_pil.resize(scaled_size)) | |
| except Exception as e: | |
| logger.error(f"Error in BiRefNet background removal: {e}") | |
| raise | |
| def exclude_paper_area(mask: np.ndarray, paper_contour: np.ndarray, expansion_factor: float = 1.1) -> np.ndarray: | |
| """ | |
| Remove paper area from the mask to focus only on objects | |
| """ | |
| # Create paper mask with slight expansion to ensure complete removal | |
| paper_mask = np.zeros(mask.shape[:2], dtype=np.uint8) | |
| # Expand paper contour slightly | |
| epsilon = expansion_factor * cv2.arcLength(paper_contour, True) | |
| expanded_contour = cv2.approxPolyDP(paper_contour, epsilon, True) | |
| cv2.fillPoly(paper_mask, [expanded_contour], 255) | |
| # Invert paper mask and apply to object mask | |
| paper_mask_inv = cv2.bitwise_not(paper_mask) | |
| result_mask = cv2.bitwise_and(mask, paper_mask_inv) | |
| return result_mask | |
| def resample_contour(contour, edge_radius_px: int = 0): | |
| """Resample contour with radius-aware smoothing and periodic handling.""" | |
| logger.info(f"Starting resample_contour with contour of shape {contour.shape}") | |
| num_points = 1500 | |
| sigma = max(2, int(edge_radius_px) // 4) | |
| if len(contour) < 4: | |
| error_msg = f"Contour must have at least 4 points, but has {len(contour)} points." | |
| logger.error(error_msg) | |
| raise ValueError(error_msg) | |
| try: | |
| contour = contour[:, 0, :] | |
| logger.debug(f"Reshaped contour to shape {contour.shape}") | |
| if not np.array_equal(contour[0], contour[-1]): | |
| contour = np.vstack([contour, contour[0]]) | |
| tck, u = splprep(contour.T, u=None, s=0, per=True) | |
| u_new = np.linspace(u.min(), u.max(), num_points) | |
| x_new, y_new = splev(u_new, tck, der=0) | |
| if sigma > 0: | |
| x_new = gaussian_filter1d(x_new, sigma=sigma, mode='wrap') | |
| y_new = gaussian_filter1d(y_new, sigma=sigma, mode='wrap') | |
| x_new[-1] = x_new[0] | |
| y_new[-1] = y_new[0] | |
| result = np.array([x_new, y_new]).T | |
| logger.info(f"Completed resample_contour with result shape {result.shape}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error in resample_contour: {e}") | |
| raise | |
| def save_dxf_spline(inflated_contours, scaling_factor, height, finger_clearance=False): | |
| """Save contours as DXF splines with optional finger cuts""" | |
| doc = ezdxf.new(units=ezdxf.units.MM) | |
| doc.header["$INSUNITS"] = ezdxf.units.MM | |
| msp = doc.modelspace() | |
| final_polygons_inch = [] | |
| finger_centers = [] | |
| original_polygons = [] | |
| # Scale correction factor | |
| scale_correction = 1.079 | |
| for contour in inflated_contours: | |
| try: | |
| resampled_contour = resample_contour(contour) | |
| points_inch = [(x * scaling_factor, (height - y) * scaling_factor) | |
| for x, y in resampled_contour] | |
| if len(points_inch) < 3: | |
| continue | |
| tool_polygon = build_tool_polygon(points_inch) | |
| original_polygons.append(tool_polygon) | |
| if finger_clearance: | |
| try: | |
| tool_polygon, center = place_finger_cut_adjusted( | |
| tool_polygon, points_inch, finger_centers, final_polygons_inch | |
| ) | |
| except FingerCutOverlapError: | |
| tool_polygon = original_polygons[-1] | |
| exterior_coords = polygon_to_exterior_coords(tool_polygon) | |
| if len(exterior_coords) < 3: | |
| continue | |
| # Apply scale correction | |
| corrected_coords = [(x * scale_correction, y * scale_correction) for x, y in exterior_coords] | |
| msp.add_spline(corrected_coords, degree=3, dxfattribs={"layer": "TOOLS"}) | |
| final_polygons_inch.append(tool_polygon) | |
| except ValueError as e: | |
| logger.warning(f"Skipping contour: {e}") | |
| dxf_filepath = os.path.join("./outputs", "out.dxf") | |
| doc.saveas(dxf_filepath) | |
| return dxf_filepath, final_polygons_inch, original_polygons | |
| def build_tool_polygon(points_inch): | |
| """Build a polygon from inch-converted points""" | |
| return Polygon(points_inch) | |
| def polygon_to_exterior_coords(poly): | |
| """Extract exterior coordinates from polygon""" | |
| logger.info(f"Starting polygon_to_exterior_coords with input geometry type: {poly.geom_type}") | |
| try: | |
| if poly.geom_type == "GeometryCollection" or poly.geom_type == "MultiPolygon": | |
| logger.debug(f"Performing unary_union on {poly.geom_type}") | |
| unified = unary_union(poly) | |
| if unified.is_empty: | |
| logger.warning("unary_union produced an empty geometry; returning empty list") | |
| return [] | |
| if unified.geom_type == "GeometryCollection" or unified.geom_type == "MultiPolygon": | |
| largest = None | |
| max_area = 0.0 | |
| for g in getattr(unified, "geoms", []): | |
| if hasattr(g, "area") and g.area > max_area and hasattr(g, "exterior"): | |
| max_area = g.area | |
| largest = g | |
| if largest is None: | |
| logger.warning("No valid Polygon found in unified geometry; returning empty list") | |
| return [] | |
| poly = largest | |
| else: | |
| poly = unified | |
| if not hasattr(poly, "exterior") or poly.exterior is None: | |
| logger.warning("Input geometry has no exterior ring; returning empty list") | |
| return [] | |
| raw_coords = list(poly.exterior.coords) | |
| total = len(raw_coords) | |
| logger.info(f"Extracted {total} raw exterior coordinates") | |
| if total == 0: | |
| return [] | |
| # Subsample coordinates to at most 100 points | |
| max_pts = 100 | |
| if total > max_pts: | |
| step = total // max_pts | |
| sampled = [raw_coords[i] for i in range(0, total, step)] | |
| if sampled[-1] != raw_coords[-1]: | |
| sampled.append(raw_coords[-1]) | |
| logger.info(f"Downsampled perimeter from {total} to {len(sampled)} points") | |
| return sampled | |
| else: | |
| return raw_coords | |
| except Exception as e: | |
| logger.error(f"Error in polygon_to_exterior_coords: {e}") | |
| return [] | |
| def place_finger_cut_adjusted( | |
| tool_polygon: Polygon, | |
| points_inch: list, | |
| existing_centers: list, | |
| all_polygons: list, | |
| circle_diameter: float = 25.4, | |
| min_gap: float = 0.5, | |
| max_attempts: int = 100 | |
| ) -> Tuple[Polygon, tuple]: | |
| """Place finger cuts with collision avoidance""" | |
| logger.info(f"Starting place_finger_cut_adjusted with {len(points_inch)} input points") | |
| def fallback_solution(): | |
| logger.warning("Using fallback approach for finger cut placement") | |
| fallback_center = points_inch[len(points_inch) // 2] | |
| r = circle_diameter / 2.0 | |
| fallback_circle = Point(fallback_center).buffer(r, resolution=32) | |
| try: | |
| union_poly = tool_polygon.union(fallback_circle) | |
| except Exception as e: | |
| logger.warning(f"Fallback union failed ({e}); trying buffer-union fallback") | |
| union_poly = tool_polygon.buffer(0).union(fallback_circle.buffer(0)) | |
| existing_centers.append(fallback_center) | |
| logger.info(f"Fallback finger cut placed at {fallback_center}") | |
| return union_poly, fallback_center | |
| r = circle_diameter / 2.0 | |
| needed_center_dist = circle_diameter + min_gap | |
| raw_perimeter = polygon_to_exterior_coords(tool_polygon) | |
| if not raw_perimeter: | |
| logger.warning("No valid exterior coords found; using fallback immediately") | |
| return fallback_solution() | |
| if len(raw_perimeter) > 100: | |
| step = len(raw_perimeter) // 100 | |
| perimeter_coords = raw_perimeter[::step] | |
| logger.info(f"Subsampled perimeter from {len(raw_perimeter)} to {len(perimeter_coords)} points") | |
| else: | |
| perimeter_coords = raw_perimeter[:] | |
| indices = list(range(len(perimeter_coords))) | |
| np.random.shuffle(indices) | |
| logger.debug(f"Shuffled perimeter indices for candidate order") | |
| start_time = time.time() | |
| timeout_secs = 5.0 | |
| attempts = 0 | |
| try: | |
| while attempts < max_attempts: | |
| if time.time() - start_time > timeout_secs - 0.1: | |
| logger.warning(f"Approaching timeout after {attempts} attempts") | |
| return fallback_solution() | |
| for idx in indices: | |
| if time.time() - start_time > timeout_secs - 0.05: | |
| logger.warning("Timeout during candidate-point loop") | |
| return fallback_solution() | |
| cx, cy = perimeter_coords[idx] | |
| for dx, dy in [(0, 0), (-min_gap/2, 0), (min_gap/2, 0), (0, -min_gap/2), (0, min_gap/2)]: | |
| candidate_center = (cx + dx, cy + dy) | |
| # Check distance to existing finger centers | |
| too_close_finger = any( | |
| np.hypot(candidate_center[0] - ex, candidate_center[1] - ey) | |
| < needed_center_dist | |
| for (ex, ey) in existing_centers | |
| ) | |
| if too_close_finger: | |
| continue | |
| # Build candidate circle | |
| candidate_circle = Point(candidate_center).buffer(r, resolution=32) | |
| # Must overlap ≥30% with this polygon | |
| try: | |
| inter_area = tool_polygon.intersection(candidate_circle).area | |
| except Exception: | |
| continue | |
| if inter_area < 0.3 * candidate_circle.area: | |
| continue | |
| # Must not intersect other polygons | |
| invalid = False | |
| for other_poly in all_polygons: | |
| if other_poly.equals(tool_polygon): | |
| continue | |
| if other_poly.buffer(min_gap).intersects(candidate_circle) or \ | |
| other_poly.buffer(min_gap).touches(candidate_circle): | |
| invalid = True | |
| break | |
| if invalid: | |
| continue | |
| # Union and return | |
| try: | |
| union_poly = tool_polygon.union(candidate_circle) | |
| if union_poly.geom_type == "MultiPolygon" and len(union_poly.geoms) > 1: | |
| continue | |
| if union_poly.equals(tool_polygon): | |
| continue | |
| except Exception: | |
| continue | |
| existing_centers.append(candidate_center) | |
| logger.info(f"Finger cut placed successfully at {candidate_center} after {attempts} attempts") | |
| return union_poly, candidate_center | |
| attempts += 1 | |
| if attempts >= (max_attempts // 2) and (time.time() - start_time) > timeout_secs * 0.8: | |
| logger.warning(f"Approaching timeout (attempt {attempts})") | |
| return fallback_solution() | |
| logger.warning(f"No valid spot after {max_attempts} attempts, using fallback") | |
| return fallback_solution() | |
| except Exception as e: | |
| logger.error(f"Error in place_finger_cut_adjusted: {e}") | |
| return fallback_solution() | |
| def extract_outlines(binary_image: np.ndarray) -> Tuple[np.ndarray, list]: | |
| """Extract outlines from binary image""" | |
| contours, _ = cv2.findContours( | |
| binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE | |
| ) | |
| outline_image = np.full_like(binary_image, 255) | |
| return outline_image, contours | |
| def round_edges(mask: np.ndarray, radius_mm: float, scaling_factor: float) -> np.ndarray: | |
| """Round mask edges using contour smoothing""" | |
| if radius_mm <= 0 or scaling_factor <= 0: | |
| return mask | |
| radius_px = max(1, int(radius_mm / scaling_factor)) | |
| if np.count_nonzero(mask) < 500: | |
| return cv2.dilate(cv2.erode(mask, np.ones((3,3))), np.ones((3,3))) | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| contours = [c for c in contours if cv2.contourArea(c) > 100] | |
| smoothed_contours = [] | |
| for contour in contours: | |
| try: | |
| resampled = resample_contour(contour, radius_px) | |
| resampled = resampled.astype(np.int32).reshape((-1, 1, 2)) | |
| smoothed_contours.append(resampled) | |
| except Exception as e: | |
| logger.warning(f"Error smoothing contour: {e}") | |
| smoothed_contours.append(contour) | |
| rounded = np.zeros_like(mask) | |
| cv2.drawContours(rounded, smoothed_contours, -1, 255, thickness=cv2.FILLED) | |
| return rounded | |
| def cleanup_memory(): | |
| """Clean up memory after processing""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| logger.info("Memory cleanup completed") | |
| def cleanup_models(): | |
| """Unload models to free memory""" | |
| global paper_detector_global, u2net_global, birefnet | |
| if paper_detector_global is not None: | |
| del paper_detector_global | |
| paper_detector_global = None | |
| if u2net_global is not None: | |
| del u2net_global | |
| u2net_global = None | |
| if birefnet is not None: | |
| del birefnet | |
| birefnet = None | |
| cleanup_memory() | |
| def make_square(img: np.ndarray): | |
| """Make the image square by padding""" | |
| height, width = img.shape[:2] | |
| max_dim = max(height, width) | |
| pad_height = (max_dim - height) // 2 | |
| pad_width = (max_dim - width) // 2 | |
| pad_height_extra = max_dim - height - 2 * pad_height | |
| pad_width_extra = max_dim - width - 2 * pad_width | |
| if len(img.shape) == 3: | |
| padded = np.pad( | |
| img, | |
| ( | |
| (pad_height, pad_height + pad_height_extra), | |
| (pad_width, pad_width + pad_width_extra), | |
| (0, 0), | |
| ), | |
| mode="edge", | |
| ) | |
| else: | |
| padded = np.pad( | |
| img, | |
| ( | |
| (pad_height, pad_height + pad_height_extra), | |
| (pad_width, pad_width + pad_width_extra), | |
| ), | |
| mode="edge", | |
| ) | |
| return padded | |
| def predict_with_paper(image, paper_size, offset, offset_unit, edge_radius, finger_clearance=False): | |
| """Main prediction function using paper as reference""" | |
| if offset_unit == "inches": | |
| offset *= 25.4 | |
| if edge_radius is None or edge_radius == 0: | |
| edge_radius = 0.0001 | |
| if offset < 0: | |
| raise gr.Error("Offset Value Can't be negative") | |
| try: | |
| # Detect paper bounds and calculate scaling factor | |
| paper_contour, scaling_factor = detect_paper_bounds(image, paper_size) | |
| logger.info(f"Paper detected with scaling factor: {scaling_factor:.4f} mm/px") | |
| except PaperNotDetectedError as e: | |
| return ( | |
| None, None, None, None, | |
| f"Error: {str(e)}" | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Error processing image: {str(e)}") | |
| try: | |
| # Remove background from main objects | |
| orig_size = image.shape[:2] | |
| objects_mask = remove_bg(image) | |
| processed_size = objects_mask.shape[:2] | |
| # Resize mask to match original image | |
| objects_mask = cv2.resize(objects_mask, (image.shape[1], image.shape[0])) | |
| # Remove paper area from mask to focus only on objects | |
| objects_mask = exclude_paper_area(objects_mask, paper_contour) | |
| # Validate single object | |
| validate_single_object(objects_mask, paper_contour) | |
| except (MultipleObjectsError, NoObjectDetectedError) as e: | |
| return ( | |
| None, None, None, None, | |
| f"Error: {str(e)}" | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Error in object detection: {str(e)}") | |
| # Apply edge rounding if specified | |
| if edge_radius > 0: | |
| rounded_mask = round_edges(objects_mask, edge_radius, scaling_factor) | |
| else: | |
| rounded_mask = objects_mask.copy() | |
| # Apply dilation for offset | |
| if offset > 0: | |
| offset_pixels = (float(offset) / scaling_factor) * 2 + 1 if scaling_factor else 1 | |
| kernel = np.ones((int(offset_pixels), int(offset_pixels)), np.uint8) | |
| dilated_mask = cv2.dilate(rounded_mask, kernel) | |
| else: | |
| dilated_mask = rounded_mask.copy() | |
| # Save original dilated mask for output | |
| Image.fromarray(dilated_mask).save("./outputs/scaled_mask_original.jpg") | |
| dilated_mask_orig = dilated_mask.copy() | |
| # Extract contours | |
| outlines, contours = extract_outlines(dilated_mask) | |
| try: | |
| # Generate DXF | |
| dxf, finger_polygons, original_polygons = save_dxf_spline( | |
| contours, | |
| scaling_factor, | |
| processed_size[0], | |
| finger_clearance=(finger_clearance == "On") | |
| ) | |
| except FingerCutOverlapError as e: | |
| raise gr.Error(str(e)) | |
| # Create annotated image | |
| shrunked_img_contours = image.copy() | |
| if finger_clearance == "On": | |
| outlines = np.full_like(dilated_mask, 255) | |
| for poly in finger_polygons: | |
| try: | |
| coords = np.array([ | |
| (int(x / scaling_factor), int(processed_size[0] - y / scaling_factor)) | |
| for x, y in poly.exterior.coords | |
| ], np.int32).reshape((-1, 1, 2)) | |
| cv2.drawContours(shrunked_img_contours, [coords], -1, (0, 255, 0), thickness=2) | |
| cv2.drawContours(outlines, [coords], -1, 0, thickness=2) | |
| except Exception as e: | |
| logger.warning(f"Failed to draw finger cut: {e}") | |
| continue | |
| else: | |
| outlines = np.full_like(dilated_mask, 255) | |
| cv2.drawContours(shrunked_img_contours, contours, -1, (0, 255, 0), thickness=2) | |
| cv2.drawContours(outlines, contours, -1, 0, thickness=2) | |
| # Draw paper bounds on annotated image | |
| cv2.drawContours(shrunked_img_contours, [paper_contour], -1, (255, 0, 0), thickness=3) | |
| # Add paper size text | |
| paper_text = f"Paper: {paper_size}" | |
| cv2.putText(shrunked_img_contours, paper_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2) | |
| cleanup_models() | |
| return ( | |
| shrunked_img_contours, | |
| outlines, | |
| dxf, | |
| dilated_mask_orig, | |
| f"Scale: {scaling_factor:.4f} mm/px | Paper: {paper_size}" | |
| ) | |
| def predict_full_paper(image, paper_size, enable_fillet, fillet_value_mm, enable_finger_cut, selected_outputs): | |
| """ | |
| Full prediction function with paper reference and flexible outputs | |
| Returns DXF + conditionally selected additional outputs | |
| """ | |
| radius = fillet_value_mm if enable_fillet == "On" else 0 | |
| finger_flag = "On" if enable_finger_cut == "On" else "Off" | |
| # Always get all outputs from predict_with_paper | |
| ann, outlines, dxf_path, mask, scale_info = predict_with_paper( | |
| image, | |
| paper_size, | |
| offset=0, # No offset for now, can be added as parameter later | |
| offset_unit="mm", | |
| edge_radius=radius, | |
| finger_clearance=finger_flag, | |
| ) | |
| # Return based on selected outputs | |
| return ( | |
| dxf_path, # Always return DXF | |
| ann if "Annotated Image" in selected_outputs else None, | |
| outlines if "Outlines" in selected_outputs else None, | |
| mask if "Mask" in selected_outputs else None, | |
| scale_info # Always return scaling info | |
| ) | |
| # Gradio Interface | |
| if __name__ == "__main__": | |
| os.makedirs("./outputs", exist_ok=True) | |
| with gr.Blocks(title="Paper-Based DXF Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # Paper-Based DXF Generator | |
| Upload an image with a single object placed on paper (A4, A3, or US Letter). | |
| The paper serves as a size reference for accurate DXF generation. | |
| **Instructions:** | |
| 1. Place a single object on paper | |
| 2. Select the correct paper size | |
| 3. Configure options as needed | |
| 4. Click Submit to generate DXF | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| label="Input Image (Object on Paper)", | |
| type="numpy", | |
| height=400 | |
| ) | |
| paper_size = gr.Radio( | |
| choices=["A4", "A3", "US Letter"], | |
| value="A4", | |
| label="Paper Size", | |
| info="Select the paper size used in your image" | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("### Edge Rounding") | |
| enable_fillet = gr.Radio( | |
| choices=["On", "Off"], | |
| value="Off", | |
| label="Enable Edge Rounding", | |
| interactive=True | |
| ) | |
| fillet_value_mm = gr.Slider( | |
| minimum=0, | |
| maximum=20, | |
| step=1, | |
| value=5, | |
| label="Edge Radius (mm)", | |
| visible=False, | |
| interactive=True | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("### Finger Cuts") | |
| enable_finger_cut = gr.Radio( | |
| choices=["On", "Off"], | |
| value="Off", | |
| label="Enable Finger Cuts", | |
| info="Add circular cuts for easier handling" | |
| ) | |
| output_options = gr.CheckboxGroup( | |
| choices=["Annotated Image", "Outlines", "Mask"], | |
| value=[], | |
| label="Additional Outputs", | |
| info="DXF is always included" | |
| ) | |
| submit_btn = gr.Button("Generate DXF", variant="primary", size="lg") | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown("### Generated Files") | |
| dxf_file = gr.File(label="DXF File", file_types=[".dxf"]) | |
| scale_info = gr.Textbox(label="Scaling Information", interactive=False) | |
| with gr.Group(): | |
| gr.Markdown("### Preview Images") | |
| output_image = gr.Image(label="Annotated Image", visible=False) | |
| outlines_image = gr.Image(label="Outlines", visible=False) | |
| mask_image = gr.Image(label="Mask", visible=False) | |
| # Dynamic visibility updates | |
| def toggle_fillet(choice): | |
| return gr.update(visible=(choice == "On")) | |
| def update_outputs_visibility(selected): | |
| return [ | |
| gr.update(visible="Annotated Image" in selected), | |
| gr.update(visible="Outlines" in selected), | |
| gr.update(visible="Mask" in selected) | |
| ] | |
| # Event handlers | |
| enable_fillet.change( | |
| fn=toggle_fillet, | |
| inputs=enable_fillet, | |
| outputs=fillet_value_mm | |
| ) | |
| output_options.change( | |
| fn=update_outputs_visibility, | |
| inputs=output_options, | |
| outputs=[output_image, outlines_image, mask_image] | |
| ) | |
| submit_btn.click( | |
| fn=predict_full_paper, | |
| inputs=[ | |
| input_image, | |
| paper_size, | |
| enable_fillet, | |
| fillet_value_mm, | |
| enable_finger_cut, | |
| output_options | |
| ], | |
| outputs=[dxf_file, output_image, outlines_image, mask_image, scale_info] | |
| ) | |
| # Example gallery | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| ### Tips for Best Results: | |
| - Ensure good lighting and clear paper edges | |
| - Place object completely on the paper | |
| - Avoid shadows that might interfere with detection | |
| - Use high contrast between object and paper | |
| """) | |
| demo.launch(share=True) |