[feat] use temp files on SamGeo prediction
Browse files- src/main.py +6 -3
- src/prediction_api/predictor.py +34 -34
src/main.py
CHANGED
|
@@ -92,9 +92,12 @@ def samgeo(request_input: Input):
|
|
| 92 |
bbox=request_body["bbox"],
|
| 93 |
point_coords=request_body["point"]
|
| 94 |
)
|
| 95 |
-
|
| 96 |
-
app_logger.info(f"
|
| 97 |
-
body
|
|
|
|
|
|
|
|
|
|
| 98 |
return JSONResponse(status_code=200, content={"body": json.dumps(body)})
|
| 99 |
except Exception as inference_exception:
|
| 100 |
home_content = subprocess.run("ls -l /home/user", shell=True, universal_newlines=True, stdout=subprocess.PIPE)
|
|
|
|
| 92 |
bbox=request_body["bbox"],
|
| 93 |
point_coords=request_body["point"]
|
| 94 |
)
|
| 95 |
+
duration_run = time.time() - time_start_run
|
| 96 |
+
app_logger.info(f"duration_run:{duration_run}.")
|
| 97 |
+
body = {
|
| 98 |
+
"duration_run": duration_run,
|
| 99 |
+
"output": output
|
| 100 |
+
}
|
| 101 |
return JSONResponse(status_code=200, content={"body": json.dumps(body)})
|
| 102 |
except Exception as inference_exception:
|
| 103 |
home_content = subprocess.run("ls -l /home/user", shell=True, universal_newlines=True, stdout=subprocess.PIPE)
|
src/prediction_api/predictor.py
CHANGED
|
@@ -9,39 +9,39 @@ from src.utilities.type_hints import input_floatlist, input_floatlist2
|
|
| 9 |
def base_predict(
|
| 10 |
bbox: input_floatlist, point_coords: input_floatlist2, point_crs: str = "EPSG:4326", zoom: float = 16, model_name: str = "vit_h", root_folder: str = ROOT
|
| 11 |
) -> str:
|
|
|
|
| 12 |
from samgeo import SamGeo, tms_to_geotiff
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
return out_gdf_json
|
|
|
|
| 9 |
def base_predict(
|
| 10 |
bbox: input_floatlist, point_coords: input_floatlist2, point_crs: str = "EPSG:4326", zoom: float = 16, model_name: str = "vit_h", root_folder: str = ROOT
|
| 11 |
) -> str:
|
| 12 |
+
import tempfile
|
| 13 |
from samgeo import SamGeo, tms_to_geotiff
|
| 14 |
|
| 15 |
+
with tempfile.NamedTemporaryFile(prefix="satellite_", suffix=".tif", dir=root_folder) as image_input_tmp:
|
| 16 |
+
app_logger.info(f"start tms_to_geotiff using bbox:{bbox}, type:{type(bbox)}, download image to {image_input_tmp.name} ...")
|
| 17 |
+
for coord in bbox:
|
| 18 |
+
app_logger.info(f"coord:{coord}, type:{type(coord)}.")
|
| 19 |
+
|
| 20 |
+
# bbox: image input coordinate
|
| 21 |
+
tms_to_geotiff(output=image_input_tmp.name, bbox=bbox, zoom=zoom, source="Satellite", overwrite=True)
|
| 22 |
+
app_logger.info(f"geotiff created, start to initialize samgeo instance (read model {model_name} from {root_folder})...")
|
| 23 |
+
|
| 24 |
+
predictor = SamGeo(
|
| 25 |
+
model_type=model_name,
|
| 26 |
+
checkpoint_dir=root_folder,
|
| 27 |
+
automatic=False,
|
| 28 |
+
sam_kwargs=None,
|
| 29 |
+
)
|
| 30 |
+
app_logger.info(f"initialized samgeo instance, start to use SamGeo.set_image({image_input_tmp.name})...")
|
| 31 |
+
predictor.set_image(image_input_tmp.name)
|
| 32 |
+
|
| 33 |
+
with tempfile.NamedTemporaryFile(prefix="output_", suffix=".tif", dir=root_folder) as image_output_tmp:
|
| 34 |
+
app_logger.info(f"done set_image, start prediction using {image_output_tmp.name} as output...")
|
| 35 |
+
predictor.predict(point_coords, point_labels=len(point_coords), point_crs=point_crs, output=image_output_tmp.name)
|
| 36 |
+
|
| 37 |
+
# geotiff to geojson
|
| 38 |
+
with tempfile.NamedTemporaryFile(prefix="feats_", suffix=".geojson", dir=root_folder) as vector_tmp:
|
| 39 |
+
app_logger.info(f"done prediction, start conversion SamGeo.tiff_to_geojson({image_output_tmp.name}) => {vector_tmp.name}.")
|
| 40 |
+
predictor.tiff_to_geojson(image_output_tmp.name, vector_tmp.name, bidx=1)
|
| 41 |
+
|
| 42 |
+
app_logger.info(f"start reading geojson {vector_tmp.name}...")
|
| 43 |
+
with open(vector_tmp.name) as out_gdf:
|
| 44 |
+
out_gdf_str = out_gdf.read()
|
| 45 |
+
out_gdf_json = json.loads(out_gdf_str)
|
| 46 |
+
app_logger.info(f"geojson {vector_tmp.name} string has length: {len(out_gdf_str)}.")
|
| 47 |
+
return out_gdf_json
|
|
|