Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from abc import abstractmethod | |
import asyncio | |
from os.path import realpath, dirname, join | |
from loguru import logger as log | |
import numpy as np | |
from api_types import InferenceRequest, InferenceResult, SeedingRequest | |
ROOT_DIR = realpath(dirname(dirname(dirname(__file__)))) | |
DATA_DIR = join(ROOT_DIR, "data") | |
class InferenceModel(): | |
""" | |
Base class for models that can be served by the inference server | |
defined in `server.py`. | |
""" | |
def __init__(self, data_path: str | None = None, checkpoint_path: str | None = None, | |
fake_delay_ms: float = 0, inference_cache_size: int = 15, | |
compress_inference_results: bool = True) -> None: | |
# These paths may be unused by certain inference server types. | |
self.data_path = data_path | |
self.checkpoint_path = checkpoint_path | |
self.fake_delay_ms = fake_delay_ms | |
self.inference_cache_size = inference_cache_size | |
self.inference_tasks: dict[str, asyncio.Task] = {} | |
self.inference_results: dict[str, InferenceResult] = {} | |
self.request_history: set[str] = set() | |
# If supported by the model and relevant, compress inference results, | |
# e.g. as MP4 video, before returning from the server. | |
self.compress_inference_results: bool = compress_inference_results | |
# Can be acquired before starting inference | |
# if the model can only handle one request at a time | |
self.inference_lock = asyncio.Lock() | |
# The generative model may need to be seeded with one or more initial frames. | |
self.model_seeded = False | |
# ----------- Inference model interface | |
async def make_test_image(self): | |
"""Evaluate one default inference request, if possible. | |
Helps ensuring that the model has been loaded correctly.""" | |
raise NotImplementedError("make_test_image") | |
async def seed_model(self, req: SeedingRequest) -> None: | |
"""By default, no seeding is required so the default implementation just returns.""" | |
self.model_seeded = True | |
async def run_inference(self, req: InferenceRequest) -> InferenceResult: | |
"""Evaluate the actual inference model to produce an inference result.""" | |
raise NotImplementedError("run_inference") | |
def metadata(self) -> dict: | |
"""Returns metadata about this inference server.""" | |
raise NotImplementedError("metadata") | |
def min_frames_per_request(self) -> int: | |
"""Minimum number of frames that can be produced in one inference batch.""" | |
raise NotImplementedError("min_frames_per_request") | |
def max_frames_per_request(self) -> int: | |
"""Maximum number of frames that can be produced in one inference batch.""" | |
raise NotImplementedError("max_frames_per_request") | |
def inference_time_per_frame(self) -> int: | |
"""Estimated average inference time per frame (not per batch!) in seconds.""" | |
raise NotImplementedError("inference_time_per_frame") | |
def inference_resolution(self) -> list[tuple[int, int]] | None: | |
""" | |
The supported inference resolutions (width, height) in pixels, | |
or None if any resolution is supported. | |
""" | |
return None | |
def default_framerate(self) -> float | None: | |
""" | |
The model's preferred framerate when generating video. | |
Returns None when not applicable. | |
""" | |
return None | |
def requires_seeding(self) -> int: | |
"""Whether or not this model requires to be seeded with images before inference.""" | |
return False | |
# ----------- Requests handling | |
def request_inference(self, req: InferenceRequest) -> asyncio.Task: | |
if not self.model_seeded: | |
raise ValueError(f"Received request id '{req.request_id}', but the model was not seeded.") | |
if (req.request_id in self.inference_tasks) or (req.request_id in self.inference_results): | |
raise ValueError(f"Invalid request id '{req.request_id}': request already exists.") | |
self.check_valid_request(req) | |
task = asyncio.create_task(self.run_inference(req)) | |
self.inference_tasks[req.request_id] = task | |
self.request_history.add(req.request_id) | |
return task | |
async def request_inference_sync(self, req: InferenceRequest) -> InferenceResult: | |
await self.request_inference(req) | |
result = self.inference_result_or_none(req.request_id) | |
assert isinstance(result, InferenceResult) | |
return result | |
def inference_result_or_none(self, request_id: str) -> InferenceResult | None: | |
if request_id in self.inference_tasks: | |
task = self.inference_tasks[request_id] | |
if task.done(): | |
try: | |
# Inference result ready, cache it and return it | |
result = task.result() | |
self.inference_results[request_id] = result | |
del self.inference_tasks[request_id] | |
self.evict_results() | |
return result | |
except Exception as e: | |
# Inference failed | |
log.error(f"Task for request '{request_id}' failed with exception {e}") | |
raise e | |
else: | |
# Inference result not ready yet | |
return None | |
elif request_id in self.inference_results: | |
# Inference result was ready and cached, return it directly | |
return self.inference_results[request_id] | |
elif request_id in self.request_history: | |
raise KeyError(f"Request with id '{request_id}' was known, but does not have any result. Perhaps it was evicted from the cache or failed.") | |
else: | |
raise KeyError(f"Invalid request id '{request_id}': request not known.") | |
def evict_results(self, keep_max: int | None = None): | |
""" | |
Evict all results that were added before the last `keep_max` entries. | |
""" | |
keep_max = keep_max if (keep_max is not None) else self.inference_cache_size | |
to_evict = [] | |
for i, k in enumerate(reversed(self.inference_results)): | |
if i < keep_max: | |
continue | |
to_evict.append(k) | |
for k in to_evict: | |
del self.inference_results[k] | |
def get_latest_rgb(self) -> np.ndarray | None: | |
"""Returns the latest generated RGB image, if any. Useful for debugging.""" | |
if not self.inference_results: | |
return None | |
last_key = next(reversed(self.inference_results.keys())) | |
return self.inference_results[last_key].images[-1, ...] | |
def check_valid_request(self, req: InferenceRequest): | |
if len(req) not in range(self.min_frames_per_request(), self.max_frames_per_request() + 1): | |
raise ValueError(f"This model can produce between {self.min_frames_per_request()} and" | |
f" {self.max_frames_per_request()} frames per request, but the request" | |
f" specified {len(req)} camera poses.") | |
return True | |
# ----------- Resource management | |
def cleanup(self): | |
pass | |