add tracking demo with gd 1.5

This commit is contained in:
rentainhe
2024-08-05 16:05:31 +08:00
parent 41640f4add
commit 2cdd3f2d92
3 changed files with 202 additions and 5 deletions

View File

@@ -1,5 +1,4 @@
import numpy as np
from scipy.ndimage import center_of_mass
def sample_points_from_masks(masks, num_points):
"""
@@ -16,7 +15,7 @@ def sample_points_from_masks(masks, num_points):
points = []
for i in range(n):
# 找到当前mask中值为1的位置的坐标
# find the valid mask points
indices = np.argwhere(masks[i] == 1)
# the output format of np.argwhere is (y, x) and the shape is (num_points, 2)
# we should convert it to (x, y)
@@ -24,11 +23,11 @@ def sample_points_from_masks(masks, num_points):
# import pdb; pdb.set_trace()
if len(indices) == 0:
# 如果没有有效点,返回一个空数组
# if there are no valid points, append an empty array
points.append(np.array([]))
continue
# 如果mask中的点少于需要的数量则重复采样
# resampling if there's not enough points
if len(indices) < num_points:
sampled_indices = np.random.choice(len(indices), num_points, replace=True)
else:
@@ -37,6 +36,6 @@ def sample_points_from_masks(masks, num_points):
sampled_points = indices[sampled_indices]
points.append(sampled_points)
# 将结果转换为numpy数组
# convert to np.array
points = np.array(points, dtype=np.float32)
return points