Spaces:
Running
Running
| from io import BytesIO | |
| import litserve as ls | |
| import numpy as np | |
| from fastapi import Response, UploadFile | |
| from PIL import Image | |
| from lang_sam import LangSAM | |
| from lang_sam.utils import draw_image | |
| PORT = 8000 | |
| class LangSAMAPI(ls.LitAPI): | |
| def setup(self, device: str) -> None: | |
| """Initialize or load the LangSAM model.""" | |
| self.model = LangSAM(sam_type="sam2.1_hiera_small") | |
| print("LangSAM model initialized.") | |
| def decode_request(self, request) -> dict: | |
| """Decode the incoming request to extract parameters and image bytes. | |
| Assumes the request is sent as multipart/form-data with fields: | |
| - sam_type: str | |
| - box_threshold: float | |
| - text_threshold: float | |
| - text_prompt: str | |
| - image: UploadFile | |
| """ | |
| # Extract form data | |
| sam_type = request.get("sam_type") | |
| box_threshold = float(request.get("box_threshold", 0.3)) | |
| text_threshold = float(request.get("text_threshold", 0.25)) | |
| text_prompt = request.get("text_prompt", "") | |
| # Extract image file | |
| image_file: UploadFile = request.get("image") | |
| if image_file is None: | |
| raise ValueError("No image file provided in the request.") | |
| image_bytes = image_file.file.read() | |
| return { | |
| "sam_type": sam_type, | |
| "box_threshold": box_threshold, | |
| "text_threshold": text_threshold, | |
| "image_bytes": image_bytes, | |
| "text_prompt": text_prompt, | |
| } | |
| def predict(self, inputs: dict) -> dict: | |
| """Perform prediction using the LangSAM model. | |
| Yields: | |
| dict: Contains the processed output image. | |
| """ | |
| print("Starting prediction with parameters:") | |
| print( | |
| f"sam_type: {inputs['sam_type']}, \ | |
| box_threshold: {inputs['box_threshold']}, \ | |
| text_threshold: {inputs['text_threshold']}, \ | |
| text_prompt: {inputs['text_prompt']}" | |
| ) | |
| if inputs["sam_type"] != self.model.sam_type: | |
| print(f"Updating SAM model type to {inputs['sam_type']}") | |
| self.model.sam.build_model(inputs["sam_type"]) | |
| try: | |
| image_pil = Image.open(BytesIO(inputs["image_bytes"])).convert("RGB") | |
| except Exception as e: | |
| raise ValueError(f"Invalid image data: {e}") | |
| results = self.model.predict( | |
| images_pil=[image_pil], | |
| texts_prompt=[inputs["text_prompt"]], | |
| box_threshold=inputs["box_threshold"], | |
| text_threshold=inputs["text_threshold"], | |
| ) | |
| results = results[0] | |
| if not len(results["masks"]): | |
| print("No masks detected. Returning original image.") | |
| return {"output_image": image_pil} | |
| # Draw results on the image | |
| image_array = np.asarray(image_pil) | |
| output_image = draw_image( | |
| image_array, | |
| results["masks"], | |
| results["boxes"], | |
| results["scores"], | |
| results["labels"], | |
| ) | |
| output_image = Image.fromarray(np.uint8(output_image)).convert("RGB") | |
| return {"output_image": output_image} | |
| def encode_response(self, output: dict) -> Response: | |
| """Encode the prediction result into an HTTP response. | |
| Returns: | |
| Response: Contains the processed image in PNG format. | |
| """ | |
| try: | |
| image = output["output_image"] | |
| buffer = BytesIO() | |
| image.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| return Response(content=buffer.getvalue(), media_type="image/png") | |
| except StopIteration: | |
| raise ValueError("No output generated by the prediction.") | |
| lit_api = LangSAMAPI() | |
| server = ls.LitServer(lit_api) | |
| if __name__ == "__main__": | |
| print(f"Starting LitServe and Gradio server on port {PORT}...") | |
| server.run(port=PORT) | |