Software Stack
DepthAI
  • DepthAI Components
    • AprilTags
    • Benchmark
    • Camera
    • DetectionNetwork
    • EdgeDetector
    • Events
    • FeatureTracker
    • HostNodes
    • ImageAlign
    • ImageManip
    • IMU
    • Misc
    • Model Zoo
    • NeuralNetwork
    • RecordReplay
    • RGBD
    • Script
    • SpatialDetectionNetwork
    • SpatialLocationCalculator
    • StereoDepth
    • Sync
    • SystemLogger
    • VideoEncoder
    • Visualizer
    • Warp
    • RVC2-specific
  • Advanced Tutorials
  • API Reference
  • Tools

ON THIS PAGE

  • Spatial Detection Network
  • Pipeline
  • Source code

Spatial Detection Network

The example creates a pipeline to perform YOLOv6-Nano spatial object detection using RGB and stereo depth streams, visualizes results with bounding boxes and spatial coordinates on both colorized depth and RGB frames, and uses a custom visualization node.This example requires the DepthAI v3 API, see installation instructions.

Pipeline

Source code

Python
C++

Python

Python
GitHub
1#!/usr/bin/env python3
2
3import argparse
4from pathlib import Path
5import cv2
6import depthai as dai
7import numpy as np
8
9NEURAL_FPS = 8
10STEREO_DEFAULT_FPS = 30
11
12parser = argparse.ArgumentParser()
13parser.add_argument(
14    "--depthSource", type=str, default="stereo", choices=["stereo", "neural"]
15)
16args = parser.parse_args()
17
18modelDescription = dai.NNModelDescription("yolov6-nano")
19size = (640, 400)
20
21if args.depthSource == "stereo":
22    fps = STEREO_DEFAULT_FPS
23else:
24    fps = NEURAL_FPS
25
26class SpatialVisualizer(dai.node.HostNode):
27    def __init__(self):
28        dai.node.HostNode.__init__(self)
29        self.sendProcessingToPipeline(True)
30    def build(self, depth:dai.Node.Output, detections: dai.Node.Output, rgb: dai.Node.Output):
31        self.link_args(depth, detections, rgb) # Must match the inputs to the process method
32
33    def process(self, depthPreview, detections, rgbPreview):
34        depthPreview = depthPreview.getCvFrame()
35        rgbPreview = rgbPreview.getCvFrame()
36        depthFrameColor = self.processDepthFrame(depthPreview)
37        self.displayResults(rgbPreview, depthFrameColor, detections.detections)
38
39    def processDepthFrame(self, depthFrame):
40        depth_downscaled = depthFrame[::4]
41        if np.all(depth_downscaled == 0):
42            min_depth = 0
43        else:
44            min_depth = np.percentile(depth_downscaled[depth_downscaled != 0], 1)
45        max_depth = np.percentile(depth_downscaled, 99)
46        depthFrameColor = np.interp(depthFrame, (min_depth, max_depth), (0, 255)).astype(np.uint8)
47        return cv2.applyColorMap(depthFrameColor, cv2.COLORMAP_HOT)
48
49    def displayResults(self, rgbFrame, depthFrameColor, detections):
50        height, width, _ = rgbFrame.shape
51        for detection in detections:
52            self.drawBoundingBoxes(depthFrameColor, detection)
53            self.drawDetections(rgbFrame, detection, width, height)
54
55        cv2.imshow("depth", depthFrameColor)
56        cv2.imshow("rgb", rgbFrame)
57        if cv2.waitKey(1) == ord('q'):
58            self.stopPipeline()
59
60    def drawBoundingBoxes(self, depthFrameColor, detection):
61        roiData = detection.boundingBoxMapping
62        roi = roiData.roi
63        roi = roi.denormalize(depthFrameColor.shape[1], depthFrameColor.shape[0])
64        topLeft = roi.topLeft()
65        bottomRight = roi.bottomRight()
66        cv2.rectangle(depthFrameColor, (int(topLeft.x), int(topLeft.y)), (int(bottomRight.x), int(bottomRight.y)), (255, 255, 255), 1)
67
68    def drawDetections(self, frame, detection, frameWidth, frameHeight):
69        x1 = int(detection.xmin * frameWidth)
70        x2 = int(detection.xmax * frameWidth)
71        y1 = int(detection.ymin * frameHeight)
72        y2 = int(detection.ymax * frameHeight)
73        label = detection.labelName
74        color = (255, 255, 255)
75        cv2.putText(frame, str(label), (x1 + 10, y1 + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
76        cv2.putText(frame, "{:.2f}".format(detection.confidence * 100), (x1 + 10, y1 + 35), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
77        cv2.putText(frame, f"X: {int(detection.spatialCoordinates.x)} mm", (x1 + 10, y1 + 50), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
78        cv2.putText(frame, f"Y: {int(detection.spatialCoordinates.y)} mm", (x1 + 10, y1 + 65), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
79        cv2.putText(frame, f"Z: {int(detection.spatialCoordinates.z)} mm", (x1 + 10, y1 + 80), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
80        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 1)
81
82# Creates the pipeline and a default device implicitly
83with dai.Pipeline() as p:
84    # Define sources and outputs
85    platform = p.getDefaultDevice().getPlatform()
86
87    camRgb = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_A, sensorFps=fps)
88    monoLeft = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_B, sensorFps=fps)
89    monoRight = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_C, sensorFps=fps)
90    if args.depthSource == "stereo":
91        depthSource = p.create(dai.node.StereoDepth)
92        depthSource.setExtendedDisparity(True)
93        if platform == dai.Platform.RVC2:
94            depthSource.setOutputSize(640, 400)
95        monoLeft.requestOutput(size).link(depthSource.left)
96        monoRight.requestOutput(size).link(depthSource.right)
97    elif args.depthSource == "neural":
98        depthSource = p.create(dai.node.NeuralDepth).build(
99            monoLeft.requestFullResolutionOutput(),
100            monoRight.requestFullResolutionOutput(),
101            dai.DeviceModelZoo.NEURAL_DEPTH_LARGE,
102        )
103    else:
104        raise ValueError(f"Invalid depth source: {args.depthSource}")
105
106    spatialDetectionNetwork = p.create(dai.node.SpatialDetectionNetwork).build(
107        camRgb, depthSource, modelDescription
108    )
109    visualizer = p.create(SpatialVisualizer)
110
111    spatialDetectionNetwork.input.setBlocking(False)
112    spatialDetectionNetwork.setBoundingBoxScaleFactor(0.5)
113    spatialDetectionNetwork.setDepthLowerThreshold(100)
114    spatialDetectionNetwork.setDepthUpperThreshold(5000)
115
116    visualizer.build(
117        spatialDetectionNetwork.passthroughDepth,
118        spatialDetectionNetwork.out,
119        spatialDetectionNetwork.passthrough,
120    )
121
122    print("Starting pipeline with depth source: ", args.depthSource)
123
124    p.run()

Need assistance?

Head over to Discussion Forum for technical support or any other questions you might have.