mlbench123 commited on
Commit
6091504
·
verified ·
1 Parent(s): 74e6395

Upload 6 files

Browse files
Files changed (6) hide show
  1. api_server.py +525 -0
  2. app.py.txt +1007 -0
  3. requirements.txt +13 -0
  4. scalingtestupdated.py +184 -0
  5. u2netp.pth +3 -0
  6. u2netp.py +525 -0
api_server.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
+ # from pydantic import BaseModel
3
+ # import numpy as np
4
+ # from PIL import Image
5
+ # import io, uuid, os, shutil, timeit
6
+ # from datetime import datetime
7
+ # from fastapi.staticfiles import StaticFiles
8
+ # from fastapi.middleware.cors import CORSMiddleware
9
+
10
+ # # import your three wrappers
11
+ # from app import predict_simple, predict_middle, predict_full
12
+
13
+ # app = FastAPI()
14
+
15
+ # # allow CORS if needed
16
+ # app.add_middleware(
17
+ # CORSMiddleware,
18
+ # allow_origins=["*"],
19
+ # allow_methods=["*"],
20
+ # allow_headers=["*"],
21
+ # )
22
+
23
+ # BASE_URL = "https://snapanddtraceapp-988917236820.us-central1.run.app"
24
+ # OUTPUT_DIR = os.path.abspath("./outputs")
25
+ # os.makedirs(OUTPUT_DIR, exist_ok=True)
26
+ # app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
27
+
28
+ # UPDATES_DIR = os.path.abspath("./updates")
29
+ # os.makedirs(UPDATES_DIR, exist_ok=True)
30
+ # app.mount("/updates", StaticFiles(directory=UPDATES_DIR), name="updates")
31
+
32
+
33
+ # def save_and_build_urls(
34
+ # session_id: str,
35
+ # output_image: np.ndarray,
36
+ # outlines: np.ndarray,
37
+ # dxf_path: str,
38
+ # mask: np.ndarray
39
+ # ):
40
+ # """Helper to save all four artifacts and return public URLs."""
41
+ # request_dir = os.path.join(OUTPUT_DIR, session_id)
42
+ # os.makedirs(request_dir, exist_ok=True)
43
+
44
+ # # filenames
45
+ # out_fn = "overlay.jpg"
46
+ # outlines_fn = "outlines.jpg"
47
+ # mask_fn = "mask.jpg"
48
+ # current_date = datetime.now().strftime("%d-%m-%Y")
49
+ # dxf_fn = f"out_{current_date}_{session_id}.dxf"
50
+
51
+ # # full paths
52
+ # out_path = os.path.join(request_dir, out_fn)
53
+ # outlines_path = os.path.join(request_dir, outlines_fn)
54
+ # mask_path = os.path.join(request_dir, mask_fn)
55
+ # new_dxf_path = os.path.join(request_dir, dxf_fn)
56
+
57
+ # # save images
58
+ # Image.fromarray(output_image).save(out_path)
59
+ # Image.fromarray(outlines).save(outlines_path)
60
+ # Image.fromarray(mask).save(mask_path)
61
+
62
+ # # copy dx file
63
+ # if os.path.exists(dxf_path):
64
+ # shutil.copy(dxf_path, new_dxf_path)
65
+ # else:
66
+ # # fallback if your DXF generator returns bytes or string
67
+ # with open(new_dxf_path, "wb") as f:
68
+ # if isinstance(dxf_path, (bytes, bytearray)):
69
+ # f.write(dxf_path)
70
+ # else:
71
+ # f.write(str(dxf_path).encode("utf-8"))
72
+
73
+ # # build URLs
74
+ # return {
75
+ # "output_image_url": f"{BASE_URL}/outputs/{session_id}/{out_fn}",
76
+ # "outlines_url": f"{BASE_URL}/outputs/{session_id}/{outlines_fn}",
77
+ # "mask_url": f"{BASE_URL}/outputs/{session_id}/{mask_fn}",
78
+ # "dxf_url": f"{BASE_URL}/outputs/{session_id}/{dxf_fn}",
79
+ # }
80
+
81
+
82
+ # @app.post("/predict1")
83
+ # async def predict1_api(
84
+ # file: UploadFile = File(...)
85
+ # ):
86
+ # """
87
+ # Simple predict: only image → overlay, outlines, mask, DXF
88
+ # """
89
+ # session_id = str(uuid.uuid4())
90
+ # try:
91
+ # img_bytes = await file.read()
92
+ # image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
93
+ # except Exception:
94
+ # raise HTTPException(400, "Invalid image upload")
95
+
96
+ # try:
97
+ # start = timeit.default_timer()
98
+ # out_img, outlines, dxf_path, mask = predict_simple(image)
99
+ # elapsed = timeit.default_timer() - start
100
+ # print(f"[{session_id}] predict1 in {elapsed:.2f}s")
101
+
102
+ # return save_and_build_urls(session_id, out_img, outlines, dxf_path, mask)
103
+
104
+ # except Exception as e:
105
+ # raise HTTPException(500, f"predict1 failed: {e}")
106
+ # except ReferenceBoxNotDetectedError:
107
+ # raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
108
+ # except FingerCutOverlapError:
109
+ # raise HTTPException(status_code=400, detail="There was an overlap with fingercuts!s Please try again to generate dxf.")
110
+
111
+
112
+ # @app.post("/predict2")
113
+ # async def predict2_api(
114
+ # file: UploadFile = File(...),
115
+ # enable_fillet: str = Form(..., regex="^(On|Off)$"),
116
+ # fillet_value_mm: float = Form(...)
117
+ # ):
118
+ # """
119
+ # Middle predict: image + fillet toggle + fillet value → overlay, outlines, mask, DXF
120
+ # """
121
+ # session_id = str(uuid.uuid4())
122
+ # try:
123
+ # img_bytes = await file.read()
124
+ # image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
125
+ # except Exception:
126
+ # raise HTTPException(400, "Invalid image upload")
127
+
128
+ # try:
129
+ # start = timeit.default_timer()
130
+ # out_img, outlines, dxf_path, mask = predict_middle(
131
+ # image, enable_fillet, fillet_value_mm
132
+ # )
133
+ # elapsed = timeit.default_timer() - start
134
+ # print(f"[{session_id}] predict2 in {elapsed:.2f}s")
135
+
136
+ # return save_and_build_urls(session_id, out_img, outlines, dxf_path, mask)
137
+
138
+ # except Exception as e:
139
+ # raise HTTPException(500, f"predict2 failed: {e}")
140
+ # except ReferenceBoxNotDetectedError:
141
+ # raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
142
+ # except FingerCutOverlapError:
143
+ # raise HTTPException(status_code=400, detail="There was an overlap with fingercuts!s Please try again to generate dxf.")
144
+
145
+ # @app.post("/predict3")
146
+ # async def predict3_api(
147
+ # file: UploadFile = File(...),
148
+ # enable_fillet: str = Form(..., regex="^(On|Off)$"),
149
+ # fillet_value_mm: float = Form(...),
150
+ # enable_finger_cut: str = Form(..., regex="^(On|Off)$")
151
+ # ):
152
+ # """
153
+ # Full predict: image + fillet toggle/value + finger-cut toggle → overlay, outlines, mask, DXF
154
+ # """
155
+ # session_id = str(uuid.uuid4())
156
+ # try:
157
+ # img_bytes = await file.read()
158
+ # image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
159
+ # except Exception:
160
+ # raise HTTPException(400, "Invalid image upload")
161
+
162
+ # try:
163
+ # start = timeit.default_timer()
164
+ # out_img, outlines, dxf_path, mask = predict_full(
165
+ # image, enable_fillet, fillet_value_mm, enable_finger_cut
166
+ # )
167
+ # elapsed = timeit.default_timer() - start
168
+ # print(f"[{session_id}] predict3 in {elapsed:.2f}s")
169
+
170
+ # return save_and_build_urls(session_id, out_img, outlines, dxf_path, mask)
171
+
172
+ # except Exception as e:
173
+ # raise HTTPException(500, f"predict3 failed: {e}")
174
+ # except ReferenceBoxNotDetectedError:
175
+ # raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
176
+ # except FingerCutOverlapError:
177
+ # raise HTTPException(status_code=400, detail="There was an overlap with fingercuts!s Please try again to generate dxf.")
178
+
179
+ # @app.post("/update")
180
+ # async def update_files(
181
+ # output_image: UploadFile = File(...),
182
+ # outlines_image: UploadFile = File(...),
183
+ # mask_image: UploadFile = File(...),
184
+ # dxf_file: UploadFile = File(...)
185
+ # ):
186
+ # session_id = str(uuid.uuid4())
187
+ # update_dir = os.path.join(UPDATES_DIR, session_id)
188
+ # os.makedirs(update_dir, exist_ok=True)
189
+
190
+ # try:
191
+ # upload_map = {
192
+ # "output_image": output_image,
193
+ # "outlines_image": outlines_image,
194
+ # "mask_image": mask_image,
195
+ # "dxf_file": dxf_file,
196
+ # }
197
+ # urls = {}
198
+ # for key, up in upload_map.items():
199
+ # fn = up.filename
200
+ # path = os.path.join(update_dir, fn)
201
+ # with open(path, "wb") as f:
202
+ # shutil.copyfileobj(up.file, f)
203
+ # urls[key] = f"{BASE_URL}/updates/{session_id}/{fn}"
204
+
205
+ # return {"session_id": session_id, "uploaded": urls}
206
+
207
+ # except Exception as e:
208
+ # raise HTTPException(500, f"Update failed: {e}")
209
+
210
+
211
+ # if __name__ == "__main__":
212
+ # import uvicorn
213
+ # port = int(os.environ.get("PORT", 8082))
214
+ # print(f"Starting FastAPI server on 0.0.0.0:{port}...")
215
+ # uvicorn.run(app, host="0.0.0.0", port=port)
216
+
217
+
218
+
219
+
220
+
221
+
222
+
223
+
224
+
225
+
226
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
227
+ from pydantic import BaseModel
228
+ import numpy as np
229
+ from PIL import Image
230
+ import io, uuid, os, shutil, timeit
231
+ from datetime import datetime
232
+ from fastapi.staticfiles import StaticFiles
233
+ from fastapi.middleware.cors import CORSMiddleware
234
+ from fastapi.responses import FileResponse
235
+
236
+ # import your three wrappers
237
+ from app import predict_simple, predict_middle, predict_full
238
+
239
+ from app import (
240
+ predict_simple, predict_middle, predict_full,
241
+ ReferenceBoxNotDetectedError,
242
+ FingerCutOverlapError
243
+ )
244
+
245
+
246
+ app = FastAPI()
247
+
248
+ # allow CORS if needed
249
+ app.add_middleware(
250
+ CORSMiddleware,
251
+ allow_origins=["*"],
252
+ allow_methods=["*"],
253
+ allow_headers=["*"],
254
+ )
255
+
256
+ BASE_URL = "https://snapanddtraceapp-988917236820.us-central1.run.app"
257
+
258
+ OUTPUT_DIR = os.path.abspath("./outputs")
259
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
260
+
261
+ UPDATES_DIR = os.path.abspath("./updates")
262
+ os.makedirs(UPDATES_DIR, exist_ok=True)
263
+
264
+ # Mount static directories with normal StaticFiles
265
+ app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
266
+ app.mount("/updates", StaticFiles(directory=UPDATES_DIR), name="updates")
267
+
268
+
269
+ def save_and_build_urls(
270
+ session_id: str,
271
+ output_image: np.ndarray,
272
+ outlines: np.ndarray,
273
+ dxf_path: str,
274
+ mask: np.ndarray,
275
+ endpoint_type: str,
276
+ fillet_value: float = None,
277
+ finger_cut: str = None
278
+ ):
279
+ """Helper to save all four artifacts and return public URLs."""
280
+ request_dir = os.path.join(OUTPUT_DIR, session_id)
281
+ os.makedirs(request_dir, exist_ok=True)
282
+
283
+ # filenames
284
+ out_fn = "overlay.jpg"
285
+ outlines_fn = "outlines.jpg"
286
+ mask_fn = "mask.jpg"
287
+
288
+ # Get current date
289
+ current_date = datetime.utcnow().strftime("%d-%m-%Y")
290
+
291
+
292
+ # Format fillet value with underscore instead of dot
293
+ fillet_str = f"{fillet_value:.2f}".replace(".", "_") if fillet_value is not None else None
294
+
295
+ # Determine DXF filename based on endpoint type
296
+ if endpoint_type == "predict1":
297
+ dxf_fn = f"DXF_{current_date}.dxf"
298
+ elif endpoint_type == "predict2":
299
+ dxf_fn = f"DXF_{current_date}.dxf"
300
+ elif endpoint_type == "predict3":
301
+ dxf_fn = f"DXF_{current_date}.dxf"
302
+
303
+ # full paths
304
+ out_path = os.path.join(request_dir, out_fn)
305
+ outlines_path = os.path.join(request_dir, outlines_fn)
306
+ mask_path = os.path.join(request_dir, mask_fn)
307
+ new_dxf_path = os.path.join(request_dir, dxf_fn)
308
+
309
+ # save images
310
+ Image.fromarray(output_image).save(out_path)
311
+ Image.fromarray(outlines).save(outlines_path)
312
+ Image.fromarray(mask).save(mask_path)
313
+
314
+ # copy dxf file
315
+ if os.path.exists(dxf_path):
316
+ shutil.copy(dxf_path, new_dxf_path)
317
+ else:
318
+ # fallback if your DXF generator returns bytes or string
319
+ with open(new_dxf_path, "wb") as f:
320
+ if isinstance(dxf_path, (bytes, bytearray)):
321
+ f.write(dxf_path)
322
+ else:
323
+ f.write(str(dxf_path).encode("utf-8"))
324
+
325
+ # build URLs with /download prefix for DXF
326
+ return {
327
+ "output_image_url": f"{BASE_URL}/outputs/{session_id}/{out_fn}",
328
+ "outlines_url": f"{BASE_URL}/outputs/{session_id}/{outlines_fn}",
329
+ "mask_url": f"{BASE_URL}/outputs/{session_id}/{mask_fn}",
330
+ "dxf_url": f"{BASE_URL}/download/{session_id}/{dxf_fn}", # Changed to use download endpoint
331
+ }
332
+
333
+ # Add new endpoint for downloading DXF files
334
+ @app.get("/download/{session_id}/{filename}")
335
+ async def download_file(session_id: str, filename: str):
336
+ file_path = os.path.join(OUTPUT_DIR, session_id, filename)
337
+ if not os.path.exists(file_path):
338
+ raise HTTPException(status_code=404, detail="File not found")
339
+
340
+ return FileResponse(
341
+ path=file_path,
342
+ filename=filename,
343
+ media_type="application/x-dxf",
344
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
345
+ )
346
+
347
+
348
+ @app.post("/predict1")
349
+ async def predict1_api(
350
+ file: UploadFile = File(...)
351
+ ):
352
+ """
353
+ Simple predict: only image → overlay, outlines, mask, DXF
354
+ DXF naming format: DXF_DD-MM-YYYY.dxf
355
+ """
356
+ session_id = str(uuid.uuid4())
357
+ try:
358
+ img_bytes = await file.read()
359
+ image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
360
+ except Exception:
361
+ raise HTTPException(400, "Invalid image upload")
362
+
363
+ try:
364
+ start = timeit.default_timer()
365
+ out_img, outlines, dxf_path, mask = predict_simple(image)
366
+ elapsed = timeit.default_timer() - start
367
+ print(f"[{session_id}] predict1 in {elapsed:.2f}s")
368
+
369
+ return save_and_build_urls(
370
+ session_id=session_id,
371
+ output_image=out_img,
372
+ outlines=outlines,
373
+ dxf_path=dxf_path,
374
+ mask=mask,
375
+ endpoint_type="predict1"
376
+ )
377
+
378
+ except ReferenceBoxNotDetectedError:
379
+ raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
380
+ except FingerCutOverlapError:
381
+ raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.")
382
+ except HTTPException as e:
383
+ raise e
384
+ except Exception as e:
385
+ raise HTTPException(status_code=500, detail="Error detecting reference battery! Please try again with a clearer image.")
386
+
387
+ @app.post("/predict2")
388
+ async def predict2_api(
389
+ file: UploadFile = File(...),
390
+ enable_fillet: str = Form(..., regex="^(On|Off)$"),
391
+ fillet_value_mm: float = Form(...)
392
+ ):
393
+ """
394
+ Middle predict: image + fillet toggle + fillet value → overlay, outlines, mask, DXF
395
+ DXF naming format: DXF_DD-MM-YYYY_fillet-value_mm.dxf
396
+ """
397
+ session_id = str(uuid.uuid4())
398
+ try:
399
+ img_bytes = await file.read()
400
+ image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
401
+ except Exception:
402
+ raise HTTPException(400, "Invalid image upload")
403
+
404
+ try:
405
+ start = timeit.default_timer()
406
+ out_img, outlines, dxf_path, mask = predict_middle(
407
+ image, enable_fillet, fillet_value_mm
408
+ )
409
+ elapsed = timeit.default_timer() - start
410
+ print(f"[{session_id}] predict2 in {elapsed:.2f}s")
411
+
412
+ return save_and_build_urls(
413
+ session_id=session_id,
414
+ output_image=out_img,
415
+ outlines=outlines,
416
+ dxf_path=dxf_path,
417
+ mask=mask,
418
+ endpoint_type="predict2",
419
+ fillet_value=fillet_value_mm
420
+ )
421
+
422
+ except ReferenceBoxNotDetectedError:
423
+ raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
424
+ except FingerCutOverlapError:
425
+ raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.")
426
+ except HTTPException as e:
427
+ raise e
428
+ except Exception as e:
429
+ raise HTTPException(status_code=500, detail="Error detecting reference battery! Please try again with a clearer image.")
430
+
431
+
432
+ @app.post("/predict3")
433
+ async def predict3_api(
434
+ file: UploadFile = File(...),
435
+ enable_fillet: str = Form(..., regex="^(On|Off)$"),
436
+ fillet_value_mm: float = Form(...),
437
+ enable_finger_cut: str = Form(..., regex="^(On|Off)$")
438
+ ):
439
+ """
440
+ Full predict: image + fillet toggle/value + finger-cut toggle → overlay, outlines, mask, DXF
441
+ DXF naming format: DXF_DD-MM-YYYY_fillet-value_mm_fingercut-On|Off.dxf
442
+ """
443
+ session_id = str(uuid.uuid4())
444
+ try:
445
+ img_bytes = await file.read()
446
+ image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
447
+ except Exception:
448
+ raise HTTPException(400, "Invalid image upload")
449
+
450
+ try:
451
+ start = timeit.default_timer()
452
+ out_img, outlines, dxf_path, mask = predict_full(
453
+ image, enable_fillet, fillet_value_mm, enable_finger_cut
454
+ )
455
+ elapsed = timeit.default_timer() - start
456
+ print(f"[{session_id}] predict3 in {elapsed:.2f}s")
457
+
458
+ return save_and_build_urls(
459
+ session_id=session_id,
460
+ output_image=out_img,
461
+ outlines=outlines,
462
+ dxf_path=dxf_path,
463
+ mask=mask,
464
+ endpoint_type="predict3",
465
+ fillet_value=fillet_value_mm,
466
+ finger_cut=enable_finger_cut
467
+ )
468
+
469
+ except ReferenceBoxNotDetectedError:
470
+ raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
471
+ except FingerCutOverlapError:
472
+ raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.")
473
+ except HTTPException as e:
474
+ raise e
475
+ except Exception as e:
476
+ raise HTTPException(status_code=500, detail="Error detecting reference battery! Please try again with a clearer image.")
477
+
478
+
479
+ @app.post("/update")
480
+ async def update_files(
481
+ output_image: UploadFile = File(...),
482
+ outlines_image: UploadFile = File(...),
483
+ mask_image: UploadFile = File(...),
484
+ dxf_file: UploadFile = File(...)
485
+ ):
486
+ session_id = str(uuid.uuid4())
487
+ update_dir = os.path.join(UPDATES_DIR, session_id)
488
+ os.makedirs(update_dir, exist_ok=True)
489
+
490
+ try:
491
+ upload_map = {
492
+ "output_image": output_image,
493
+ "outlines_image": outlines_image,
494
+ "mask_image": mask_image,
495
+ "dxf_file": dxf_file,
496
+ }
497
+ urls = {}
498
+ for key, up in upload_map.items():
499
+ fn = up.filename
500
+ path = os.path.join(update_dir, fn)
501
+ with open(path, "wb") as f:
502
+ shutil.copyfileobj(up.file, f)
503
+ urls[key] = f"{BASE_URL}/updates/{session_id}/{fn}"
504
+
505
+ return {"session_id": session_id, "uploaded": urls}
506
+
507
+ except Exception as e:
508
+ raise HTTPException(500, f"Update failed: {e}")
509
+
510
+
511
+ from fastapi import Response
512
+
513
+ @app.get("/health")
514
+ def health():
515
+ return Response(content="OK", status_code=200)
516
+
517
+
518
+ if __name__ == "__main__":
519
+ import uvicorn
520
+ port = int(os.environ.get("PORT", 8080))
521
+ print(f"Starting FastAPI server on 0.0.0.0:{port}...")
522
+ uvicorn.run(app, host="0.0.0.0", port=port)
523
+
524
+
525
+
app.py.txt ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Union, Tuple
4
+ from PIL import Image
5
+ import ezdxf.units
6
+ import numpy as np
7
+ import torch
8
+ from torchvision import transforms
9
+ from ultralytics import YOLOWorld, YOLO
10
+ from ultralytics.engine.results import Results
11
+ from ultralytics.utils.plotting import save_one_box
12
+ from transformers import AutoModelForImageSegmentation
13
+ import cv2
14
+ import ezdxf
15
+ import gradio as gr
16
+ import gc
17
+ from scalingtestupdated import calculate_scaling_factor
18
+ from scipy.interpolate import splprep, splev
19
+ from scipy.ndimage import gaussian_filter1d
20
+ import json
21
+ import time
22
+ import signal
23
+ from shapely.ops import unary_union
24
+ from shapely.geometry import MultiPolygon, GeometryCollection, Polygon, Point
25
+ from u2netp import U2NETP
26
+ import logging
27
+ import shutil
28
+
29
+ # Initialize logging
30
+ logging.basicConfig(level=logging.INFO)
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Create cache directory for models
34
+ CACHE_DIR = os.path.join(os.path.dirname(__file__), ".cache")
35
+ os.makedirs(CACHE_DIR, exist_ok=True)
36
+
37
+ # Paper size configurations (in mm)
38
+ PAPER_SIZES = {
39
+ "A4": {"width": 210, "height": 297},
40
+ "A3": {"width": 297, "height": 420},
41
+ "US Letter": {"width": 215.9, "height": 279.4}
42
+ }
43
+
44
+ # Custom Exception Classes
45
+ class TimeoutReachedError(Exception):
46
+ pass
47
+
48
+ class BoundaryOverlapError(Exception):
49
+ pass
50
+
51
+ class TextOverlapError(Exception):
52
+ pass
53
+
54
+ class PaperNotDetectedError(Exception):
55
+ """Raised when the paper cannot be detected in the image"""
56
+ pass
57
+
58
+ class MultipleObjectsError(Exception):
59
+ """Raised when multiple objects are detected on the paper"""
60
+ def __init__(self, message="Multiple objects detected. Please place only a single object on the paper."):
61
+ super().__init__(message)
62
+
63
+ class NoObjectDetectedError(Exception):
64
+ """Raised when no object is detected on the paper"""
65
+ def __init__(self, message="No object detected on the paper. Please ensure an object is placed on the paper."):
66
+ super().__init__(message)
67
+
68
+ class FingerCutOverlapError(Exception):
69
+ """Raised when finger cuts overlap with existing geometry"""
70
+ def __init__(self, message="There was an overlap with fingercuts... Please try again to generate dxf."):
71
+ super().__init__(message)
72
+
73
+ # Global model variables for lazy loading
74
+ paper_detector_global = None
75
+ u2net_global = None
76
+ birefnet = None
77
+
78
+ # Model paths
79
+ paper_model_path = os.path.join(CACHE_DIR, "paper_detector.pt") # You'll need to train/provide this
80
+ u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
81
+
82
+ # Device configuration
83
+ device = "cpu"
84
+ torch.set_float32_matmul_precision(["high", "highest"][0])
85
+
86
+ def ensure_model_files():
87
+ """Ensure model files are available in cache directory"""
88
+ if not os.path.exists(paper_model_path):
89
+ if os.path.exists("paper_detector.pt"):
90
+ shutil.copy("paper_detector.pt", paper_model_path)
91
+ else:
92
+ logger.warning("paper_detector.pt model file not found - using fallback detection")
93
+
94
+ if not os.path.exists(u2net_model_path):
95
+ if os.path.exists("u2netp.pth"):
96
+ shutil.copy("u2netp.pth", u2net_model_path)
97
+ else:
98
+ raise FileNotFoundError("u2netp.pth model file not found")
99
+
100
+ ensure_model_files()
101
+
102
+ # Lazy loading functions
103
+ def get_paper_detector():
104
+ """Lazy load paper detector model"""
105
+ global paper_detector_global
106
+ if paper_detector_global is None:
107
+ logger.info("Loading paper detector model...")
108
+ if os.path.exists(paper_model_path):
109
+ paper_detector_global = YOLO(paper_model_path)
110
+ else:
111
+ # Fallback to generic object detection for paper-like rectangles
112
+ logger.warning("Using fallback paper detection")
113
+ paper_detector_global = None
114
+ logger.info("Paper detector loaded successfully")
115
+ return paper_detector_global
116
+
117
+ def get_u2net():
118
+ """Lazy load U2NETP model"""
119
+ global u2net_global
120
+ if u2net_global is None:
121
+ logger.info("Loading U2NETP model...")
122
+ u2net_global = U2NETP(3, 1)
123
+ u2net_global.load_state_dict(torch.load(u2net_model_path, map_location="cpu"))
124
+ u2net_global.to(device)
125
+ u2net_global.eval()
126
+ logger.info("U2NETP model loaded successfully")
127
+ return u2net_global
128
+
129
+ def load_birefnet_model():
130
+ """Load BiRefNet model from HuggingFace"""
131
+ return AutoModelForImageSegmentation.from_pretrained(
132
+ 'ZhengPeng7/BiRefNet',
133
+ trust_remote_code=True
134
+ )
135
+
136
+ def get_birefnet():
137
+ """Lazy load BiRefNet model"""
138
+ global birefnet
139
+ if birefnet is None:
140
+ logger.info("Loading BiRefNet model...")
141
+ birefnet = load_birefnet_model()
142
+ birefnet.to(device)
143
+ birefnet.eval()
144
+ logger.info("BiRefNet model loaded successfully")
145
+ return birefnet
146
+
147
+ def detect_paper_contour(image: np.ndarray) -> Tuple[np.ndarray, float]:
148
+ """
149
+ Detect paper in the image using contour detection as fallback
150
+ Returns the paper contour and estimated scaling factor
151
+ """
152
+ # Convert to grayscale
153
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
154
+
155
+ # Apply Gaussian blur
156
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
157
+
158
+ # Edge detection
159
+ edges = cv2.Canny(blurred, 50, 150)
160
+
161
+ # Find contours
162
+ contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
163
+
164
+ # Filter contours by area and aspect ratio to find paper-like rectangles
165
+ paper_contours = []
166
+ min_area = (image.shape[0] * image.shape[1]) * 0.1 # At least 10% of image
167
+
168
+ for contour in contours:
169
+ area = cv2.contourArea(contour)
170
+ if area > min_area:
171
+ # Approximate contour to polygon
172
+ epsilon = 0.02 * cv2.arcLength(contour, True)
173
+ approx = cv2.approxPolyDP(contour, epsilon, True)
174
+
175
+ # Check if it's roughly rectangular (4 corners)
176
+ if len(approx) >= 4:
177
+ # Calculate bounding rectangle
178
+ rect = cv2.boundingRect(approx)
179
+ aspect_ratio = rect[2] / rect[3] # width / height
180
+
181
+ # Check if aspect ratio matches common paper ratios
182
+ # A4: 1.414, A3: 1.414, US Letter: 1.294
183
+ if 0.7 < aspect_ratio < 1.8: # Allow some tolerance
184
+ paper_contours.append((contour, area, aspect_ratio))
185
+
186
+ if not paper_contours:
187
+ raise PaperNotDetectedError("Could not detect paper in the image")
188
+
189
+ # Select the largest paper-like contour
190
+ paper_contours.sort(key=lambda x: x[1], reverse=True)
191
+ best_contour = paper_contours[0][0]
192
+
193
+ return best_contour, 0.0 # Return 0.0 as placeholder scaling factor
194
+
195
+ def detect_paper_bounds(image: np.ndarray, paper_size: str) -> Tuple[np.ndarray, float]:
196
+ """
197
+ Detect paper bounds in the image and calculate scaling factor
198
+ """
199
+ try:
200
+ paper_detector = get_paper_detector()
201
+
202
+ if paper_detector is not None:
203
+ # Use trained model if available
204
+ results = paper_detector.predict(image, conf=0.5)
205
+ if not results or len(results) == 0 or len(results[0].boxes) == 0:
206
+ logger.warning("Model detection failed, using fallback contour detection")
207
+ return detect_paper_contour(image)
208
+
209
+ # Get the largest detected paper
210
+ boxes = results[0].cpu().boxes.xyxy
211
+ largest_box = None
212
+ max_area = 0
213
+
214
+ for box in boxes:
215
+ x_min, y_min, x_max, y_max = box
216
+ area = (x_max - x_min) * (y_max - y_min)
217
+ if area > max_area:
218
+ max_area = area
219
+ largest_box = box
220
+
221
+ if largest_box is None:
222
+ raise PaperNotDetectedError("No paper detected by model")
223
+
224
+ # Convert box to contour-like format
225
+ x_min, y_min, x_max, y_max = map(int, largest_box)
226
+ paper_contour = np.array([
227
+ [[x_min, y_min]],
228
+ [[x_max, y_min]],
229
+ [[x_max, y_max]],
230
+ [[x_min, y_max]]
231
+ ])
232
+
233
+ else:
234
+ # Use fallback contour detection
235
+ paper_contour, _ = detect_paper_contour(image)
236
+
237
+ # Calculate scaling factor based on paper size
238
+ scaling_factor = calculate_paper_scaling_factor(paper_contour, paper_size)
239
+
240
+ return paper_contour, scaling_factor
241
+
242
+ except Exception as e:
243
+ logger.error(f"Error in paper detection: {e}")
244
+ raise PaperNotDetectedError(f"Failed to detect paper: {str(e)}")
245
+
246
+ def calculate_paper_scaling_factor(paper_contour: np.ndarray, paper_size: str) -> float:
247
+ """
248
+ Calculate scaling factor based on detected paper dimensions
249
+ """
250
+ # Get paper dimensions
251
+ paper_dims = PAPER_SIZES[paper_size]
252
+ expected_width_mm = paper_dims["width"]
253
+ expected_height_mm = paper_dims["height"]
254
+
255
+ # Calculate bounding rectangle of paper contour
256
+ rect = cv2.boundingRect(paper_contour)
257
+ detected_width_px = rect[2]
258
+ detected_height_px = rect[3]
259
+
260
+ # Calculate scaling factors for both dimensions
261
+ scale_x = expected_width_mm / detected_width_px
262
+ scale_y = expected_height_mm / detected_height_px
263
+
264
+ # Use average of both scales
265
+ scaling_factor = (scale_x + scale_y) / 2
266
+
267
+ logger.info(f"Paper detection: {detected_width_px}x{detected_height_px} px -> {expected_width_mm}x{expected_height_mm} mm")
268
+ logger.info(f"Calculated scaling factor: {scaling_factor:.4f} mm/px")
269
+
270
+ return scaling_factor
271
+
272
+ def validate_single_object(mask: np.ndarray, paper_contour: np.ndarray) -> None:
273
+ """
274
+ Validate that only a single object is present on the paper
275
+ """
276
+ # Create a mask for the paper area
277
+ paper_mask = np.zeros(mask.shape[:2], dtype=np.uint8)
278
+ cv2.fillPoly(paper_mask, [paper_contour], 255)
279
+
280
+ # Apply paper mask to object mask
281
+ masked_objects = cv2.bitwise_and(mask, paper_mask)
282
+
283
+ # Find contours of objects within paper bounds
284
+ contours, _ = cv2.findContours(masked_objects, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
285
+
286
+ # Filter out very small contours (noise)
287
+ min_area = 1000 # Minimum area threshold
288
+ significant_contours = [c for c in contours if cv2.contourArea(c) > min_area]
289
+
290
+ if len(significant_contours) == 0:
291
+ raise NoObjectDetectedError()
292
+ elif len(significant_contours) > 1:
293
+ raise MultipleObjectsError()
294
+
295
+ logger.info(f"Single object validated: {len(significant_contours)} significant contour(s) found")
296
+
297
+ def remove_bg_u2netp(image: np.ndarray) -> np.ndarray:
298
+ """Remove background using U2NETP model"""
299
+ try:
300
+ u2net_model = get_u2net()
301
+
302
+ image_pil = Image.fromarray(image)
303
+ transform_u2netp = transforms.Compose([
304
+ transforms.Resize((320, 320)),
305
+ transforms.ToTensor(),
306
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
307
+ ])
308
+
309
+ input_tensor = transform_u2netp(image_pil).unsqueeze(0).to(device)
310
+
311
+ with torch.no_grad():
312
+ outputs = u2net_model(input_tensor)
313
+
314
+ pred = outputs[0]
315
+ pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
316
+ pred_np = pred.squeeze().cpu().numpy()
317
+ pred_np = cv2.resize(pred_np, (image_pil.width, image_pil.height))
318
+ pred_np = (pred_np * 255).astype(np.uint8)
319
+
320
+ return pred_np
321
+ except Exception as e:
322
+ logger.error(f"Error in U2NETP background removal: {e}")
323
+ raise
324
+
325
+ def remove_bg(image: np.ndarray) -> np.ndarray:
326
+ """Remove background using BiRefNet model for main objects"""
327
+ try:
328
+ birefnet_model = get_birefnet()
329
+
330
+ transform_image = transforms.Compose([
331
+ transforms.Resize((1024, 1024)),
332
+ transforms.ToTensor(),
333
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
334
+ ])
335
+
336
+ image_pil = Image.fromarray(image)
337
+ input_images = transform_image(image_pil).unsqueeze(0).to(device)
338
+
339
+ with torch.no_grad():
340
+ preds = birefnet_model(input_images)[-1].sigmoid().cpu()
341
+ pred = preds[0].squeeze()
342
+
343
+ pred_pil = transforms.ToPILImage()(pred)
344
+
345
+ scale_ratio = 1024 / max(image_pil.size)
346
+ scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
347
+
348
+ return np.array(pred_pil.resize(scaled_size))
349
+ except Exception as e:
350
+ logger.error(f"Error in BiRefNet background removal: {e}")
351
+ raise
352
+
353
+ def exclude_paper_area(mask: np.ndarray, paper_contour: np.ndarray, expansion_factor: float = 1.1) -> np.ndarray:
354
+ """
355
+ Remove paper area from the mask to focus only on objects
356
+ """
357
+ # Create paper mask with slight expansion to ensure complete removal
358
+ paper_mask = np.zeros(mask.shape[:2], dtype=np.uint8)
359
+
360
+ # Expand paper contour slightly
361
+ epsilon = expansion_factor * cv2.arcLength(paper_contour, True)
362
+ expanded_contour = cv2.approxPolyDP(paper_contour, epsilon, True)
363
+
364
+ cv2.fillPoly(paper_mask, [expanded_contour], 255)
365
+
366
+ # Invert paper mask and apply to object mask
367
+ paper_mask_inv = cv2.bitwise_not(paper_mask)
368
+ result_mask = cv2.bitwise_and(mask, paper_mask_inv)
369
+
370
+ return result_mask
371
+
372
+ def resample_contour(contour, edge_radius_px: int = 0):
373
+ """Resample contour with radius-aware smoothing and periodic handling."""
374
+ logger.info(f"Starting resample_contour with contour of shape {contour.shape}")
375
+
376
+ num_points = 1500
377
+ sigma = max(2, int(edge_radius_px) // 4)
378
+
379
+ if len(contour) < 4:
380
+ error_msg = f"Contour must have at least 4 points, but has {len(contour)} points."
381
+ logger.error(error_msg)
382
+ raise ValueError(error_msg)
383
+
384
+ try:
385
+ contour = contour[:, 0, :]
386
+ logger.debug(f"Reshaped contour to shape {contour.shape}")
387
+
388
+ if not np.array_equal(contour[0], contour[-1]):
389
+ contour = np.vstack([contour, contour[0]])
390
+
391
+ tck, u = splprep(contour.T, u=None, s=0, per=True)
392
+
393
+ u_new = np.linspace(u.min(), u.max(), num_points)
394
+ x_new, y_new = splev(u_new, tck, der=0)
395
+
396
+ if sigma > 0:
397
+ x_new = gaussian_filter1d(x_new, sigma=sigma, mode='wrap')
398
+ y_new = gaussian_filter1d(y_new, sigma=sigma, mode='wrap')
399
+
400
+ x_new[-1] = x_new[0]
401
+ y_new[-1] = y_new[0]
402
+
403
+ result = np.array([x_new, y_new]).T
404
+ logger.info(f"Completed resample_contour with result shape {result.shape}")
405
+ return result
406
+
407
+ except Exception as e:
408
+ logger.error(f"Error in resample_contour: {e}")
409
+ raise
410
+
411
+ def save_dxf_spline(inflated_contours, scaling_factor, height, finger_clearance=False):
412
+ """Save contours as DXF splines with optional finger cuts"""
413
+ doc = ezdxf.new(units=ezdxf.units.MM)
414
+ doc.header["$INSUNITS"] = ezdxf.units.MM
415
+ msp = doc.modelspace()
416
+ final_polygons_inch = []
417
+ finger_centers = []
418
+ original_polygons = []
419
+
420
+ # Scale correction factor
421
+ scale_correction = 1.079
422
+
423
+ for contour in inflated_contours:
424
+ try:
425
+ resampled_contour = resample_contour(contour)
426
+
427
+ points_inch = [(x * scaling_factor, (height - y) * scaling_factor)
428
+ for x, y in resampled_contour]
429
+
430
+ if len(points_inch) < 3:
431
+ continue
432
+
433
+ tool_polygon = build_tool_polygon(points_inch)
434
+ original_polygons.append(tool_polygon)
435
+
436
+ if finger_clearance:
437
+ try:
438
+ tool_polygon, center = place_finger_cut_adjusted(
439
+ tool_polygon, points_inch, finger_centers, final_polygons_inch
440
+ )
441
+ except FingerCutOverlapError:
442
+ tool_polygon = original_polygons[-1]
443
+
444
+ exterior_coords = polygon_to_exterior_coords(tool_polygon)
445
+ if len(exterior_coords) < 3:
446
+ continue
447
+
448
+ # Apply scale correction
449
+ corrected_coords = [(x * scale_correction, y * scale_correction) for x, y in exterior_coords]
450
+
451
+ msp.add_spline(corrected_coords, degree=3, dxfattribs={"layer": "TOOLS"})
452
+ final_polygons_inch.append(tool_polygon)
453
+
454
+ except ValueError as e:
455
+ logger.warning(f"Skipping contour: {e}")
456
+
457
+ dxf_filepath = os.path.join("./outputs", "out.dxf")
458
+ doc.saveas(dxf_filepath)
459
+ return dxf_filepath, final_polygons_inch, original_polygons
460
+
461
+ def build_tool_polygon(points_inch):
462
+ """Build a polygon from inch-converted points"""
463
+ return Polygon(points_inch)
464
+
465
+ def polygon_to_exterior_coords(poly):
466
+ """Extract exterior coordinates from polygon"""
467
+ logger.info(f"Starting polygon_to_exterior_coords with input geometry type: {poly.geom_type}")
468
+
469
+ try:
470
+ if poly.geom_type == "GeometryCollection" or poly.geom_type == "MultiPolygon":
471
+ logger.debug(f"Performing unary_union on {poly.geom_type}")
472
+ unified = unary_union(poly)
473
+ if unified.is_empty:
474
+ logger.warning("unary_union produced an empty geometry; returning empty list")
475
+ return []
476
+
477
+ if unified.geom_type == "GeometryCollection" or unified.geom_type == "MultiPolygon":
478
+ largest = None
479
+ max_area = 0.0
480
+ for g in getattr(unified, "geoms", []):
481
+ if hasattr(g, "area") and g.area > max_area and hasattr(g, "exterior"):
482
+ max_area = g.area
483
+ largest = g
484
+ if largest is None:
485
+ logger.warning("No valid Polygon found in unified geometry; returning empty list")
486
+ return []
487
+ poly = largest
488
+ else:
489
+ poly = unified
490
+
491
+ if not hasattr(poly, "exterior") or poly.exterior is None:
492
+ logger.warning("Input geometry has no exterior ring; returning empty list")
493
+ return []
494
+
495
+ raw_coords = list(poly.exterior.coords)
496
+ total = len(raw_coords)
497
+ logger.info(f"Extracted {total} raw exterior coordinates")
498
+
499
+ if total == 0:
500
+ return []
501
+
502
+ # Subsample coordinates to at most 100 points
503
+ max_pts = 100
504
+ if total > max_pts:
505
+ step = total // max_pts
506
+ sampled = [raw_coords[i] for i in range(0, total, step)]
507
+ if sampled[-1] != raw_coords[-1]:
508
+ sampled.append(raw_coords[-1])
509
+ logger.info(f"Downsampled perimeter from {total} to {len(sampled)} points")
510
+ return sampled
511
+ else:
512
+ return raw_coords
513
+
514
+ except Exception as e:
515
+ logger.error(f"Error in polygon_to_exterior_coords: {e}")
516
+ return []
517
+
518
+ def place_finger_cut_adjusted(
519
+ tool_polygon: Polygon,
520
+ points_inch: list,
521
+ existing_centers: list,
522
+ all_polygons: list,
523
+ circle_diameter: float = 25.4,
524
+ min_gap: float = 0.5,
525
+ max_attempts: int = 100
526
+ ) -> Tuple[Polygon, tuple]:
527
+ """Place finger cuts with collision avoidance"""
528
+ logger.info(f"Starting place_finger_cut_adjusted with {len(points_inch)} input points")
529
+
530
+ def fallback_solution():
531
+ logger.warning("Using fallback approach for finger cut placement")
532
+ fallback_center = points_inch[len(points_inch) // 2]
533
+ r = circle_diameter / 2.0
534
+ fallback_circle = Point(fallback_center).buffer(r, resolution=32)
535
+ try:
536
+ union_poly = tool_polygon.union(fallback_circle)
537
+ except Exception as e:
538
+ logger.warning(f"Fallback union failed ({e}); trying buffer-union fallback")
539
+ union_poly = tool_polygon.buffer(0).union(fallback_circle.buffer(0))
540
+
541
+ existing_centers.append(fallback_center)
542
+ logger.info(f"Fallback finger cut placed at {fallback_center}")
543
+ return union_poly, fallback_center
544
+
545
+ r = circle_diameter / 2.0
546
+ needed_center_dist = circle_diameter + min_gap
547
+
548
+ raw_perimeter = polygon_to_exterior_coords(tool_polygon)
549
+ if not raw_perimeter:
550
+ logger.warning("No valid exterior coords found; using fallback immediately")
551
+ return fallback_solution()
552
+
553
+ if len(raw_perimeter) > 100:
554
+ step = len(raw_perimeter) // 100
555
+ perimeter_coords = raw_perimeter[::step]
556
+ logger.info(f"Subsampled perimeter from {len(raw_perimeter)} to {len(perimeter_coords)} points")
557
+ else:
558
+ perimeter_coords = raw_perimeter[:]
559
+
560
+ indices = list(range(len(perimeter_coords)))
561
+ np.random.shuffle(indices)
562
+ logger.debug(f"Shuffled perimeter indices for candidate order")
563
+
564
+ start_time = time.time()
565
+ timeout_secs = 5.0
566
+
567
+ attempts = 0
568
+ try:
569
+ while attempts < max_attempts:
570
+ if time.time() - start_time > timeout_secs - 0.1:
571
+ logger.warning(f"Approaching timeout after {attempts} attempts")
572
+ return fallback_solution()
573
+
574
+ for idx in indices:
575
+ if time.time() - start_time > timeout_secs - 0.05:
576
+ logger.warning("Timeout during candidate-point loop")
577
+ return fallback_solution()
578
+
579
+ cx, cy = perimeter_coords[idx]
580
+ for dx, dy in [(0, 0), (-min_gap/2, 0), (min_gap/2, 0), (0, -min_gap/2), (0, min_gap/2)]:
581
+ candidate_center = (cx + dx, cy + dy)
582
+
583
+ # Check distance to existing finger centers
584
+ too_close_finger = any(
585
+ np.hypot(candidate_center[0] - ex, candidate_center[1] - ey)
586
+ < needed_center_dist
587
+ for (ex, ey) in existing_centers
588
+ )
589
+ if too_close_finger:
590
+ continue
591
+
592
+ # Build candidate circle
593
+ candidate_circle = Point(candidate_center).buffer(r, resolution=32)
594
+
595
+ # Must overlap ≥30% with this polygon
596
+ try:
597
+ inter_area = tool_polygon.intersection(candidate_circle).area
598
+ except Exception:
599
+ continue
600
+
601
+ if inter_area < 0.3 * candidate_circle.area:
602
+ continue
603
+
604
+ # Must not intersect other polygons
605
+ invalid = False
606
+ for other_poly in all_polygons:
607
+ if other_poly.equals(tool_polygon):
608
+ continue
609
+ if other_poly.buffer(min_gap).intersects(candidate_circle) or \
610
+ other_poly.buffer(min_gap).touches(candidate_circle):
611
+ invalid = True
612
+ break
613
+ if invalid:
614
+ continue
615
+
616
+ # Union and return
617
+ try:
618
+ union_poly = tool_polygon.union(candidate_circle)
619
+ if union_poly.geom_type == "MultiPolygon" and len(union_poly.geoms) > 1:
620
+ continue
621
+ if union_poly.equals(tool_polygon):
622
+ continue
623
+ except Exception:
624
+ continue
625
+
626
+ existing_centers.append(candidate_center)
627
+ logger.info(f"Finger cut placed successfully at {candidate_center} after {attempts} attempts")
628
+ return union_poly, candidate_center
629
+
630
+ attempts += 1
631
+ if attempts >= (max_attempts // 2) and (time.time() - start_time) > timeout_secs * 0.8:
632
+ logger.warning(f"Approaching timeout (attempt {attempts})")
633
+ return fallback_solution()
634
+
635
+ logger.warning(f"No valid spot after {max_attempts} attempts, using fallback")
636
+ return fallback_solution()
637
+
638
+ except Exception as e:
639
+ logger.error(f"Error in place_finger_cut_adjusted: {e}")
640
+ return fallback_solution()
641
+
642
+ def extract_outlines(binary_image: np.ndarray) -> Tuple[np.ndarray, list]:
643
+ """Extract outlines from binary image"""
644
+ contours, _ = cv2.findContours(
645
+ binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
646
+ )
647
+ outline_image = np.full_like(binary_image, 255)
648
+ return outline_image, contours
649
+
650
+ def round_edges(mask: np.ndarray, radius_mm: float, scaling_factor: float) -> np.ndarray:
651
+ """Round mask edges using contour smoothing"""
652
+ if radius_mm <= 0 or scaling_factor <= 0:
653
+ return mask
654
+
655
+ radius_px = max(1, int(radius_mm / scaling_factor))
656
+
657
+ if np.count_nonzero(mask) < 500:
658
+ return cv2.dilate(cv2.erode(mask, np.ones((3,3))), np.ones((3,3)))
659
+
660
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
661
+ contours = [c for c in contours if cv2.contourArea(c) > 100]
662
+ smoothed_contours = []
663
+
664
+ for contour in contours:
665
+ try:
666
+ resampled = resample_contour(contour, radius_px)
667
+ resampled = resampled.astype(np.int32).reshape((-1, 1, 2))
668
+ smoothed_contours.append(resampled)
669
+ except Exception as e:
670
+ logger.warning(f"Error smoothing contour: {e}")
671
+ smoothed_contours.append(contour)
672
+
673
+ rounded = np.zeros_like(mask)
674
+ cv2.drawContours(rounded, smoothed_contours, -1, 255, thickness=cv2.FILLED)
675
+
676
+ return rounded
677
+
678
+ def cleanup_memory():
679
+ """Clean up memory after processing"""
680
+ if torch.cuda.is_available():
681
+ torch.cuda.empty_cache()
682
+ gc.collect()
683
+ logger.info("Memory cleanup completed")
684
+
685
+ def cleanup_models():
686
+ """Unload models to free memory"""
687
+ global paper_detector_global, u2net_global, birefnet
688
+ if paper_detector_global is not None:
689
+ del paper_detector_global
690
+ paper_detector_global = None
691
+ if u2net_global is not None:
692
+ del u2net_global
693
+ u2net_global = None
694
+ if birefnet is not None:
695
+ del birefnet
696
+ birefnet = None
697
+ cleanup_memory()
698
+
699
+ def make_square(img: np.ndarray):
700
+ """Make the image square by padding"""
701
+ height, width = img.shape[:2]
702
+ max_dim = max(height, width)
703
+
704
+ pad_height = (max_dim - height) // 2
705
+ pad_width = (max_dim - width) // 2
706
+
707
+ pad_height_extra = max_dim - height - 2 * pad_height
708
+ pad_width_extra = max_dim - width - 2 * pad_width
709
+
710
+ if len(img.shape) == 3:
711
+ padded = np.pad(
712
+ img,
713
+ (
714
+ (pad_height, pad_height + pad_height_extra),
715
+ (pad_width, pad_width + pad_width_extra),
716
+ (0, 0),
717
+ ),
718
+ mode="edge",
719
+ )
720
+ else:
721
+ padded = np.pad(
722
+ img,
723
+ (
724
+ (pad_height, pad_height + pad_height_extra),
725
+ (pad_width, pad_width + pad_width_extra),
726
+ ),
727
+ mode="edge",
728
+ )
729
+
730
+ return padded
731
+
732
+ def predict_with_paper(image, paper_size, offset, offset_unit, edge_radius, finger_clearance=False):
733
+ """Main prediction function using paper as reference"""
734
+
735
+ if offset_unit == "inches":
736
+ offset *= 25.4
737
+
738
+ if edge_radius is None or edge_radius == 0:
739
+ edge_radius = 0.0001
740
+
741
+ if offset < 0:
742
+ raise gr.Error("Offset Value Can't be negative")
743
+
744
+ try:
745
+ # Detect paper bounds and calculate scaling factor
746
+ paper_contour, scaling_factor = detect_paper_bounds(image, paper_size)
747
+ logger.info(f"Paper detected with scaling factor: {scaling_factor:.4f} mm/px")
748
+
749
+ except PaperNotDetectedError as e:
750
+ return (
751
+ None, None, None, None,
752
+ f"Error: {str(e)}"
753
+ )
754
+ except Exception as e:
755
+ raise gr.Error(f"Error processing image: {str(e)}")
756
+
757
+ try:
758
+ # Remove background from main objects
759
+ orig_size = image.shape[:2]
760
+ objects_mask = remove_bg(image)
761
+ processed_size = objects_mask.shape[:2]
762
+
763
+ # Resize mask to match original image
764
+ objects_mask = cv2.resize(objects_mask, (image.shape[1], image.shape[0]))
765
+
766
+ # Remove paper area from mask to focus only on objects
767
+ objects_mask = exclude_paper_area(objects_mask, paper_contour)
768
+
769
+ # Validate single object
770
+ validate_single_object(objects_mask, paper_contour)
771
+
772
+ except (MultipleObjectsError, NoObjectDetectedError) as e:
773
+ return (
774
+ None, None, None, None,
775
+ f"Error: {str(e)}"
776
+ )
777
+ except Exception as e:
778
+ raise gr.Error(f"Error in object detection: {str(e)}")
779
+
780
+ # Apply edge rounding if specified
781
+ if edge_radius > 0:
782
+ rounded_mask = round_edges(objects_mask, edge_radius, scaling_factor)
783
+ else:
784
+ rounded_mask = objects_mask.copy()
785
+
786
+ # Apply dilation for offset
787
+ if offset > 0:
788
+ offset_pixels = (float(offset) / scaling_factor) * 2 + 1 if scaling_factor else 1
789
+ kernel = np.ones((int(offset_pixels), int(offset_pixels)), np.uint8)
790
+ dilated_mask = cv2.dilate(rounded_mask, kernel)
791
+ else:
792
+ dilated_mask = rounded_mask.copy()
793
+
794
+ # Save original dilated mask for output
795
+ Image.fromarray(dilated_mask).save("./outputs/scaled_mask_original.jpg")
796
+ dilated_mask_orig = dilated_mask.copy()
797
+
798
+ # Extract contours
799
+ outlines, contours = extract_outlines(dilated_mask)
800
+
801
+ try:
802
+ # Generate DXF
803
+ dxf, finger_polygons, original_polygons = save_dxf_spline(
804
+ contours,
805
+ scaling_factor,
806
+ processed_size[0],
807
+ finger_clearance=(finger_clearance == "On")
808
+ )
809
+ except FingerCutOverlapError as e:
810
+ raise gr.Error(str(e))
811
+
812
+ # Create annotated image
813
+ shrunked_img_contours = image.copy()
814
+
815
+ if finger_clearance == "On":
816
+ outlines = np.full_like(dilated_mask, 255)
817
+ for poly in finger_polygons:
818
+ try:
819
+ coords = np.array([
820
+ (int(x / scaling_factor), int(processed_size[0] - y / scaling_factor))
821
+ for x, y in poly.exterior.coords
822
+ ], np.int32).reshape((-1, 1, 2))
823
+
824
+ cv2.drawContours(shrunked_img_contours, [coords], -1, (0, 255, 0), thickness=2)
825
+ cv2.drawContours(outlines, [coords], -1, 0, thickness=2)
826
+ except Exception as e:
827
+ logger.warning(f"Failed to draw finger cut: {e}")
828
+ continue
829
+ else:
830
+ outlines = np.full_like(dilated_mask, 255)
831
+ cv2.drawContours(shrunked_img_contours, contours, -1, (0, 255, 0), thickness=2)
832
+ cv2.drawContours(outlines, contours, -1, 0, thickness=2)
833
+
834
+ # Draw paper bounds on annotated image
835
+ cv2.drawContours(shrunked_img_contours, [paper_contour], -1, (255, 0, 0), thickness=3)
836
+
837
+ # Add paper size text
838
+ paper_text = f"Paper: {paper_size}"
839
+ cv2.putText(shrunked_img_contours, paper_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
840
+
841
+ cleanup_models()
842
+
843
+ return (
844
+ shrunked_img_contours,
845
+ outlines,
846
+ dxf,
847
+ dilated_mask_orig,
848
+ f"Scale: {scaling_factor:.4f} mm/px | Paper: {paper_size}"
849
+ )
850
+
851
+ def predict_full_paper(image, paper_size, enable_fillet, fillet_value_mm, enable_finger_cut, selected_outputs):
852
+ """
853
+ Full prediction function with paper reference and flexible outputs
854
+ Returns DXF + conditionally selected additional outputs
855
+ """
856
+ radius = fillet_value_mm if enable_fillet == "On" else 0
857
+ finger_flag = "On" if enable_finger_cut == "On" else "Off"
858
+
859
+ # Always get all outputs from predict_with_paper
860
+ ann, outlines, dxf_path, mask, scale_info = predict_with_paper(
861
+ image,
862
+ paper_size,
863
+ offset=0, # No offset for now, can be added as parameter later
864
+ offset_unit="mm",
865
+ edge_radius=radius,
866
+ finger_clearance=finger_flag,
867
+ )
868
+
869
+ # Return based on selected outputs
870
+ return (
871
+ dxf_path, # Always return DXF
872
+ ann if "Annotated Image" in selected_outputs else None,
873
+ outlines if "Outlines" in selected_outputs else None,
874
+ mask if "Mask" in selected_outputs else None,
875
+ scale_info # Always return scaling info
876
+ )
877
+
878
+ # Gradio Interface
879
+ if __name__ == "__main__":
880
+ os.makedirs("./outputs", exist_ok=True)
881
+
882
+ with gr.Blocks(title="Paper-Based DXF Generator", theme=gr.themes.Soft()) as demo:
883
+ gr.Markdown("""
884
+ # Paper-Based DXF Generator
885
+
886
+ Upload an image with a single object placed on paper (A4, A3, or US Letter).
887
+ The paper serves as a size reference for accurate DXF generation.
888
+
889
+ **Instructions:**
890
+ 1. Place a single object on paper
891
+ 2. Select the correct paper size
892
+ 3. Configure options as needed
893
+ 4. Click Submit to generate DXF
894
+ """)
895
+
896
+ with gr.Row():
897
+ with gr.Column():
898
+ input_image = gr.Image(
899
+ label="Input Image (Object on Paper)",
900
+ type="numpy",
901
+ height=400
902
+ )
903
+
904
+ paper_size = gr.Radio(
905
+ choices=["A4", "A3", "US Letter"],
906
+ value="A4",
907
+ label="Paper Size",
908
+ info="Select the paper size used in your image"
909
+ )
910
+
911
+ with gr.Group():
912
+ gr.Markdown("### Edge Rounding")
913
+ enable_fillet = gr.Radio(
914
+ choices=["On", "Off"],
915
+ value="Off",
916
+ label="Enable Edge Rounding",
917
+ interactive=True
918
+ )
919
+
920
+ fillet_value_mm = gr.Slider(
921
+ minimum=0,
922
+ maximum=20,
923
+ step=1,
924
+ value=5,
925
+ label="Edge Radius (mm)",
926
+ visible=False,
927
+ interactive=True
928
+ )
929
+
930
+ with gr.Group():
931
+ gr.Markdown("### Finger Cuts")
932
+ enable_finger_cut = gr.Radio(
933
+ choices=["On", "Off"],
934
+ value="Off",
935
+ label="Enable Finger Cuts",
936
+ info="Add circular cuts for easier handling"
937
+ )
938
+
939
+ output_options = gr.CheckboxGroup(
940
+ choices=["Annotated Image", "Outlines", "Mask"],
941
+ value=[],
942
+ label="Additional Outputs",
943
+ info="DXF is always included"
944
+ )
945
+
946
+ submit_btn = gr.Button("Generate DXF", variant="primary", size="lg")
947
+
948
+ with gr.Column():
949
+ with gr.Group():
950
+ gr.Markdown("### Generated Files")
951
+ dxf_file = gr.File(label="DXF File", file_types=[".dxf"])
952
+ scale_info = gr.Textbox(label="Scaling Information", interactive=False)
953
+
954
+ with gr.Group():
955
+ gr.Markdown("### Preview Images")
956
+ output_image = gr.Image(label="Annotated Image", visible=False)
957
+ outlines_image = gr.Image(label="Outlines", visible=False)
958
+ mask_image = gr.Image(label="Mask", visible=False)
959
+
960
+ # Dynamic visibility updates
961
+ def toggle_fillet(choice):
962
+ return gr.update(visible=(choice == "On"))
963
+
964
+ def update_outputs_visibility(selected):
965
+ return [
966
+ gr.update(visible="Annotated Image" in selected),
967
+ gr.update(visible="Outlines" in selected),
968
+ gr.update(visible="Mask" in selected)
969
+ ]
970
+
971
+ # Event handlers
972
+ enable_fillet.change(
973
+ fn=toggle_fillet,
974
+ inputs=enable_fillet,
975
+ outputs=fillet_value_mm
976
+ )
977
+
978
+ output_options.change(
979
+ fn=update_outputs_visibility,
980
+ inputs=output_options,
981
+ outputs=[output_image, outlines_image, mask_image]
982
+ )
983
+
984
+ submit_btn.click(
985
+ fn=predict_full_paper,
986
+ inputs=[
987
+ input_image,
988
+ paper_size,
989
+ enable_fillet,
990
+ fillet_value_mm,
991
+ enable_finger_cut,
992
+ output_options
993
+ ],
994
+ outputs=[dxf_file, output_image, outlines_image, mask_image, scale_info]
995
+ )
996
+
997
+ # Example gallery
998
+ with gr.Row():
999
+ gr.Markdown("""
1000
+ ### Tips for Best Results:
1001
+ - Ensure good lighting and clear paper edges
1002
+ - Place object completely on the paper
1003
+ - Avoid shadows that might interfere with detection
1004
+ - Use high contrast between object and paper
1005
+ """)
1006
+
1007
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ transformers==4.48.3
3
+ ultralytics==8.3.9
4
+ pydantic==2.10.6
5
+ ezdxf==1.3.5
6
+ gradio==5.15.0
7
+ kornia==0.8.0
8
+ timm==1.0.14
9
+ einops==0.8.1
10
+ torchvision==0.20.1
11
+ torch==2.5.1
12
+ torchaudio==2.5.1
13
+ shapely
scalingtestupdated.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import argparse
5
+ from typing import Union
6
+ from matplotlib import pyplot as plt
7
+
8
+ class ScalingSquareDetector:
9
+ def __init__(self, feature_detector="ORB", debug=False):
10
+ """
11
+ Initialize the detector with the desired feature matching algorithm.
12
+ :param feature_detector: "ORB" or "SIFT" (default is "ORB").
13
+ :param debug: If True, saves intermediate images for debugging.
14
+ """
15
+ self.feature_detector = feature_detector
16
+ self.debug = debug
17
+ self.detector = self._initialize_detector()
18
+
19
+ def _initialize_detector(self):
20
+ """
21
+ Initialize the chosen feature detector.
22
+ :return: OpenCV detector object.
23
+ """
24
+ if self.feature_detector.upper() == "SIFT":
25
+ return cv2.SIFT_create()
26
+ elif self.feature_detector.upper() == "ORB":
27
+ return cv2.ORB_create()
28
+ else:
29
+ raise ValueError("Invalid feature detector. Choose 'ORB' or 'SIFT'.")
30
+
31
+ def find_scaling_square(
32
+ self, target_image, known_size_mm, roi_margin=30
33
+ ):
34
+ """
35
+ Detect the scaling square in the target image based on the reference image.
36
+ :param reference_image_path: Path to the reference image of the square.
37
+ :param target_image_path: Path to the target image containing the square.
38
+ :param known_size_mm: Physical size of the square in millimeters.
39
+ :param roi_margin: Margin to expand the ROI around the detected square (in pixels).
40
+ :return: Scaling factor (mm per pixel).
41
+ """
42
+ contours, _ = cv2.findContours(
43
+ target_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
44
+ )
45
+
46
+ if not contours:
47
+ raise ValueError("No contours found in the cropped ROI.")
48
+
49
+ # # Select the largest square-like contour
50
+ print(f"No of contours: {len(contours)}")
51
+ largest_square = None
52
+ # largest_square_area = 0
53
+ # for contour in contours:
54
+ # x_c, y_c, w_c, h_c = cv2.boundingRect(contour)
55
+ # aspect_ratio = w_c / float(h_c)
56
+ # if 0.9 <= aspect_ratio <= 1.1:
57
+ # peri = cv2.arcLength(contour, True)
58
+ # approx = cv2.approxPolyDP(contour, 0.02 * peri, True)
59
+ # if len(approx) == 4:
60
+ # area = cv2.contourArea(contour)
61
+ # if area > largest_square_area:
62
+ # largest_square = contour
63
+ # largest_square_area = area
64
+
65
+ for contour in contours:
66
+ largest_square = contour
67
+
68
+ # if largest_square is None:
69
+ # raise ValueError("No square-like contour found in the ROI.")
70
+
71
+ # Draw the largest contour on the original image
72
+ target_image_color = cv2.cvtColor(target_image, cv2.COLOR_GRAY2BGR)
73
+ cv2.drawContours(
74
+ target_image_color, largest_square, -1, (255, 0, 0), 3
75
+ )
76
+
77
+ # if self.debug:
78
+ cv2.imwrite("largest_contour.jpg", target_image_color)
79
+
80
+ # Calculate the bounding rectangle of the largest contour
81
+ x, y, w, h = cv2.boundingRect(largest_square)
82
+ square_width_px = w
83
+ square_height_px = h
84
+ print(f"Reference object size: {known_size_mm} mm")
85
+ print(f"width: {square_width_px} px")
86
+ print(f"height: {square_height_px} px")
87
+
88
+ # Calculate the scaling factor
89
+ avg_square_size_px = (square_width_px + square_height_px) / 2
90
+ print(f"avg square size: {avg_square_size_px} px")
91
+ scaling_factor = known_size_mm / avg_square_size_px # mm per pixel
92
+ print(f"scaling factor: {scaling_factor} mm per pixel")
93
+
94
+ return scaling_factor #, square_height_px, square_width_px, roi_binary
95
+
96
+ def draw_debug_images(self, output_folder):
97
+ """
98
+ Save debug images if enabled.
99
+ :param output_folder: Directory to save debug images.
100
+ """
101
+ if self.debug:
102
+ if not os.path.exists(output_folder):
103
+ os.makedirs(output_folder)
104
+ debug_images = ["largest_contour.jpg"]
105
+ for img_name in debug_images:
106
+ if os.path.exists(img_name):
107
+ os.rename(img_name, os.path.join(output_folder, img_name))
108
+
109
+
110
+ def calculate_scaling_factor(
111
+ target_image,
112
+ reference_obj_size_mm,
113
+ feature_detector="ORB",
114
+ debug=False,
115
+ roi_margin=30,
116
+ ):
117
+ # Initialize detector
118
+ detector = ScalingSquareDetector(feature_detector=feature_detector, debug=debug)
119
+
120
+ # Find scaling square and calculate scaling factor
121
+ scaling_factor = detector.find_scaling_square(
122
+ target_image=target_image,
123
+ known_size_mm=reference_obj_size_mm,
124
+ roi_margin=roi_margin,
125
+ )
126
+
127
+ # Save debug images
128
+ if debug:
129
+ detector.draw_debug_images("debug_outputs")
130
+
131
+ return scaling_factor
132
+
133
+
134
+ # Example usage:
135
+ if __name__ == "__main__":
136
+ import os
137
+ from PIL import Image
138
+ from ultralytics import YOLO
139
+ from app import yolo_detect, shrink_bbox
140
+ from ultralytics.utils.plotting import save_one_box
141
+
142
+ for idx, file in enumerate(os.listdir("./sample_images")):
143
+ img = np.array(Image.open(os.path.join("./sample_images", file)))
144
+ img = yolo_detect(img, ['box'])
145
+ model = YOLO("./best.pt")
146
+ res = model.predict(img, conf=0.6)
147
+
148
+ box_img = save_one_box(res[0].cpu().boxes.xyxy, im=res[0].orig_img, save=False)
149
+ # img = shrink_bbox(box_img, 1.20)
150
+ cv2.imwrite(f"./outputs/{idx}_{file}", box_img)
151
+
152
+ print("File: ",f"./outputs/{idx}_{file}")
153
+ try:
154
+
155
+ scaling_factor = calculate_scaling_factor(
156
+ target_image=box_img,
157
+ known_square_size_mm=20,
158
+ feature_detector="ORB",
159
+ debug=False,
160
+ roi_margin=90,
161
+ )
162
+ # cv2.imwrite(f"./outputs/{idx}_binary_{file}", roi_binary)
163
+
164
+ # Square size in mm
165
+ # square_size_mm = 12.7
166
+
167
+ # # Compute the calculated scaling factors and compare
168
+ # calculated_scaling_factor = square_size_mm / height_px
169
+ # discrepancy = abs(calculated_scaling_factor - scaling_factor)
170
+ # import pprint
171
+ # pprint.pprint({
172
+ # "height_px": height_px,
173
+ # "width_px": width_px,
174
+ # "given_scaling_factor": scaling_factor,
175
+ # "calculated_scaling_factor": calculated_scaling_factor,
176
+ # "discrepancy": discrepancy,
177
+ # })
178
+
179
+
180
+ print(f"Scaling Factor (mm per pixel): {scaling_factor:.6f}")
181
+ except Exception as e:
182
+ from traceback import print_exc
183
+ print(print_exc())
184
+ print(f"Error: {e}")
u2netp.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7567cde013fb64813973ce6e1ecc25a80c05c3ca7adbc5a54f3c3d90991b854
3
+ size 4683258
u2netp.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class REBNCONV(nn.Module):
6
+ def __init__(self,in_ch=3,out_ch=3,dirate=1):
7
+ super(REBNCONV,self).__init__()
8
+
9
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self,x):
14
+
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
+ def _upsample_like(src,tar):
22
+
23
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module):#UNet07DRES(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
32
+ super(RSU7,self).__init__()
33
+
34
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
35
+
36
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
37
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
38
+
39
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
40
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
41
+
42
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
43
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
+
45
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
+
48
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+
53
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
54
+
55
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
56
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
57
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
58
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
59
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
61
+
62
+ def forward(self,x):
63
+
64
+ hx = x
65
+ hxin = self.rebnconvin(hx)
66
+
67
+ hx1 = self.rebnconv1(hxin)
68
+ hx = self.pool1(hx1)
69
+
70
+ hx2 = self.rebnconv2(hx)
71
+ hx = self.pool2(hx2)
72
+
73
+ hx3 = self.rebnconv3(hx)
74
+ hx = self.pool3(hx3)
75
+
76
+ hx4 = self.rebnconv4(hx)
77
+ hx = self.pool4(hx4)
78
+
79
+ hx5 = self.rebnconv5(hx)
80
+ hx = self.pool5(hx5)
81
+
82
+ hx6 = self.rebnconv6(hx)
83
+
84
+ hx7 = self.rebnconv7(hx6)
85
+
86
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
87
+ hx6dup = _upsample_like(hx6d,hx5)
88
+
89
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
90
+ hx5dup = _upsample_like(hx5d,hx4)
91
+
92
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
93
+ hx4dup = _upsample_like(hx4d,hx3)
94
+
95
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
96
+ hx3dup = _upsample_like(hx3d,hx2)
97
+
98
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
99
+ hx2dup = _upsample_like(hx2d,hx1)
100
+
101
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
102
+
103
+ return hx1d + hxin
104
+
105
+ ### RSU-6 ###
106
+ class RSU6(nn.Module):#UNet06DRES(nn.Module):
107
+
108
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
109
+ super(RSU6,self).__init__()
110
+
111
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
112
+
113
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
114
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
115
+
116
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
117
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
118
+
119
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
120
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
+
122
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+
127
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
128
+
129
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
130
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
131
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
134
+
135
+ def forward(self,x):
136
+
137
+ hx = x
138
+
139
+ hxin = self.rebnconvin(hx)
140
+
141
+ hx1 = self.rebnconv1(hxin)
142
+ hx = self.pool1(hx1)
143
+
144
+ hx2 = self.rebnconv2(hx)
145
+ hx = self.pool2(hx2)
146
+
147
+ hx3 = self.rebnconv3(hx)
148
+ hx = self.pool3(hx3)
149
+
150
+ hx4 = self.rebnconv4(hx)
151
+ hx = self.pool4(hx4)
152
+
153
+ hx5 = self.rebnconv5(hx)
154
+
155
+ hx6 = self.rebnconv6(hx5)
156
+
157
+
158
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
159
+ hx5dup = _upsample_like(hx5d,hx4)
160
+
161
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
162
+ hx4dup = _upsample_like(hx4d,hx3)
163
+
164
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
165
+ hx3dup = _upsample_like(hx3d,hx2)
166
+
167
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
168
+ hx2dup = _upsample_like(hx2d,hx1)
169
+
170
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
171
+
172
+ return hx1d + hxin
173
+
174
+ ### RSU-5 ###
175
+ class RSU5(nn.Module):#UNet05DRES(nn.Module):
176
+
177
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
178
+ super(RSU5,self).__init__()
179
+
180
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
181
+
182
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
183
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
187
+
188
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
189
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
+
191
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
+
193
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
194
+
195
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
196
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
197
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
199
+
200
+ def forward(self,x):
201
+
202
+ hx = x
203
+
204
+ hxin = self.rebnconvin(hx)
205
+
206
+ hx1 = self.rebnconv1(hxin)
207
+ hx = self.pool1(hx1)
208
+
209
+ hx2 = self.rebnconv2(hx)
210
+ hx = self.pool2(hx2)
211
+
212
+ hx3 = self.rebnconv3(hx)
213
+ hx = self.pool3(hx3)
214
+
215
+ hx4 = self.rebnconv4(hx)
216
+
217
+ hx5 = self.rebnconv5(hx4)
218
+
219
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
220
+ hx4dup = _upsample_like(hx4d,hx3)
221
+
222
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
223
+ hx3dup = _upsample_like(hx3d,hx2)
224
+
225
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
226
+ hx2dup = _upsample_like(hx2d,hx1)
227
+
228
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
229
+
230
+ return hx1d + hxin
231
+
232
+ ### RSU-4 ###
233
+ class RSU4(nn.Module):#UNet04DRES(nn.Module):
234
+
235
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
236
+ super(RSU4,self).__init__()
237
+
238
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
239
+
240
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
241
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
242
+
243
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
244
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
245
+
246
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
247
+
248
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
249
+
250
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
251
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
252
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
253
+
254
+ def forward(self,x):
255
+
256
+ hx = x
257
+
258
+ hxin = self.rebnconvin(hx)
259
+
260
+ hx1 = self.rebnconv1(hxin)
261
+ hx = self.pool1(hx1)
262
+
263
+ hx2 = self.rebnconv2(hx)
264
+ hx = self.pool2(hx2)
265
+
266
+ hx3 = self.rebnconv3(hx)
267
+
268
+ hx4 = self.rebnconv4(hx3)
269
+
270
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
271
+ hx3dup = _upsample_like(hx3d,hx2)
272
+
273
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
274
+ hx2dup = _upsample_like(hx2d,hx1)
275
+
276
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
277
+
278
+ return hx1d + hxin
279
+
280
+ ### RSU-4F ###
281
+ class RSU4F(nn.Module):#UNet04FRES(nn.Module):
282
+
283
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
284
+ super(RSU4F,self).__init__()
285
+
286
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
287
+
288
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
289
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
290
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
291
+
292
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
293
+
294
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
295
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
296
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
297
+
298
+ def forward(self,x):
299
+
300
+ hx = x
301
+
302
+ hxin = self.rebnconvin(hx)
303
+
304
+ hx1 = self.rebnconv1(hxin)
305
+ hx2 = self.rebnconv2(hx1)
306
+ hx3 = self.rebnconv3(hx2)
307
+
308
+ hx4 = self.rebnconv4(hx3)
309
+
310
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
311
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
312
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
313
+
314
+ return hx1d + hxin
315
+
316
+
317
+ ##### U^2-Net ####
318
+ class U2NET(nn.Module):
319
+
320
+ def __init__(self,in_ch=3,out_ch=1):
321
+ super(U2NET,self).__init__()
322
+
323
+ self.stage1 = RSU7(in_ch,32,64)
324
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
325
+
326
+ self.stage2 = RSU6(64,32,128)
327
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
328
+
329
+ self.stage3 = RSU5(128,64,256)
330
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
331
+
332
+ self.stage4 = RSU4(256,128,512)
333
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
334
+
335
+ self.stage5 = RSU4F(512,256,512)
336
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
337
+
338
+ self.stage6 = RSU4F(512,256,512)
339
+
340
+ # decoder
341
+ self.stage5d = RSU4F(1024,256,512)
342
+ self.stage4d = RSU4(1024,128,256)
343
+ self.stage3d = RSU5(512,64,128)
344
+ self.stage2d = RSU6(256,32,64)
345
+ self.stage1d = RSU7(128,16,64)
346
+
347
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
348
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
349
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
350
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
351
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
352
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
353
+
354
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
355
+
356
+ def forward(self,x):
357
+
358
+ hx = x
359
+
360
+ #stage 1
361
+ hx1 = self.stage1(hx)
362
+ hx = self.pool12(hx1)
363
+
364
+ #stage 2
365
+ hx2 = self.stage2(hx)
366
+ hx = self.pool23(hx2)
367
+
368
+ #stage 3
369
+ hx3 = self.stage3(hx)
370
+ hx = self.pool34(hx3)
371
+
372
+ #stage 4
373
+ hx4 = self.stage4(hx)
374
+ hx = self.pool45(hx4)
375
+
376
+ #stage 5
377
+ hx5 = self.stage5(hx)
378
+ hx = self.pool56(hx5)
379
+
380
+ #stage 6
381
+ hx6 = self.stage6(hx)
382
+ hx6up = _upsample_like(hx6,hx5)
383
+
384
+ #-------------------- decoder --------------------
385
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
386
+ hx5dup = _upsample_like(hx5d,hx4)
387
+
388
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
389
+ hx4dup = _upsample_like(hx4d,hx3)
390
+
391
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
392
+ hx3dup = _upsample_like(hx3d,hx2)
393
+
394
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
395
+ hx2dup = _upsample_like(hx2d,hx1)
396
+
397
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
398
+
399
+
400
+ #side output
401
+ d1 = self.side1(hx1d)
402
+
403
+ d2 = self.side2(hx2d)
404
+ d2 = _upsample_like(d2,d1)
405
+
406
+ d3 = self.side3(hx3d)
407
+ d3 = _upsample_like(d3,d1)
408
+
409
+ d4 = self.side4(hx4d)
410
+ d4 = _upsample_like(d4,d1)
411
+
412
+ d5 = self.side5(hx5d)
413
+ d5 = _upsample_like(d5,d1)
414
+
415
+ d6 = self.side6(hx6)
416
+ d6 = _upsample_like(d6,d1)
417
+
418
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
419
+
420
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
421
+
422
+ ### U^2-Net small ###
423
+ class U2NETP(nn.Module):
424
+
425
+ def __init__(self,in_ch=3,out_ch=1):
426
+ super(U2NETP,self).__init__()
427
+
428
+ self.stage1 = RSU7(in_ch,16,64)
429
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
430
+
431
+ self.stage2 = RSU6(64,16,64)
432
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
433
+
434
+ self.stage3 = RSU5(64,16,64)
435
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
436
+
437
+ self.stage4 = RSU4(64,16,64)
438
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
439
+
440
+ self.stage5 = RSU4F(64,16,64)
441
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
442
+
443
+ self.stage6 = RSU4F(64,16,64)
444
+
445
+ # decoder
446
+ self.stage5d = RSU4F(128,16,64)
447
+ self.stage4d = RSU4(128,16,64)
448
+ self.stage3d = RSU5(128,16,64)
449
+ self.stage2d = RSU6(128,16,64)
450
+ self.stage1d = RSU7(128,16,64)
451
+
452
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
453
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
454
+ self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
455
+ self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
456
+ self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
457
+ self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
458
+
459
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
460
+
461
+ def forward(self,x):
462
+
463
+ hx = x
464
+
465
+ #stage 1
466
+ hx1 = self.stage1(hx)
467
+ hx = self.pool12(hx1)
468
+
469
+ #stage 2
470
+ hx2 = self.stage2(hx)
471
+ hx = self.pool23(hx2)
472
+
473
+ #stage 3
474
+ hx3 = self.stage3(hx)
475
+ hx = self.pool34(hx3)
476
+
477
+ #stage 4
478
+ hx4 = self.stage4(hx)
479
+ hx = self.pool45(hx4)
480
+
481
+ #stage 5
482
+ hx5 = self.stage5(hx)
483
+ hx = self.pool56(hx5)
484
+
485
+ #stage 6
486
+ hx6 = self.stage6(hx)
487
+ hx6up = _upsample_like(hx6,hx5)
488
+
489
+ #decoder
490
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
491
+ hx5dup = _upsample_like(hx5d,hx4)
492
+
493
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
494
+ hx4dup = _upsample_like(hx4d,hx3)
495
+
496
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
497
+ hx3dup = _upsample_like(hx3d,hx2)
498
+
499
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
500
+ hx2dup = _upsample_like(hx2d,hx1)
501
+
502
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
503
+
504
+
505
+ #side output
506
+ d1 = self.side1(hx1d)
507
+
508
+ d2 = self.side2(hx2d)
509
+ d2 = _upsample_like(d2,d1)
510
+
511
+ d3 = self.side3(hx3d)
512
+ d3 = _upsample_like(d3,d1)
513
+
514
+ d4 = self.side4(hx4d)
515
+ d4 = _upsample_like(d4,d1)
516
+
517
+ d5 = self.side5(hx5d)
518
+ d5 = _upsample_like(d5,d1)
519
+
520
+ d6 = self.side6(hx6)
521
+ d6 = _upsample_like(d6,d1)
522
+
523
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
524
+
525
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)