Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
from pydantic import BaseModel | |
import numpy as np | |
from PIL import Image | |
import io, uuid, os, shutil, timeit | |
from datetime import datetime | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import FileResponse | |
# Import your paper-based prediction function | |
from app import ( | |
predict_full_paper, | |
ReferenceBoxNotDetectedError, | |
FingerCutOverlapError, | |
MultipleObjectsError, | |
NoObjectDetectedError, | |
PaperNotDetectedError | |
) | |
app = FastAPI() | |
# Allow CORS if needed | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
BASE_URL = "https://app.us-central1.run.app" | |
OUTPUT_DIR = os.path.abspath("./outputs") | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
UPDATES_DIR = os.path.abspath("./updates") | |
os.makedirs(UPDATES_DIR, exist_ok=True) | |
# Mount static directories with normal StaticFiles | |
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") | |
app.mount("/updates", StaticFiles(directory=UPDATES_DIR), name="updates") | |
def save_and_build_urls( | |
session_id: str, | |
dxf_path: str, | |
output_image: np.ndarray = None, | |
outlines: np.ndarray = None, | |
mask: np.ndarray = None, | |
endpoint_type: str = "predict", | |
paper_size: str = None, | |
offset_value: float = None, | |
offset_unit: str = "mm", | |
finger_cut: str = "Off" | |
): | |
"""Helper to save all artifacts and return public URLs.""" | |
request_dir = os.path.join(OUTPUT_DIR, session_id) | |
os.makedirs(request_dir, exist_ok=True) | |
# Get current date | |
current_date = datetime.utcnow().strftime("%d-%m-%Y") | |
# Format offset value with underscore instead of dot | |
offset_str = f"{offset_value:.3f}".replace(".", "_") if offset_value is not None else "0_000" | |
# Create descriptive DXF filename | |
if paper_size and offset_value is not None: | |
dxf_fn = f"DXF_{current_date}_{paper_size}_{offset_str}{offset_unit}" | |
if finger_cut == "On": | |
dxf_fn += "_fingercut" | |
dxf_fn += ".dxf" | |
else: | |
dxf_fn = f"DXF_{current_date}.dxf" | |
# Full path for DXF | |
new_dxf_path = os.path.join(request_dir, dxf_fn) | |
# Copy DXF file | |
if os.path.exists(dxf_path): | |
shutil.copy(dxf_path, new_dxf_path) | |
else: | |
# Fallback if your DXF generator returns bytes or string | |
with open(new_dxf_path, "wb") as f: | |
if isinstance(dxf_path, (bytes, bytearray)): | |
f.write(dxf_path) | |
else: | |
f.write(str(dxf_path).encode("utf-8")) | |
urls = { | |
"dxf_url": f"{BASE_URL}/download/{session_id}/{dxf_fn}", | |
} | |
# Save optional images if provided | |
if output_image is not None: | |
out_fn = "annotated_image.jpg" | |
out_path = os.path.join(request_dir, out_fn) | |
Image.fromarray(output_image).save(out_path) | |
urls["output_image_url"] = f"{BASE_URL}/outputs/{session_id}/{out_fn}" | |
if outlines is not None: | |
outlines_fn = "outlines.jpg" | |
outlines_path = os.path.join(request_dir, outlines_fn) | |
Image.fromarray(outlines).save(outlines_path) | |
urls["outlines_url"] = f"{BASE_URL}/outputs/{session_id}/{outlines_fn}" | |
if mask is not None: | |
mask_fn = "mask.jpg" | |
mask_path = os.path.join(request_dir, mask_fn) | |
Image.fromarray(mask).save(mask_path) | |
urls["mask_url"] = f"{BASE_URL}/outputs/{session_id}/{mask_fn}" | |
return urls | |
# Add new endpoint for downloading DXF files | |
async def download_file(session_id: str, filename: str): | |
file_path = os.path.join(OUTPUT_DIR, session_id, filename) | |
if not os.path.exists(file_path): | |
raise HTTPException(status_code=404, detail="File not found") | |
return FileResponse( | |
path=file_path, | |
filename=filename, | |
media_type="application/x-dxf", | |
headers={"Content-Disposition": f"attachment; filename={filename}"} | |
) | |
async def predict_paper_simple_api( | |
file: UploadFile = File(...), | |
paper_size: str = Form(..., regex="^(A4|A3|US Letter)$"), | |
): | |
""" | |
Simple paper-based predict: image + paper size → DXF only | |
Default: 0mm offset, no finger cuts | |
""" | |
session_id = str(uuid.uuid4()) | |
try: | |
img_bytes = await file.read() | |
image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB")) | |
except Exception: | |
raise HTTPException(400, "Invalid image upload") | |
try: | |
start = timeit.default_timer() | |
# Call predict_full_paper with default values | |
dxf_path, ann_img, outlines_img, mask_img, scale_info = predict_full_paper( | |
image=image, | |
paper_size=paper_size, | |
offset_value_mm=0.0, # No offset | |
offset_unit="mm", | |
enable_finger_cut="Off", # No finger cuts | |
selected_outputs=[] # DXF only | |
) | |
elapsed = timeit.default_timer() - start | |
print(f"[{session_id}] predict_paper_simple in {elapsed:.2f}s - {scale_info}") | |
urls = save_and_build_urls( | |
session_id=session_id, | |
dxf_path=dxf_path, | |
endpoint_type="predict_paper_simple", | |
paper_size=paper_size, | |
offset_value=0.0, | |
offset_unit="mm", | |
finger_cut="Off" | |
) | |
# Add scaling info to response | |
urls["scale_info"] = scale_info | |
return urls | |
except (ReferenceBoxNotDetectedError, PaperNotDetectedError): | |
raise HTTPException(status_code=400, detail="Error detecting paper! Please ensure the paper is clearly visible and try again.") | |
except (MultipleObjectsError): | |
raise HTTPException(status_code=400, detail="Multiple objects detected! Please place only a single object on the paper.") | |
except (NoObjectDetectedError): | |
raise HTTPException(status_code=400, detail="No object detected! Please ensure an object is placed on the paper.") | |
except FingerCutOverlapError: | |
raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.") | |
except Exception as e: | |
print(f"Error in predict_paper_simple: {str(e)}") | |
raise HTTPException(status_code=500, detail="Error processing image! Please try again with a clearer image.") | |
async def predict_paper_with_offset_api( | |
file: UploadFile = File(...), | |
paper_size: str = Form(..., regex="^(A4|A3|US Letter)$"), | |
offset_value: float = Form(...), | |
offset_unit: str = Form(..., regex="^(mm|inches)$"), | |
include_images: bool = Form(False) # Optional: include preview images | |
): | |
""" | |
Paper-based predict with offset: image + paper size + offset → DXF + optional images | |
""" | |
session_id = str(uuid.uuid4()) | |
try: | |
img_bytes = await file.read() | |
image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB")) | |
except Exception: | |
raise HTTPException(400, "Invalid image upload") | |
# Validate offset | |
if offset_value < 0: | |
raise HTTPException(400, "Offset value cannot be negative") | |
if offset_value > 50: # Reasonable upper limit | |
raise HTTPException(400, "Offset value too large (max 50)") | |
try: | |
start = timeit.default_timer() | |
# Determine which outputs to include | |
selected_outputs = ["Annotated Image", "Outlines", "Mask"] if include_images else [] | |
dxf_path, ann_img, outlines_img, mask_img, scale_info = predict_full_paper( | |
image=image, | |
paper_size=paper_size, | |
offset_value_mm=offset_value, | |
offset_unit=offset_unit, | |
enable_finger_cut="Off", # No finger cuts | |
selected_outputs=selected_outputs | |
) | |
elapsed = timeit.default_timer() - start | |
print(f"[{session_id}] predict_paper_with_offset in {elapsed:.2f}s - {scale_info}") | |
urls = save_and_build_urls( | |
session_id=session_id, | |
dxf_path=dxf_path, | |
output_image=ann_img if include_images else None, | |
outlines=outlines_img if include_images else None, | |
mask=mask_img if include_images else None, | |
endpoint_type="predict_paper_with_offset", | |
paper_size=paper_size, | |
offset_value=offset_value, | |
offset_unit=offset_unit, | |
finger_cut="Off" | |
) | |
urls["scale_info"] = scale_info | |
return urls | |
except (ReferenceBoxNotDetectedError, PaperNotDetectedError): | |
raise HTTPException(status_code=400, detail="Error detecting paper! Please ensure the paper is clearly visible and try again.") | |
except (MultipleObjectsError): | |
raise HTTPException(status_code=400, detail="Multiple objects detected! Please place only a single object on the paper.") | |
except (NoObjectDetectedError): | |
raise HTTPException(status_code=400, detail="No object detected! Please ensure an object is placed on the paper.") | |
except FingerCutOverlapError: | |
raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.") | |
except Exception as e: | |
print(f"Error in predict_paper_with_offset: {str(e)}") | |
raise HTTPException(status_code=500, detail="Error processing image! Please try again with a clearer image.") | |
async def predict_paper_full_api( | |
file: UploadFile = File(...), | |
paper_size: str = Form(..., regex="^(A4|A3|US Letter)$"), | |
offset_value: float = Form(...), | |
offset_unit: str = Form(..., regex="^(mm|inches)$"), | |
enable_finger_cut: str = Form(..., regex="^(On|Off)$"), | |
include_images: bool = Form(False) # Optional: include preview images | |
): | |
""" | |
Full paper-based predict: image + paper size + offset + finger cuts → DXF + optional images | |
""" | |
session_id = str(uuid.uuid4()) | |
try: | |
img_bytes = await file.read() | |
image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB")) | |
except Exception: | |
raise HTTPException(400, "Invalid image upload") | |
# Validate offset | |
if offset_value < 0: | |
raise HTTPException(400, "Offset value cannot be negative") | |
if offset_value > 50: | |
raise HTTPException(400, "Offset value too large (max 50)") | |
try: | |
start = timeit.default_timer() | |
# Determine which outputs to include | |
selected_outputs = ["Annotated Image", "Outlines", "Mask"] if include_images else [] | |
dxf_path, ann_img, outlines_img, mask_img, scale_info = predict_full_paper( | |
image=image, | |
paper_size=paper_size, | |
offset_value_mm=offset_value, | |
offset_unit=offset_unit, | |
enable_finger_cut=enable_finger_cut, | |
selected_outputs=selected_outputs | |
) | |
elapsed = timeit.default_timer() - start | |
print(f"[{session_id}] predict_paper_full in {elapsed:.2f}s - {scale_info}") | |
urls = save_and_build_urls( | |
session_id=session_id, | |
dxf_path=dxf_path, | |
output_image=ann_img if include_images else None, | |
outlines=outlines_img if include_images else None, | |
mask=mask_img if include_images else None, | |
endpoint_type="predict_paper_full", | |
paper_size=paper_size, | |
offset_value=offset_value, | |
offset_unit=offset_unit, | |
finger_cut=enable_finger_cut | |
) | |
urls["scale_info"] = scale_info | |
return urls | |
except (ReferenceBoxNotDetectedError, PaperNotDetectedError): | |
raise HTTPException(status_code=400, detail="Error detecting paper! Please ensure the paper is clearly visible and try again.") | |
except (MultipleObjectsError): | |
raise HTTPException(status_code=400, detail="Multiple objects detected! Please place only a single object on the paper.") | |
except (NoObjectDetectedError): | |
raise HTTPException(status_code=400, detail="No object detected! Please ensure an object is placed on the paper.") | |
except FingerCutOverlapError: | |
raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.") | |
except Exception as e: | |
print(f"Error in predict_paper_full: {str(e)}") | |
raise HTTPException(status_code=500, detail="Error processing image! Please try again with a clearer image.") | |
# Keep the legacy endpoints for backward compatibility (optional) | |
async def predict1_api( | |
file: UploadFile = File(...) | |
): | |
""" | |
Legacy endpoint - redirects to simple paper-based prediction with A4 default | |
""" | |
return await predict_paper_simple_api(file=file, paper_size="A4") | |
async def predict2_api( | |
file: UploadFile = File(...), | |
enable_fillet: str = Form(..., regex="^(On|Off)$"), | |
fillet_value_mm: float = Form(...) | |
): | |
""" | |
Legacy endpoint - redirects to paper-based prediction with offset | |
Note: Fillet functionality mapped to offset for compatibility | |
""" | |
# Map fillet to offset (you might want to adjust this logic) | |
offset_value = fillet_value_mm if enable_fillet == "On" else 0.0 | |
return await predict_paper_with_offset_api( | |
file=file, | |
paper_size="A4", # Default to A4 | |
offset_value=offset_value, | |
offset_unit="mm", | |
include_images=True | |
) | |
async def predict3_api( | |
file: UploadFile = File(...), | |
enable_fillet: str = Form(..., regex="^(On|Off)$"), | |
fillet_value_mm: float = Form(...), | |
enable_finger_cut: str = Form(..., regex="^(On|Off)$") | |
): | |
""" | |
Legacy endpoint - redirects to full paper-based prediction | |
""" | |
offset_value = fillet_value_mm if enable_fillet == "On" else 0.0 | |
return await predict_paper_full_api( | |
file=file, | |
paper_size="A4", # Default to A4 | |
offset_value=offset_value, | |
offset_unit="mm", | |
enable_finger_cut=enable_finger_cut, | |
include_images=True | |
) | |
async def update_files( | |
output_image: UploadFile = File(...), | |
outlines_image: UploadFile = File(...), | |
mask_image: UploadFile = File(...), | |
dxf_file: UploadFile = File(...) | |
): | |
session_id = str(uuid.uuid4()) | |
update_dir = os.path.join(UPDATES_DIR, session_id) | |
os.makedirs(update_dir, exist_ok=True) | |
try: | |
upload_map = { | |
"output_image": output_image, | |
"outlines_image": outlines_image, | |
"mask_image": mask_image, | |
"dxf_file": dxf_file, | |
} | |
urls = {} | |
for key, up in upload_map.items(): | |
fn = up.filename | |
path = os.path.join(update_dir, fn) | |
with open(path, "wb") as f: | |
shutil.copyfileobj(up.file, f) | |
urls[key] = f"{BASE_URL}/updates/{session_id}/{fn}" | |
return {"session_id": session_id, "uploaded": urls} | |
except Exception as e: | |
raise HTTPException(500, f"Update failed: {e}") | |
from fastapi import Response | |
def health(): | |
return Response(content="OK", status_code=200) | |
def root(): | |
return { | |
"message": "Paper-based DXF Generator API", | |
"endpoints": [ | |
"/predict_paper_simple - Simple DXF generation with paper reference", | |
"/predict_paper_with_offset - DXF generation with contour offset", | |
"/predict_paper_full - Full DXF generation with all features", | |
"/predict1, /predict2, /predict3 - Legacy endpoints (backward compatibility)" | |
], | |
"paper_sizes": ["A4", "A3", "US Letter"], | |
"units": ["mm", "inches"] | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 8080)) | |
print(f"Starting FastAPI server on 0.0.0.0:{port}...") | |
uvicorn.run(app, host="0.0.0.0", port=port) |