From ea9ed60a33d69ef9169b9220299f29e2c993bd4d Mon Sep 17 00:00:00 2001 From: "Kin-Yiu, Wong" Date: Thu, 7 Mar 2024 11:42:05 +0800 Subject: [PATCH] Update experimental.py https://github.com/WongKinYiu/yolov9/pull/189 --- models/experimental.py | 103 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/models/experimental.py b/models/experimental.py index ae087c5..80ee2e0 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -66,6 +66,109 @@ def forward(self, x, augment=False, profile=False, visualize=False): return y, None # inference, train output +class TRT_NMS(torch.autograd.Function): + '''TensorRT NMS operation''' + @staticmethod + def forward( + ctx, + boxes, + scores, + background_class=-1, + box_coding=1, + iou_threshold=0.45, + max_output_boxes=100, + plugin_version="1", + score_activation=0, + score_threshold=0.25, + ): + + batch_size, num_boxes, num_classes = scores.shape + num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) + det_boxes = torch.randn(batch_size, max_output_boxes, 4) + det_scores = torch.randn(batch_size, max_output_boxes) + det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) + return num_det, det_boxes, det_scores, det_classes + + @staticmethod + def symbolic(g, + boxes, + scores, + background_class=-1, + box_coding=1, + iou_threshold=0.45, + max_output_boxes=100, + plugin_version="1", + score_activation=0, + score_threshold=0.25): + out = g.op("TRT::EfficientNMS_TRT", + boxes, + scores, + background_class_i=background_class, + box_coding_i=box_coding, + iou_threshold_f=iou_threshold, + max_output_boxes_i=max_output_boxes, + plugin_version_s=plugin_version, + score_activation_i=score_activation, + score_threshold_f=score_threshold, + outputs=4) + nums, boxes, scores, classes = out + return nums, boxes, scores, classes + + + +class ONNX_TRT(nn.Module): + '''onnx module with TensorRT NMS operation.''' + def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80): + super().__init__() + assert max_wh is None + self.device = device if device else torch.device('cpu') + self.background_class = -1, + self.box_coding = 1, + self.iou_threshold = iou_thres + self.max_obj = max_obj + self.plugin_version = '1' + self.score_activation = 0 + self.score_threshold = score_thres + self.n_classes=n_classes + + def forward(self, x): + ## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py + ## thanks https://github.com/thaitc-hust + if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list + x = x[1] + x = x.permute(0, 2, 1) + bboxes_x = x[..., 0:1] + bboxes_y = x[..., 1:2] + bboxes_w = x[..., 2:3] + bboxes_h = x[..., 3:4] + bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1) + bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4] + obj_conf = x[..., 4:] + scores = obj_conf + num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(bboxes, scores, self.background_class, self.box_coding, + self.iou_threshold, self.max_obj, + self.plugin_version, self.score_activation, + self.score_threshold) + return num_det, det_boxes, det_scores, det_classes + +class End2End(nn.Module): + '''export onnx or tensorrt model with NMS operation.''' + def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80): + super().__init__() + device = device if device else torch.device('cpu') + assert isinstance(max_wh,(int)) or max_wh is None + self.model = model.to(device) + self.model.model[-1].end2end = True + self.patch_model = ONNX_TRT + self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes) + self.end2end.eval() + + def forward(self, x): + x = self.model(x) + x = self.end2end(x) + return x + + def attempt_load(weights, device=None, inplace=True, fuse=True): # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a from models.yolo import Detect, Model