Spaces:
Build error
Build error
File size: 7,063 Bytes
28451f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
import asyncio
from os.path import realpath, dirname, join
from loguru import logger as log
import numpy as np
from api_types import InferenceRequest, InferenceResult, SeedingRequest
ROOT_DIR = realpath(dirname(dirname(dirname(__file__))))
DATA_DIR = join(ROOT_DIR, "data")
class InferenceModel():
"""
Base class for models that can be served by the inference server
defined in `server.py`.
"""
def __init__(self, data_path: str | None = None, checkpoint_path: str | None = None,
fake_delay_ms: float = 0, inference_cache_size: int = 15,
compress_inference_results: bool = True) -> None:
# These paths may be unused by certain inference server types.
self.data_path = data_path
self.checkpoint_path = checkpoint_path
self.fake_delay_ms = fake_delay_ms
self.inference_cache_size = inference_cache_size
self.inference_tasks: dict[str, asyncio.Task] = {}
self.inference_results: dict[str, InferenceResult] = {}
self.request_history: set[str] = set()
# If supported by the model and relevant, compress inference results,
# e.g. as MP4 video, before returning from the server.
self.compress_inference_results: bool = compress_inference_results
# Can be acquired before starting inference
# if the model can only handle one request at a time
self.inference_lock = asyncio.Lock()
# The generative model may need to be seeded with one or more initial frames.
self.model_seeded = False
# ----------- Inference model interface
@abstractmethod
async def make_test_image(self):
"""Evaluate one default inference request, if possible.
Helps ensuring that the model has been loaded correctly."""
raise NotImplementedError("make_test_image")
async def seed_model(self, req: SeedingRequest) -> None:
"""By default, no seeding is required so the default implementation just returns."""
self.model_seeded = True
@abstractmethod
async def run_inference(self, req: InferenceRequest) -> InferenceResult:
"""Evaluate the actual inference model to produce an inference result."""
raise NotImplementedError("run_inference")
@abstractmethod
def metadata(self) -> dict:
"""Returns metadata about this inference server."""
raise NotImplementedError("metadata")
@abstractmethod
def min_frames_per_request(self) -> int:
"""Minimum number of frames that can be produced in one inference batch."""
raise NotImplementedError("min_frames_per_request")
@abstractmethod
def max_frames_per_request(self) -> int:
"""Maximum number of frames that can be produced in one inference batch."""
raise NotImplementedError("max_frames_per_request")
@abstractmethod
def inference_time_per_frame(self) -> int:
"""Estimated average inference time per frame (not per batch!) in seconds."""
raise NotImplementedError("inference_time_per_frame")
def inference_resolution(self) -> list[tuple[int, int]] | None:
"""
The supported inference resolutions (width, height) in pixels,
or None if any resolution is supported.
"""
return None
def default_framerate(self) -> float | None:
"""
The model's preferred framerate when generating video.
Returns None when not applicable.
"""
return None
@abstractmethod
def requires_seeding(self) -> int:
"""Whether or not this model requires to be seeded with images before inference."""
return False
# ----------- Requests handling
def request_inference(self, req: InferenceRequest) -> asyncio.Task:
if not self.model_seeded:
raise ValueError(f"Received request id '{req.request_id}', but the model was not seeded.")
if (req.request_id in self.inference_tasks) or (req.request_id in self.inference_results):
raise ValueError(f"Invalid request id '{req.request_id}': request already exists.")
self.check_valid_request(req)
task = asyncio.create_task(self.run_inference(req))
self.inference_tasks[req.request_id] = task
self.request_history.add(req.request_id)
return task
async def request_inference_sync(self, req: InferenceRequest) -> InferenceResult:
await self.request_inference(req)
result = self.inference_result_or_none(req.request_id)
assert isinstance(result, InferenceResult)
return result
def inference_result_or_none(self, request_id: str) -> InferenceResult | None:
if request_id in self.inference_tasks:
task = self.inference_tasks[request_id]
if task.done():
try:
# Inference result ready, cache it and return it
result = task.result()
self.inference_results[request_id] = result
del self.inference_tasks[request_id]
self.evict_results()
return result
except Exception as e:
# Inference failed
log.error(f"Task for request '{request_id}' failed with exception {e}")
raise e
else:
# Inference result not ready yet
return None
elif request_id in self.inference_results:
# Inference result was ready and cached, return it directly
return self.inference_results[request_id]
elif request_id in self.request_history:
raise KeyError(f"Request with id '{request_id}' was known, but does not have any result. Perhaps it was evicted from the cache or failed.")
else:
raise KeyError(f"Invalid request id '{request_id}': request not known.")
def evict_results(self, keep_max: int | None = None):
"""
Evict all results that were added before the last `keep_max` entries.
"""
keep_max = keep_max if (keep_max is not None) else self.inference_cache_size
to_evict = []
for i, k in enumerate(reversed(self.inference_results)):
if i < keep_max:
continue
to_evict.append(k)
for k in to_evict:
del self.inference_results[k]
def get_latest_rgb(self) -> np.ndarray | None:
"""Returns the latest generated RGB image, if any. Useful for debugging."""
if not self.inference_results:
return None
last_key = next(reversed(self.inference_results.keys()))
return self.inference_results[last_key].images[-1, ...]
def check_valid_request(self, req: InferenceRequest):
if len(req) not in range(self.min_frames_per_request(), self.max_frames_per_request() + 1):
raise ValueError(f"This model can produce between {self.min_frames_per_request()} and"
f" {self.max_frames_per_request()} frames per request, but the request"
f" specified {len(req)} camera poses.")
return True
# ----------- Resource management
def cleanup(self):
pass
|