Skip to content

Commit

Permalink
Update experimental.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WongKinYiu authored and Nimeshs54 committed Sep 7, 2024
1 parent 74d5951 commit ea9ed60
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,109 @@ def forward(self, x, augment=False, profile=False, visualize=False):
return y, None # inference, train output


class TRT_NMS(torch.autograd.Function):
'''TensorRT NMS operation'''
@staticmethod
def forward(
ctx,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25,
):

batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
return num_det, det_boxes, det_scores, det_classes

@staticmethod
def symbolic(g,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25):
out = g.op("TRT::EfficientNMS_TRT",
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums, boxes, scores, classes



class ONNX_TRT(nn.Module):
'''onnx module with TensorRT NMS operation.'''
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
super().__init__()
assert max_wh is None
self.device = device if device else torch.device('cpu')
self.background_class = -1,
self.box_coding = 1,
self.iou_threshold = iou_thres
self.max_obj = max_obj
self.plugin_version = '1'
self.score_activation = 0
self.score_threshold = score_thres
self.n_classes=n_classes

def forward(self, x):
## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py
## thanks https://github.com/thaitc-hust
if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list
x = x[1]
x = x.permute(0, 2, 1)
bboxes_x = x[..., 0:1]
bboxes_y = x[..., 1:2]
bboxes_w = x[..., 2:3]
bboxes_h = x[..., 3:4]
bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
obj_conf = x[..., 4:]
scores = obj_conf
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(bboxes, scores, self.background_class, self.box_coding,
self.iou_threshold, self.max_obj,
self.plugin_version, self.score_activation,
self.score_threshold)
return num_det, det_boxes, det_scores, det_classes

class End2End(nn.Module):
'''export onnx or tensorrt model with NMS operation.'''
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80):
super().__init__()
device = device if device else torch.device('cpu')
assert isinstance(max_wh,(int)) or max_wh is None
self.model = model.to(device)
self.model.model[-1].end2end = True
self.patch_model = ONNX_TRT
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes)
self.end2end.eval()

def forward(self, x):
x = self.model(x)
x = self.end2end(x)
return x


def attempt_load(weights, device=None, inplace=True, fuse=True):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from models.yolo import Detect, Model
Expand Down

0 comments on commit ea9ed60

Please sign in to comment.