以下是您可以做的:
基础图像 = 货架的整体图片
模板图像 = 单个产品图像
- 从两个图像(基础图像和模板图像)获取SIFT匹配。
- 进行特征匹配。
- 获取所有与模板图像匹配的基础图像中的点。(参见图例)
- 根据模板图像的大小创建聚类。(这里阈值为50px)
- 获取聚类的边界框。
- 裁剪每个边界框聚类并检查与模板图像的匹配情况。
- 接受至少有最小百分比匹配的所有聚类。(这里取关键点的最小10%)
def plot_pts(img, pts):
img_plot = img.copy()
for i in range(len(pts)):
img_plot = cv2.circle(img_plot, (int(pts[i][0]), int(pts[i][1])), radius=7, color=(255, 0, 0), thickness=-1)
plt.figure(figsize=(20, 10))
plt.imshow(img_plot)
def plot_bbox(img, bbox_list):
img_plot = img.copy()
for i in range(len(bbox_list)):
start_pt = bbox_list[i][0]
end_pt = bbox_list[i][2]
img_plot = cv2.rectangle(img_plot, pt1=start_pt, pt2=end_pt, color=(255, 0, 0), thickness=2)
plt.figure(figsize=(20, 10))
plt.imshow(img_plot)
def get_distance(pt1, pt2):
x1, y1 = pt1
x2, y2 = pt2
return np.sqrt(np.square(x1 - x2) + np.square(y1 - y2))
def check_centroid(pt, centroid):
x, y = pt
cx, cy = centroid
distance = get_distance(pt1=(x, y), pt2=(cx, cy))
if distance < max_distance:
return True
else:
return False
def update_centroid(pt, centroids_list):
new_centroids_list = centroids_list.copy()
flag_new_centroid = True
for j, c in enumerate(centroids_list):
temp_centroid = np.mean(c, axis=0)
if_close = check_centroid(pt, temp_centroid)
if if_close:
new_centroids_list[j].append(pt)
flag_new_centroid = False
break
if flag_new_centroid:
new_centroids_list.append([pt])
new_centroids_list = recheck_centroid(new_centroids_list)
return new_centroids_list
def recheck_centroid(centroids_list):
new_centroids_list = [list(set(c)) for c in centroids_list]
return new_centroids_list
def get_bbox(pts):
minn_x, minn_y = np.min(pts, axis=0)
maxx_x, maxx_y = np.max(pts, axis=0)
return [[minn_x, minn_y], [maxx_x, minn_y], [maxx_x, maxx_y], [minn_x, maxx_y]]
class RotateAndTransform:
def __init__(self, path_img_ref):
self.path_img_ref = path_img_ref
self.ref_img = self._read_ref_image()
self.sift = cv2.SIFT_create()
self.bf = cv2.BFMatcher()
FLANN_INDEX_KDTREE = 1
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
search_params = dict(checks=50)
self.flann = cv2.FlannBasedMatcher(index_params,search_params)
def _read_ref_image(self):
ref_img = cv2.imread(self.path_img_ref, cv2.IMREAD_COLOR)
ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
return ref_img
def read_src_image(self, path_img_src):
self.path_img_src = path_img_src
src_img = cv2.imread(path_img_src, cv2.IMREAD_COLOR)
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
return src_img
def convert_bw(self, img):
img_bw = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
return img_bw
def get_keypoints_descriptors(self, img_bw):
keypoints, descriptors = self.sift.detectAndCompute(img_bw,None)
return keypoints, descriptors
def get_matches(self, src_descriptors, ref_descriptors, threshold=0.6):
matches = self.bf.knnMatch(ref_descriptors, src_descriptors, k=2)
flann_matches = self.flann.knnMatch(ref_descriptors, src_descriptors,k=2)
good_matches = []
good_flann_matches = []
for m,n in matches:
if m.distance <threshold*n.distance:
good_matches.append([m])
print(f'Numner of BF Match: {len(matches)}, Number of good BF Match: {len(good_matches)}')
for m,n in flann_matches:
if m.distance < threshold*n.distance:
good_flann_matches.append([m])
print(f'Numner of FLANN Match: {len(flann_matches)}, Number of good Flann Match: {len(good_flann_matches)}')
return good_matches, good_flann_matches
def get_src_dst_pts(self, good_flann_matches, ref_keypoints, src_keypoints):
pts_src = []
pts_ref = []
n = len(good_flann_matches)
for i in range(n):
ref_index = good_flann_matches[i][0].queryIdx
src_index = good_flann_matches[i][0].trainIdx
pts_src.append(src_keypoints[src_index].pt)
pts_ref.append(ref_keypoints[ref_index].pt)
return np.array(pts_src), np.array(pts_ref)
def extend_bbox(bbox, increment=0.1):
bbox_new = bbox.copy()
bbox_new[0] = [bbox_new[0][0] - int(bbox_new[0][0] * increment), bbox_new[0][1] - int(bbox_new[0][1] * increment)]
bbox_new[1] = [bbox_new[1][0] + int(bbox_new[1][0] * increment), bbox_new[1][1] - int(bbox_new[1][1] * increment)]
bbox_new[2] = [bbox_new[2][0] + int(bbox_new[2][0] * increment), bbox_new[2][1] + int(bbox_new[2][1] * increment)]
bbox_new[3] = [bbox_new[3][0] - int(bbox_new[3][0] * increment), bbox_new[3][1] + int(bbox_new[3][1] * increment)]
return bbox_new
def crop_bbox(img, bbox):
y, x = bbox[0]
h, w = bbox[1][0] - bbox[0][0], bbox[2][1] - bbox[0][1]
return img[x: x + w, y: y + h, :]
base_img = cv2.imread(path_img_base)
ref_img = cv2.imread(path_img_ref)
rnt = RotateAndTransform(path_img_ref)
ref_img_bw = rnt.convert_bw(img=rnt.ref_img)
ref_keypoints, ref_descriptors = rnt.get_keypoints_descriptors(ref_img_bw)
base_img = rnt.read_src_image(path_img_src = path_img_base)
base_img_bw = rnt.convert_bw(img=base_img)
base_keypoints, base_descriptors = rnt.get_keypoints_descriptors(base_img_bw)
good_matches, good_flann_matches = rnt.get_matches(src_descriptors=base_descriptors, ref_descriptors=ref_descriptors, threshold=0.6)
ref_points = []
for gm in good_flann_matches:
x, y = ref_keypoints[gm[0].queryIdx].pt
x, y = int(x), int(y)
ref_points.append((x, y))
max_distance = 50
centroids = [[ref_points[0]]]
for i in tqdm(range(len(ref_points))):
pt = ref_points[i]
centroids = update_centroid(pt, centroids)
bbox = [get_bbox(c) for c in centroi[![enter image description here][1]][1]ds]
centroids = [np.mean(c, axis=0) for c in centroids]
print(f'Number of Points: {len(good_flann_matches)}, centroids: {len(centroids)}')
data = []
for i in range(len(bbox)):
temp_crop_img = crop_bbox(ref_img, extend_bbox(bbox[i], 0.01))
temp_crop_img_bw = rnt.convert_bw(img=temp_crop_img)
temp_crop_keypoints, temp_crop_descriptors = rnt.get_keypoints_descriptors(temp_crop_img_bw)
good_matches, good_flann_matches = rnt.get_matches(src_descriptors=base_descriptors, ref_descriptors=temp_crop_descriptors, threshold=0.6)
temp_data = {'image': temp_crop_img,
'num_matched': len(good_flann_matches),
'total_keypoints' : len(base_keypoints),
}
data.append(temp_data)
filter_data = [{'num_matched' : i['num_matched'], 'image': i['image']} for i in data if i['num_matched'] > 25]
for i in range(len(filter_data)):
temp_num_match = filter_data[i]['num_matched']
plt.figure()
plt.title(f'num matched: {temp_num_match}')
plt.imshow(filter_data[i]['image'])
![Results of the method](https://istack.dev59.com/zc2d4.webp)