Skip to content

Commit 0ae2f4b

Browse files
authored
Merge pull request #62 from Pbatch/pb_unit_yolov8
Yolov8 Unit Model
2 parents 19e43ed + 14c09e1 commit 0ae2f4b

3 files changed

Lines changed: 20 additions & 11 deletions

File tree

-68.4 MB
Binary file not shown.

clashroyalebuildabot/state/onnx_detector.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import onnxruntime
21
import numpy as np
2+
import onnxruntime
33

44

55
class OnnxDetector:
@@ -43,22 +43,31 @@ def _nms(boxes, scores, thresh):
4343

4444
return keep
4545

46-
def nms(self, prediction, conf_thres=0.35, iou_thres=0.45):
46+
def nms(self, prediction, conf_thres=0.35, iou_thres=0.45, yolov8=False):
4747
"""
4848
Runs Non-Maximum Suppression (NMS) on inference results
4949
"""
50+
if yolov8:
51+
prediction = prediction.transpose((0, 2, 1))
5052
output = [np.zeros((0, 6))] * len(prediction)
5153
for i in range(len(prediction)):
52-
# Mask out predictions below the confidence threshold
53-
mask = prediction[i, :, 4] > conf_thres
54-
x = prediction[i][mask]
54+
if yolov8:
55+
x = prediction[i]
56+
else:
57+
# Mask out predictions below the confidence threshold
58+
mask = prediction[i, :, 4] > conf_thres
59+
x = prediction[i][mask]
5560

56-
if not x.shape[0]:
57-
continue
61+
if not x.shape[0]:
62+
continue
5863

5964
# Calculate the best scores
60-
# score = object confidence * class confidence
61-
scores = x[:, 4:5] * x[:, 5:]
65+
if yolov8:
66+
# score = class confidence
67+
scores = x[:, 4:]
68+
else:
69+
# score = object confidence * class confidence
70+
scores = x[:, 4:5] * x[:, 5:]
6271
best_scores_idx = np.argmax(scores, axis=1).reshape(-1, 1)
6372
best_scores = np.take_along_axis(scores, best_scores_idx, axis=1)
6473

@@ -76,8 +85,8 @@ def nms(self, prediction, conf_thres=0.35, iou_thres=0.45):
7685

7786
# Keep only the best class
7887
best = np.hstack([boxes[keep], best_scores[keep], best_scores_idx[keep]])
79-
output[i] = best
8088

89+
output[i] = best
8190
return output
8291

8392
def _post_process(self, pred):

clashroyalebuildabot/state/unit_detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def run(self, image):
9090
pred = self.sess.run([self.output_name], {self.input_name: np_image})[0]
9191

9292
# Forced post-processing
93-
pred = np.array(self.nms(pred)[0])
93+
pred = np.array(self.nms(pred, yolov8=True)[0])
9494
pred[:, [0, 2]] *= width / UNIT_SIZE
9595
pred[:, [1, 3]] *= height / UNIT_SIZE
9696

0 commit comments

Comments
 (0)