File size: 7,063 Bytes
28451f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
import asyncio
from os.path import realpath, dirname, join

from loguru import logger as log
import numpy as np

from api_types import InferenceRequest, InferenceResult, SeedingRequest


ROOT_DIR = realpath(dirname(dirname(dirname(__file__))))
DATA_DIR = join(ROOT_DIR, "data")


class InferenceModel():
	"""
	Base class for models that can be served by the inference server
	defined in `server.py`.
	"""

	def __init__(self, data_path: str | None = None, checkpoint_path: str | None = None,
				 fake_delay_ms: float = 0, inference_cache_size: int = 15,
				 compress_inference_results: bool = True) -> None:

		# These paths may be unused by certain inference server types.
		self.data_path = data_path
		self.checkpoint_path = checkpoint_path

		self.fake_delay_ms = fake_delay_ms
		self.inference_cache_size = inference_cache_size

		self.inference_tasks: dict[str, asyncio.Task] = {}
		self.inference_results: dict[str, InferenceResult] = {}
		self.request_history: set[str] = set()

		# If supported by the model and relevant, compress inference results,
		# e.g. as MP4 video, before returning from the server.
		self.compress_inference_results: bool = compress_inference_results

		# Can be acquired before starting inference
		# if the model can only handle one request at a time
		self.inference_lock = asyncio.Lock()

		# The generative model may need to be seeded with one or more initial frames.
		self.model_seeded = False


	# ----------- Inference model interface

	@abstractmethod
	async def make_test_image(self):
		"""Evaluate one default inference request, if possible.
		Helps ensuring that the model has been loaded correctly."""
		raise NotImplementedError("make_test_image")

	async def seed_model(self, req: SeedingRequest) -> None:
		"""By default, no seeding is required so the default implementation just returns."""
		self.model_seeded = True

	@abstractmethod
	async def run_inference(self, req: InferenceRequest) -> InferenceResult:
		"""Evaluate the actual inference model to produce an inference result."""
		raise NotImplementedError("run_inference")

	@abstractmethod
	def metadata(self) -> dict:
		"""Returns metadata about this inference server."""
		raise NotImplementedError("metadata")

	@abstractmethod
	def min_frames_per_request(self) -> int:
		"""Minimum number of frames that can be produced in one inference batch."""
		raise NotImplementedError("min_frames_per_request")

	@abstractmethod
	def max_frames_per_request(self) -> int:
		"""Maximum number of frames that can be produced in one inference batch."""
		raise NotImplementedError("max_frames_per_request")

	@abstractmethod
	def inference_time_per_frame(self) -> int:
		"""Estimated average inference time per frame (not per batch!) in seconds."""
		raise NotImplementedError("inference_time_per_frame")

	def inference_resolution(self) -> list[tuple[int, int]] | None:
		"""
		The supported inference resolutions (width, height) in pixels,
		or None if any resolution is supported.
		"""
		return None

	def default_framerate(self) -> float | None:
		"""
		The model's preferred framerate when generating video.
		Returns None when not applicable.
		"""
		return None

	@abstractmethod
	def requires_seeding(self) -> int:
		"""Whether or not this model requires to be seeded with images before inference."""
		return False

	# ----------- Requests handling

	def request_inference(self, req: InferenceRequest) -> asyncio.Task:
		if not self.model_seeded:
			raise ValueError(f"Received request id '{req.request_id}', but the model was not seeded.")
		if (req.request_id in self.inference_tasks) or (req.request_id in self.inference_results):
			raise ValueError(f"Invalid request id '{req.request_id}': request already exists.")
		self.check_valid_request(req)

		task = asyncio.create_task(self.run_inference(req))
		self.inference_tasks[req.request_id] = task
		self.request_history.add(req.request_id)
		return task


	async def request_inference_sync(self, req: InferenceRequest) -> InferenceResult:
		await self.request_inference(req)
		result = self.inference_result_or_none(req.request_id)
		assert isinstance(result, InferenceResult)
		return result


	def inference_result_or_none(self, request_id: str) -> InferenceResult | None:
		if request_id in self.inference_tasks:
			task = self.inference_tasks[request_id]
			if task.done():
				try:
					# Inference result ready, cache it and return it
					result = task.result()
					self.inference_results[request_id] = result
					del self.inference_tasks[request_id]
					self.evict_results()
					return result
				except Exception as e:
					# Inference failed
					log.error(f"Task for request '{request_id}' failed with exception {e}")
					raise e
			else:
				# Inference result not ready yet
				return None

		elif request_id in self.inference_results:
			# Inference result was ready and cached, return it directly
			return self.inference_results[request_id]

		elif request_id in self.request_history:
			raise KeyError(f"Request with id '{request_id}' was known, but does not have any result. Perhaps it was evicted from the cache or failed.")

		else:
			raise KeyError(f"Invalid request id '{request_id}': request not known.")


	def evict_results(self, keep_max: int | None = None):
		"""
		Evict all results that were added before the last `keep_max` entries.
		"""
		keep_max = keep_max if (keep_max is not None) else self.inference_cache_size

		to_evict = []
		for i, k in enumerate(reversed(self.inference_results)):
			if i < keep_max:
				continue
			to_evict.append(k)
		for k in to_evict:
			del self.inference_results[k]


	def get_latest_rgb(self) -> np.ndarray | None:
		"""Returns the latest generated RGB image, if any. Useful for debugging."""
		if not self.inference_results:
			return None
		last_key = next(reversed(self.inference_results.keys()))
		return self.inference_results[last_key].images[-1, ...]

	def check_valid_request(self, req: InferenceRequest):
		if len(req) not in range(self.min_frames_per_request(), self.max_frames_per_request() + 1):
			raise ValueError(f"This model can produce between {self.min_frames_per_request()} and"
							 f" {self.max_frames_per_request()} frames per request, but the request"
							 f" specified {len(req)} camera poses.")

		return True


	# ----------- Resource management
	def cleanup(self):
		pass