dangminh214's picture
Clean initial commit (no large files, no LFS pointers)
b26e93d
"""
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
"""
import collections
import contextlib
import os
import time
from collections import OrderedDict
import cv2 # Added for video processing
import numpy as np
import tensorrt as trt
import torch
import torchvision.transforms as T
from PIL import Image, ImageDraw
class TimeProfiler(contextlib.ContextDecorator):
def __init__(self):
self.total = 0
def __enter__(self):
self.start = self.time()
return self
def __exit__(self, type, value, traceback):
self.total += self.time() - self.start
def reset(self):
self.total = 0
def time(self):
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
class TRTInference(object):
def __init__(
self, engine_path, device="cuda:0", backend="torch", max_batch_size=32, verbose=False
):
self.engine_path = engine_path
self.device = device
self.backend = backend
self.max_batch_size = max_batch_size
self.logger = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger(trt.Logger.INFO)
self.engine = self.load_engine(engine_path)
self.context = self.engine.create_execution_context()
self.bindings = self.get_bindings(
self.engine, self.context, self.max_batch_size, self.device
)
self.bindings_addr = OrderedDict((n, v.ptr) for n, v in self.bindings.items())
self.input_names = self.get_input_names()
self.output_names = self.get_output_names()
self.time_profile = TimeProfiler()
def load_engine(self, path):
trt.init_libnvinfer_plugins(self.logger, "")
with open(path, "rb") as f, trt.Runtime(self.logger) as runtime:
return runtime.deserialize_cuda_engine(f.read())
def get_input_names(self):
names = []
for _, name in enumerate(self.engine):
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
names.append(name)
return names
def get_output_names(self):
names = []
for _, name in enumerate(self.engine):
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
names.append(name)
return names
def get_bindings(self, engine, context, max_batch_size=32, device=None) -> OrderedDict:
Binding = collections.namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
bindings = OrderedDict()
for i, name in enumerate(engine):
shape = engine.get_tensor_shape(name)
dtype = trt.nptype(engine.get_tensor_dtype(name))
if shape[0] == -1:
shape[0] = max_batch_size
if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
context.set_input_shape(name, shape)
data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
bindings[name] = Binding(name, dtype, shape, data, data.data_ptr())
return bindings
def run_torch(self, blob):
for n in self.input_names:
if blob[n].dtype is not self.bindings[n].data.dtype:
blob[n] = blob[n].to(dtype=self.bindings[n].data.dtype)
if self.bindings[n].shape != blob[n].shape:
self.context.set_input_shape(n, blob[n].shape)
self.bindings[n] = self.bindings[n]._replace(shape=blob[n].shape)
assert self.bindings[n].data.dtype == blob[n].dtype, "{} dtype mismatch".format(n)
self.bindings_addr.update({n: blob[n].data_ptr() for n in self.input_names})
self.context.execute_v2(list(self.bindings_addr.values()))
outputs = {n: self.bindings[n].data for n in self.output_names}
return outputs
def __call__(self, blob):
if self.backend == "torch":
return self.run_torch(blob)
else:
raise NotImplementedError("Only 'torch' backend is implemented.")
def synchronize(self):
if self.backend == "torch" and torch.cuda.is_available():
torch.cuda.synchronize()
def draw(images, labels, boxes, scores, thrh=0.4):
for i, im in enumerate(images):
draw = ImageDraw.Draw(im)
scr = scores[i]
lab = labels[i][scr > thrh]
box = boxes[i][scr > thrh]
scrs = scr[scr > thrh]
for j, b in enumerate(box):
draw.rectangle(list(b), outline="red")
draw.text(
(b[0], b[1]),
text=f"{lab[j].item()} {round(scrs[j].item(), 2)}",
fill="blue",
)
return images
def process_image(m, file_path, device):
im_pil = Image.open(file_path).convert("RGB")
w, h = im_pil.size
orig_size = torch.tensor([w, h])[None].to(device)
transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
]
)
im_data = transforms(im_pil)[None]
blob = {
"images": im_data.to(device),
"orig_target_sizes": orig_size.to(device),
}
output = m(blob)
result_images = draw([im_pil], output["labels"], output["boxes"], output["scores"])
result_images[0].save("trt_result.jpg")
print("Image processing complete. Result saved as 'result.jpg'.")
def process_video(m, file_path, device):
cap = cv2.VideoCapture(file_path)
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter("trt_result.mp4", fourcc, fps, (orig_w, orig_h))
transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
]
)
frame_count = 0
print("Processing video frames...")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Convert frame to PIL image
frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
w, h = frame_pil.size
orig_size = torch.tensor([w, h])[None].to(device)
im_data = transforms(frame_pil)[None]
blob = {
"images": im_data.to(device),
"orig_target_sizes": orig_size.to(device),
}
output = m(blob)
# Draw detections on the frame
result_images = draw([frame_pil], output["labels"], output["boxes"], output["scores"])
# Convert back to OpenCV image
frame = cv2.cvtColor(np.array(result_images[0]), cv2.COLOR_RGB2BGR)
# Write the frame
out.write(frame)
frame_count += 1
if frame_count % 10 == 0:
print(f"Processed {frame_count} frames...")
cap.release()
out.release()
print("Video processing complete. Result saved as 'result_video.mp4'.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-trt", "--trt", type=str, required=True)
parser.add_argument("-i", "--input", type=str, required=True)
parser.add_argument("-d", "--device", type=str, default="cuda:0")
args = parser.parse_args()
m = TRTInference(args.trt, device=args.device)
file_path = args.input
if os.path.splitext(file_path)[-1].lower() in [".jpg", ".jpeg", ".png", ".bmp"]:
# Process as image
process_image(m, file_path, args.device)
else:
# Process as video
process_video(m, file_path, args.device)