Skip to content
8 changes: 5 additions & 3 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,14 @@ def call(self, inputs):
x.append(self.m[i](inputs[i]))
# x(bs,20,20,255) to x(bs,3,20,20,85)
ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])
x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])

if not self.training: # inference
y = tf.sigmoid(x[i])
xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
grid = tf.transpose(self.grid[i], [0, 2, 1, 3])
anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3])
xy = (y[..., 0:2] * 2 - 0.5 + grid) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * anchor_grid
# Normalize xywh to 0-1 to reduce calibration error
xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
Expand Down