""" Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. """ import collections import contextlib import os import time from collections import OrderedDict import cv2 # Added for video processing import numpy as np import tensorrt as trt import torch import torchvision.transforms as T from PIL import Image, ImageDraw class TimeProfiler(contextlib.ContextDecorator): def __init__(self): self.total = 0 def __enter__(self): self.start = self.time() return self def __exit__(self, type, value, traceback): self.total += self.time() - self.start def reset(self): self.total = 0 def time(self): if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() class TRTInference(object): def __init__( self, engine_path, device="cuda:0", backend="torch", max_batch_size=32, verbose=False ): self.engine_path = engine_path self.device = device self.backend = backend self.max_batch_size = max_batch_size self.logger = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger(trt.Logger.INFO) self.engine = self.load_engine(engine_path) self.context = self.engine.create_execution_context() self.bindings = self.get_bindings( self.engine, self.context, self.max_batch_size, self.device ) self.bindings_addr = OrderedDict((n, v.ptr) for n, v in self.bindings.items()) self.input_names = self.get_input_names() self.output_names = self.get_output_names() self.time_profile = TimeProfiler() def load_engine(self, path): trt.init_libnvinfer_plugins(self.logger, "") with open(path, "rb") as f, trt.Runtime(self.logger) as runtime: return runtime.deserialize_cuda_engine(f.read()) def get_input_names(self): names = [] for _, name in enumerate(self.engine): if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: names.append(name) return names def get_output_names(self): names = [] for _, name in enumerate(self.engine): if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: names.append(name) return names def get_bindings(self, engine, context, max_batch_size=32, device=None) -> OrderedDict: Binding = collections.namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) bindings = OrderedDict() for i, name in enumerate(engine): shape = engine.get_tensor_shape(name) dtype = trt.nptype(engine.get_tensor_dtype(name)) if shape[0] == -1: shape[0] = max_batch_size if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: context.set_input_shape(name, shape) data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) bindings[name] = Binding(name, dtype, shape, data, data.data_ptr()) return bindings def run_torch(self, blob): for n in self.input_names: if blob[n].dtype is not self.bindings[n].data.dtype: blob[n] = blob[n].to(dtype=self.bindings[n].data.dtype) if self.bindings[n].shape != blob[n].shape: self.context.set_input_shape(n, blob[n].shape) self.bindings[n] = self.bindings[n]._replace(shape=blob[n].shape) assert self.bindings[n].data.dtype == blob[n].dtype, "{} dtype mismatch".format(n) self.bindings_addr.update({n: blob[n].data_ptr() for n in self.input_names}) self.context.execute_v2(list(self.bindings_addr.values())) outputs = {n: self.bindings[n].data for n in self.output_names} return outputs def __call__(self, blob): if self.backend == "torch": return self.run_torch(blob) else: raise NotImplementedError("Only 'torch' backend is implemented.") def synchronize(self): if self.backend == "torch" and torch.cuda.is_available(): torch.cuda.synchronize() def draw(images, labels, boxes, scores, thrh=0.4): for i, im in enumerate(images): draw = ImageDraw.Draw(im) scr = scores[i] lab = labels[i][scr > thrh] box = boxes[i][scr > thrh] scrs = scr[scr > thrh] for j, b in enumerate(box): draw.rectangle(list(b), outline="red") draw.text( (b[0], b[1]), text=f"{lab[j].item()} {round(scrs[j].item(), 2)}", fill="blue", ) return images def process_image(m, file_path, device): im_pil = Image.open(file_path).convert("RGB") w, h = im_pil.size orig_size = torch.tensor([w, h])[None].to(device) transforms = T.Compose( [ T.Resize((640, 640)), T.ToTensor(), ] ) im_data = transforms(im_pil)[None] blob = { "images": im_data.to(device), "orig_target_sizes": orig_size.to(device), } output = m(blob) result_images = draw([im_pil], output["labels"], output["boxes"], output["scores"]) result_images[0].save("trt_result.jpg") print("Image processing complete. Result saved as 'result.jpg'.") def process_video(m, file_path, device): cap = cv2.VideoCapture(file_path) # Get video properties fps = cap.get(cv2.CAP_PROP_FPS) orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Define the codec and create VideoWriter object fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter("trt_result.mp4", fourcc, fps, (orig_w, orig_h)) transforms = T.Compose( [ T.Resize((640, 640)), T.ToTensor(), ] ) frame_count = 0 print("Processing video frames...") while cap.isOpened(): ret, frame = cap.read() if not ret: break # Convert frame to PIL image frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) w, h = frame_pil.size orig_size = torch.tensor([w, h])[None].to(device) im_data = transforms(frame_pil)[None] blob = { "images": im_data.to(device), "orig_target_sizes": orig_size.to(device), } output = m(blob) # Draw detections on the frame result_images = draw([frame_pil], output["labels"], output["boxes"], output["scores"]) # Convert back to OpenCV image frame = cv2.cvtColor(np.array(result_images[0]), cv2.COLOR_RGB2BGR) # Write the frame out.write(frame) frame_count += 1 if frame_count % 10 == 0: print(f"Processed {frame_count} frames...") cap.release() out.release() print("Video processing complete. Result saved as 'result_video.mp4'.") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("-trt", "--trt", type=str, required=True) parser.add_argument("-i", "--input", type=str, required=True) parser.add_argument("-d", "--device", type=str, default="cuda:0") args = parser.parse_args() m = TRTInference(args.trt, device=args.device) file_path = args.input if os.path.splitext(file_path)[-1].lower() in [".jpg", ".jpeg", ".png", ".bmp"]: # Process as image process_image(m, file_path, args.device) else: # Process as video process_video(m, file_path, args.device)