gen3c / gui /api /server_base.py
elungky's picture
Initial commit for new Space - pre-built Docker image
28451f7
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
import asyncio
from os.path import realpath, dirname, join
from loguru import logger as log
import numpy as np
from api_types import InferenceRequest, InferenceResult, SeedingRequest
ROOT_DIR = realpath(dirname(dirname(dirname(__file__))))
DATA_DIR = join(ROOT_DIR, "data")
class InferenceModel():
"""
Base class for models that can be served by the inference server
defined in `server.py`.
"""
def __init__(self, data_path: str | None = None, checkpoint_path: str | None = None,
fake_delay_ms: float = 0, inference_cache_size: int = 15,
compress_inference_results: bool = True) -> None:
# These paths may be unused by certain inference server types.
self.data_path = data_path
self.checkpoint_path = checkpoint_path
self.fake_delay_ms = fake_delay_ms
self.inference_cache_size = inference_cache_size
self.inference_tasks: dict[str, asyncio.Task] = {}
self.inference_results: dict[str, InferenceResult] = {}
self.request_history: set[str] = set()
# If supported by the model and relevant, compress inference results,
# e.g. as MP4 video, before returning from the server.
self.compress_inference_results: bool = compress_inference_results
# Can be acquired before starting inference
# if the model can only handle one request at a time
self.inference_lock = asyncio.Lock()
# The generative model may need to be seeded with one or more initial frames.
self.model_seeded = False
# ----------- Inference model interface
@abstractmethod
async def make_test_image(self):
"""Evaluate one default inference request, if possible.
Helps ensuring that the model has been loaded correctly."""
raise NotImplementedError("make_test_image")
async def seed_model(self, req: SeedingRequest) -> None:
"""By default, no seeding is required so the default implementation just returns."""
self.model_seeded = True
@abstractmethod
async def run_inference(self, req: InferenceRequest) -> InferenceResult:
"""Evaluate the actual inference model to produce an inference result."""
raise NotImplementedError("run_inference")
@abstractmethod
def metadata(self) -> dict:
"""Returns metadata about this inference server."""
raise NotImplementedError("metadata")
@abstractmethod
def min_frames_per_request(self) -> int:
"""Minimum number of frames that can be produced in one inference batch."""
raise NotImplementedError("min_frames_per_request")
@abstractmethod
def max_frames_per_request(self) -> int:
"""Maximum number of frames that can be produced in one inference batch."""
raise NotImplementedError("max_frames_per_request")
@abstractmethod
def inference_time_per_frame(self) -> int:
"""Estimated average inference time per frame (not per batch!) in seconds."""
raise NotImplementedError("inference_time_per_frame")
def inference_resolution(self) -> list[tuple[int, int]] | None:
"""
The supported inference resolutions (width, height) in pixels,
or None if any resolution is supported.
"""
return None
def default_framerate(self) -> float | None:
"""
The model's preferred framerate when generating video.
Returns None when not applicable.
"""
return None
@abstractmethod
def requires_seeding(self) -> int:
"""Whether or not this model requires to be seeded with images before inference."""
return False
# ----------- Requests handling
def request_inference(self, req: InferenceRequest) -> asyncio.Task:
if not self.model_seeded:
raise ValueError(f"Received request id '{req.request_id}', but the model was not seeded.")
if (req.request_id in self.inference_tasks) or (req.request_id in self.inference_results):
raise ValueError(f"Invalid request id '{req.request_id}': request already exists.")
self.check_valid_request(req)
task = asyncio.create_task(self.run_inference(req))
self.inference_tasks[req.request_id] = task
self.request_history.add(req.request_id)
return task
async def request_inference_sync(self, req: InferenceRequest) -> InferenceResult:
await self.request_inference(req)
result = self.inference_result_or_none(req.request_id)
assert isinstance(result, InferenceResult)
return result
def inference_result_or_none(self, request_id: str) -> InferenceResult | None:
if request_id in self.inference_tasks:
task = self.inference_tasks[request_id]
if task.done():
try:
# Inference result ready, cache it and return it
result = task.result()
self.inference_results[request_id] = result
del self.inference_tasks[request_id]
self.evict_results()
return result
except Exception as e:
# Inference failed
log.error(f"Task for request '{request_id}' failed with exception {e}")
raise e
else:
# Inference result not ready yet
return None
elif request_id in self.inference_results:
# Inference result was ready and cached, return it directly
return self.inference_results[request_id]
elif request_id in self.request_history:
raise KeyError(f"Request with id '{request_id}' was known, but does not have any result. Perhaps it was evicted from the cache or failed.")
else:
raise KeyError(f"Invalid request id '{request_id}': request not known.")
def evict_results(self, keep_max: int | None = None):
"""
Evict all results that were added before the last `keep_max` entries.
"""
keep_max = keep_max if (keep_max is not None) else self.inference_cache_size
to_evict = []
for i, k in enumerate(reversed(self.inference_results)):
if i < keep_max:
continue
to_evict.append(k)
for k in to_evict:
del self.inference_results[k]
def get_latest_rgb(self) -> np.ndarray | None:
"""Returns the latest generated RGB image, if any. Useful for debugging."""
if not self.inference_results:
return None
last_key = next(reversed(self.inference_results.keys()))
return self.inference_results[last_key].images[-1, ...]
def check_valid_request(self, req: InferenceRequest):
if len(req) not in range(self.min_frames_per_request(), self.max_frames_per_request() + 1):
raise ValueError(f"This model can produce between {self.min_frames_per_request()} and"
f" {self.max_frames_per_request()} frames per request, but the request"
f" specified {len(req)} camera poses.")
return True
# ----------- Resource management
def cleanup(self):
pass