Skip to content

Commit

Permalink
Add streaming video support using VideoPipe (#17, #19)
Browse files Browse the repository at this point in the history
  • Loading branch information
laugh12321 committed Jun 27, 2024
1 parent 2cacb7d commit 01f8c64
Show file tree
Hide file tree
Showing 9 changed files with 579 additions and 0 deletions.
Binary file added assets/videopipe.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
151 changes: 151 additions & 0 deletions demo/VideoPipe/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# this is the build file for project
# it is autogenerated by the xmake build system.
# do not edit by hand.

# project
cmake_minimum_required(VERSION 3.15.0)
cmake_policy(SET CMP0091 NEW)
project(PipeDemo LANGUAGES CXX CUDA)

# target
add_executable(PipeDemo "")
set_target_properties(PipeDemo PROPERTIES OUTPUT_NAME "PipeDemo")
set_target_properties(PipeDemo PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/build/linux/x86_64/release")
target_include_directories(PipeDemo PRIVATE
/usr/local/tensorrt/include
/home/laugh/Projects/TensorRT-YOLO/include
/home/laugh/Projects/VideoPipe
/usr/local/cuda/include
)
target_include_directories(PipeDemo SYSTEM PRIVATE
/usr/local/include/opencv4
/usr/include/libdrm
/usr/include/gstreamer-1.0
/usr/include/x86_64-linux-gnu
/usr/include/glib-2.0
/usr/lib/x86_64-linux-gnu/glib-2.0/include
)
target_compile_options(PipeDemo PRIVATE
$<$<COMPILE_LANGUAGE:C>:-m64>
$<$<COMPILE_LANGUAGE:CXX>:-m64>
$<$<COMPILE_LANGUAGE:C>:-DNDEBUG>
$<$<COMPILE_LANGUAGE:CXX>:-DNDEBUG>
$<$<COMPILE_LANGUAGE:C>:-pthread>
$<$<COMPILE_LANGUAGE:CXX>:-pthread>
$<$<COMPILE_LANGUAGE:CUDA>:-allow-unsupported-compiler>
$<$<COMPILE_LANGUAGE:CUDA>:-m64>
$<$<COMPILE_LANGUAGE:CUDA>:-rdc=true>
$<$<COMPILE_LANGUAGE:CUDA>:-gencode arch=compute_75,code=sm_75>
)
set_target_properties(PipeDemo PROPERTIES CXX_EXTENSIONS OFF)
target_compile_features(PipeDemo PRIVATE cxx_std_17)
if(MSVC)
target_compile_options(PipeDemo PRIVATE $<$<CONFIG:Release>:-Ox -fp:fast>)
else()
target_compile_options(PipeDemo PRIVATE -O3)
endif()
if(MSVC)
else()
target_compile_options(PipeDemo PRIVATE -fvisibility=hidden)
endif()
if(MSVC)
set_property(TARGET PipeDemo PROPERTY
MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>")
endif()
target_link_libraries(PipeDemo PRIVATE
nvinfer
nvinfer_plugin
nvonnxparser
deploy
video_pipe
tinyexpr
opencv_gapi
opencv_stitching
opencv_aruco
opencv_bgsegm
opencv_bioinspired
opencv_ccalib
opencv_cudabgsegm
opencv_cudafeatures2d
opencv_cudaobjdetect
opencv_cudastereo
opencv_dnn_objdetect
opencv_dnn_superres
opencv_dpm
opencv_face
opencv_freetype
opencv_fuzzy
opencv_hdf
opencv_hfs
opencv_img_hash
opencv_intensity_transform
opencv_line_descriptor
opencv_mcc
opencv_quality
opencv_rapid
opencv_reg
opencv_rgbd
opencv_saliency
opencv_signal
opencv_stereo
opencv_structured_light
opencv_phase_unwrapping
opencv_superres
opencv_surface_matching
opencv_tracking
opencv_highgui
opencv_datasets
opencv_text
opencv_plot
opencv_videostab
opencv_cudaoptflow
opencv_optflow
opencv_cudalegacy
opencv_videoio
opencv_cudawarping
opencv_wechat_qrcode
opencv_xfeatures2d
opencv_shape
opencv_ml
opencv_ximgproc
opencv_video
opencv_xobjdetect
opencv_objdetect
opencv_calib3d
opencv_imgcodecs
opencv_features2d
opencv_dnn
opencv_flann
opencv_xphoto
opencv_photo
opencv_cudaimgproc
opencv_cudafilters
opencv_imgproc
opencv_cudaarithm
opencv_core
opencv_cudev
drm
gstreamer-1.0
gobject-2.0
glib-2.0
cudadevrt
cudart_static
rt
pthread
dl
)
target_link_directories(PipeDemo PRIVATE
/usr/local/tensorrt/lib
/home/laugh/Projects/TensorRT-YOLO/lib
/home/laugh/Projects/VideoPipe/build/libs
/usr/local/cuda/lib64
/usr/local/lib
)
target_link_options(PipeDemo PRIVATE
-m64
)
target_sources(PipeDemo PRIVATE
src/main.cpp
src/vp_trtyolo_detector.cpp
)

49 changes: 49 additions & 0 deletions demo/VideoPipe/README.en.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
English | [简体中文](README.md)

# Video Analysis Example

This example uses the YOLOv8s model to demonstrate how to integrate the TensorRT-YOLO Deploy module into [VideoPipe](https://github.com/sherlockchou86/VideoPipe) for video analysis.

## Model Export

First, download the YOLOv8s model from [YOLOv8s](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8s.pt) and save it to the `workspace` folder.

Next, use the following command to export the model to ONNX format with the EfficientNMS plugin from [EfficientNMS](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin):

```bash
cd workspace
trtyolo export -w yolov8s.pt -v yolov8 -o models -b 2
```

After executing the command above, a file named `yolov8s.onnx` will be generated in the `models` folder. Then, convert the ONNX file to a TensorRT engine using the `trtexec` tool:

```bash
cd workspace
trtexec --onnx=yolov8s.onnx --saveEngine=yolov8s.engine --fp16
```

## Project Execution

Before performing inference, make sure VideoPipe and TensorRT-YOLO have been compiled.

Next, use xmake to compile the project into an executable:

```bash
xmake f -P . --tensorrt=/path/to/your/TensorRT --deploy=/path/to/your/TensorRT-YOLO --videopipe=/path/to/your/VideoPipe

xmake -P . -r
```

After successful compilation, you can directly run the generated executable or use the `xmake run` command for inference:

```bash
xmake run -P . PipeDemo
```

<div align="center">
<p>
<img width="100%" src="../../assets/videopipe.jpg">
</p>
</div>

The above demonstrates the method for performing model inference.
49 changes: 49 additions & 0 deletions demo/VideoPipe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
[English](README.en.md) | 简体中文

# 视频分析示例

本示例以 YOLOv8s 模型为例,演示如何将 TensorRT-YOLO 的 Deploy 模块集成到 [VideoPipe](https://github.com/sherlockchou86/VideoPipe) 中进行视频分析。

## 模型导出

首先,从 [YOLOv8s](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8s.pt) 下载 YOLOv8s 模型并保存到 `workspace` 文件夹中。

然后,使用以下指令将模型导出为带有 [EfficientNMS](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin) 插件的 ONNX 格式:

```bash
cd workspace
trtyolo export -w yolov8s.pt -v yolov8 -o models -b 2
```

执行以上命令后,将在 `models` 文件夹下生成名为 `yolov8s.onnx` 的文件。接着,使用 `trtexec` 工具将 ONNX 文件转换为 TensorRT engine:

```bash
cd workspace
trtexec --onnx=yolov8s.onnx --saveEngine=yolov8s.engine --fp16
```

## 项目运行

在进行推理之前,请确保已经编译了 VideoPipe 和 TensorRT-YOLO。

接下来,使用 xmake 将项目编译为可执行文件:

```bash
xmake f -P . --tensorrt=/path/to/your/TensorRT --deploy=/path/to/your/TensorRT-YOLO --videopipe=/path/to/your/VideoPipe

xmake -P . -r
```

编译成功后,您可以直接运行生成的可执行文件或使用 `xmake run` 命令进行推理:

```bash
xmake run -P . PipeDemo
```

<div align="center">
<p>
<img width="100%" src="../../assets/videopipe.jpg">
</p>
</div>

以上是进行模型推理的方法示例。
58 changes: 58 additions & 0 deletions demo/VideoPipe/src/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "vp_trtyolo_detector.h"
#include "nodes/vp_split_node.h"
#include "nodes/osd/vp_osd_node.h"
#include "nodes/vp_file_src_node.h"
#include "nodes/vp_screen_des_node.h"
#include "nodes/track/vp_sort_track_node.h"
#include "utils/analysis_board/vp_analysis_board.h"

int main() {
// Disable logging code location and thread ID
VP_SET_LOG_INCLUDE_CODE_LOCATION(false);
VP_SET_LOG_INCLUDE_THREAD_ID(false);
VP_LOGGER_INIT();

// Video sources
auto file_src_0 = std::make_shared<vp_nodes::vp_file_src_node>("file_src_0", 0, "demo0.mp4");
auto file_src_1 = std::make_shared<vp_nodes::vp_file_src_node>("file_src_1", 1, "demo1.mp4");

// Inference node (TensorRT-YOLO detector)
auto detector = std::make_shared<vp_nodes::vp_trtyolo_detector>("yolo_detector", "yolov8s.engine", "labels.txt", true, 2);

// Tracking node (SORT tracker)
auto track = std::make_shared<vp_nodes::vp_sort_track_node>("track");

// OSD (On-Screen Display) node
auto osd = std::make_shared<vp_nodes::vp_osd_node>("osd");

// Channel splitting node
auto split = std::make_shared<vp_nodes::vp_split_node>("split_by_channel", true);

// Local display nodes
auto screen_des_0 = std::make_shared<vp_nodes::vp_screen_des_node>("screen_des_0", 0);
auto screen_des_1 = std::make_shared<vp_nodes::vp_screen_des_node>("screen_des_1", 1);

// Constructing the pipeline
detector->attach_to({file_src_0, file_src_1});
track->attach_to({detector});
osd->attach_to({track});
split->attach_to({osd});

// Splitting by vp_split_node for display
screen_des_0->attach_to({split});
screen_des_1->attach_to({split});

// Start video sources
file_src_0->start();
file_src_1->start();

// Debugging: Display analysis board
vp_utils::vp_analysis_board board({file_src_0, file_src_1});
board.display(1, false); // Display board with refresh rate of 1 second, non-verbose

// Wait for user input to stop and detach nodes recursively
std::string wait;
std::getline(std::cin, wait);
file_src_0->detach_recursively();
file_src_1->detach_recursively();
}
83 changes: 83 additions & 0 deletions demo/VideoPipe/src/vp_trtyolo_detector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include "vp_trtyolo_detector.h"

namespace vp_nodes {

vp_trtyolo_detector::vp_trtyolo_detector(std::string node_name, std::string model_path, std::string labels_path, bool cudagraph, int batch, int device_id)
: vp_primary_infer_node(node_name, "", "", labels_path, 0, 0, batch), use_cudagraph(cudagraph) {

// Initialize detector based on CUDA graph usage
if (use_cudagraph) {
detector = std::make_shared<deploy::DeployCGDet>(model_path, device_id);
if (detector->batch != batch) {
throw std::runtime_error("Batch size mismatch: expected " + std::to_string(batch) + ", but got " + std::to_string(detector->batch));
}
} else {
detector = std::make_shared<deploy::DeployDet>(model_path, device_id);
if (detector->batch < batch) {
throw std::runtime_error("Batch size too large: expected <= " + std::to_string(detector->batch) + ", but got " + std::to_string(batch));
}
}

this->initialized(); // Mark node as initialized
}

vp_trtyolo_detector::~vp_trtyolo_detector() {
// Destructor: Clean up any resources
deinitialized(); // Mark node as deinitialized
}

void vp_trtyolo_detector::run_infer_combinations(const std::vector<std::shared_ptr<vp_objects::vp_frame_meta>>& frame_meta_with_batch) {
if (use_cudagraph)
assert(frame_meta_with_batch.size() == detector->batch); // Assert batch size consistency if using CUDA graph

std::vector<cv::Mat> mats_to_infer;
std::vector<deploy::Image> images_to_infer;

auto start_time = std::chrono::system_clock::now(); // Start time for performance measurement

// Prepare data for inference (same as base class)
vp_primary_infer_node::prepare(frame_meta_with_batch, mats_to_infer);
std::transform(mats_to_infer.begin(), mats_to_infer.end(), std::back_inserter(images_to_infer), [](cv::Mat& mat) {
return deploy::Image(mat.data, mat.cols, mat.rows);
});

auto prepare_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - start_time);

start_time = std::chrono::system_clock::now();

// Perform inference on prepared images
std::vector<deploy::DetectionResult> detection_results = detector->predict(images_to_infer);

// Process detection results and update frame metadata
for (int i = 0; i < detection_results.size(); i++) {
auto& frame_meta = frame_meta_with_batch[i];
auto& detection_result = detection_results[i];

for (int j = 0; j < detection_result.num; j++) {
int x = static_cast<int>(detection_result.boxes[j].left);
int y = static_cast<int>(detection_result.boxes[j].top);
int width = static_cast<int>(detection_result.boxes[j].right - detection_result.boxes[j].left);
int height = static_cast<int>(detection_result.boxes[j].bottom - detection_result.boxes[j].top);
auto label = labels.size() == 0 ? "" : labels[detection_result.classes[j]];

// Create target and update back into frame meta
auto target = std::make_shared<vp_objects::vp_frame_target>(
x, y, width, height, detection_result.classes[j], detection_result.scores[j],
frame_meta->frame_index, frame_meta->channel_index, label);

frame_meta->targets.push_back(target);
}
}

auto infer_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now() - start_time);

// Cannot calculate preprocess time and postprocess time, set 0 by default.
vp_infer_node::infer_combinations_time_cost(mats_to_infer.size(), prepare_time.count(), 0, infer_time.count(), 0);
}

void vp_trtyolo_detector::postprocess(const std::vector<cv::Mat>& raw_outputs, const std::vector<std::shared_ptr<vp_objects::vp_frame_meta>>& frame_meta_with_batch) {
// Placeholder for postprocessing logic if needed in future enhancements
// Currently not implemented in this class
}

} // namespace vp_nodes
Loading

0 comments on commit 01f8c64

Please sign in to comment.