Compare commits
2 commits
482e8cc776
...
eb5b123cb1
Author | SHA1 | Date | |
---|---|---|---|
eb5b123cb1 | |||
b80a44cedd |
7 changed files with 149 additions and 68 deletions
|
@ -24,17 +24,6 @@ def 提取特征点(模板):
|
||||||
return hog_features
|
return hog_features
|
||||||
|
|
||||||
|
|
||||||
def 训练knn模型(images, labels):
|
|
||||||
knn_classifier = KNeighborsClassifier(weights="distance", n_neighbors=1, n_jobs=1)
|
|
||||||
knn_classifier.fit(images, labels)
|
|
||||||
return knn_classifier
|
|
||||||
|
|
||||||
|
|
||||||
def 保存knn模型(classifier, filename):
|
|
||||||
with lzma.open(filename, "wb") as f:
|
|
||||||
pickle.dump(classifier, f)
|
|
||||||
|
|
||||||
|
|
||||||
class DepotMatcher:
|
class DepotMatcher:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -154,6 +143,73 @@ class DepotMatcher:
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def 训练并保存knn模型(self, images, labels, filename):
|
||||||
|
"""训练并保存KNN模型"""
|
||||||
|
knn_classifier = KNeighborsClassifier(
|
||||||
|
weights="distance", n_neighbors=1, n_jobs=1
|
||||||
|
)
|
||||||
|
knn_classifier.fit(images, labels)
|
||||||
|
|
||||||
|
with lzma.open(filename, "wb") as f:
|
||||||
|
pickle.dump(knn_classifier, f)
|
||||||
|
|
||||||
|
return knn_classifier
|
||||||
|
|
||||||
|
def 训练knn模型(self, 模型保存路径="depot_knn_model.xz"):
|
||||||
|
"""训练并保存KNN模型"""
|
||||||
|
if not self.refs:
|
||||||
|
print("错误:请先加载参考图像!")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 准备训练数据
|
||||||
|
images = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
for name, img_array in self.refs.items():
|
||||||
|
# 提取HOG特征
|
||||||
|
features = 提取特征点(img_array)
|
||||||
|
images.append(features)
|
||||||
|
labels.append(name)
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
self.knn_model = self.训练并保存knn模型(images, labels, 模型保存路径)
|
||||||
|
print(f"KNN模型训练完成,已保存到: {模型保存路径}")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def 使用knn预测(self, 测试图像):
|
||||||
|
"""使用训练好的KNN模型进行预测"""
|
||||||
|
if not hasattr(self, "knn_model"):
|
||||||
|
print("错误:请先训练KNN模型!")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# 提取测试图像的特征
|
||||||
|
features = 提取特征点(测试图像)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
预测结果 = self.knn_model.predict([features])
|
||||||
|
置信度 = self.knn_model.predict_proba([features])
|
||||||
|
|
||||||
|
return 预测结果[0], 置信度[0]
|
||||||
|
|
||||||
|
def match_knn(self):
|
||||||
|
"""KNN匹配方法"""
|
||||||
|
if not hasattr(self, "knn_model"):
|
||||||
|
print("错误:请先训练KNN模型!")
|
||||||
|
return self
|
||||||
|
|
||||||
|
self.knn_results = []
|
||||||
|
|
||||||
|
for idx, gray_roi_np, _ in self.rois:
|
||||||
|
# 使用KNN进行预测
|
||||||
|
预测名称, 置信度 = self.使用knn预测(gray_roi_np)
|
||||||
|
|
||||||
|
# 获取最高置信度的值
|
||||||
|
max_conf = np.max(置信度) if 置信度 is not None else 0
|
||||||
|
|
||||||
|
self.knn_results.append((idx, 预测名称, max_conf))
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
def match_template(self, threshold=0.2):
|
def match_template(self, threshold=0.2):
|
||||||
"""模板匹配方法"""
|
"""模板匹配方法"""
|
||||||
self.template_results = []
|
self.template_results = []
|
||||||
|
@ -257,7 +313,7 @@ class DepotMatcher:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def display_results(self):
|
def display_results(self):
|
||||||
"""可视化匹配结果,直接使用内存中的图像数据"""
|
"""可视化匹配结果,现在包含KNN结果"""
|
||||||
ROW_LIMIT = 9
|
ROW_LIMIT = 9
|
||||||
|
|
||||||
# 获取一个参考图像的尺寸作为空白图像的基础
|
# 获取一个参考图像的尺寸作为空白图像的基础
|
||||||
|
@ -271,14 +327,13 @@ class DepotMatcher:
|
||||||
current_row_width = 0
|
current_row_width = 0
|
||||||
max_row_height = 0
|
max_row_height = 0
|
||||||
|
|
||||||
# self.rois 现在包含 (idx, gray_roi_data, color_roi_data)
|
|
||||||
for idx, gray_roi_data, color_roi_data in self.rois:
|
for idx, gray_roi_data, color_roi_data in self.rois:
|
||||||
color_roi_img = Image.fromarray(
|
color_roi_img = Image.fromarray(
|
||||||
cv2.cvtColor(color_roi_data, cv2.COLOR_BGR2RGB)
|
cv2.cvtColor(color_roi_data, cv2.COLOR_BGR2RGB)
|
||||||
)
|
)
|
||||||
|
|
||||||
gray_roi_img = Image.fromarray(gray_roi_data).convert("L").convert("RGB")
|
gray_roi_img = Image.fromarray(gray_roi_data).convert("L").convert("RGB")
|
||||||
|
|
||||||
|
# 获取模板匹配结果
|
||||||
t_res_name = next(
|
t_res_name = next(
|
||||||
(name for i, name, val in self.template_results if i == idx), None
|
(name for i, name, val in self.template_results if i == idx), None
|
||||||
)
|
)
|
||||||
|
@ -291,6 +346,7 @@ class DepotMatcher:
|
||||||
else blank_img_pil.copy()
|
else blank_img_pil.copy()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 获取SSIM匹配结果
|
||||||
s_res_details = next((d for d in self.ssim_results if d[0] == idx), None)
|
s_res_details = next((d for d in self.ssim_results if d[0] == idx), None)
|
||||||
s_res_name = s_res_details[1] if s_res_details else None
|
s_res_name = s_res_details[1] if s_res_details else None
|
||||||
s_ref_img = (
|
s_ref_img = (
|
||||||
|
@ -299,18 +355,32 @@ class DepotMatcher:
|
||||||
else blank_img_pil.copy()
|
else blank_img_pil.copy()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 获取KNN匹配结果
|
||||||
|
k_res_details = next(
|
||||||
|
(d for d in getattr(self, "knn_results", []) if d[0] == idx), None
|
||||||
|
)
|
||||||
|
k_res_name = k_res_details[1] if k_res_details else None
|
||||||
|
k_ref_img = (
|
||||||
|
Image.fromarray(self.refs[k_res_name]).convert("RGB")
|
||||||
|
if k_res_name and k_res_name in self.refs
|
||||||
|
else blank_img_pil.copy()
|
||||||
|
)
|
||||||
|
k_res_val = k_res_details[2] if k_res_details else 0
|
||||||
|
|
||||||
# 计算组合尺寸
|
# 计算组合尺寸
|
||||||
combined_width = (
|
combined_width = (
|
||||||
color_roi_img.width
|
color_roi_img.width
|
||||||
+ gray_roi_img.width
|
+ gray_roi_img.width
|
||||||
+ t_ref_img.width
|
+ t_ref_img.width
|
||||||
+ s_ref_img.width
|
+ s_ref_img.width
|
||||||
|
+ k_ref_img.width
|
||||||
)
|
)
|
||||||
combined_height = max(
|
combined_height = max(
|
||||||
color_roi_img.height,
|
color_roi_img.height,
|
||||||
gray_roi_img.height,
|
gray_roi_img.height,
|
||||||
t_ref_img.height,
|
t_ref_img.height,
|
||||||
s_ref_img.height,
|
s_ref_img.height,
|
||||||
|
k_ref_img.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建组合图像
|
# 创建组合图像
|
||||||
|
@ -330,24 +400,26 @@ class DepotMatcher:
|
||||||
x_offset += t_ref_img.width
|
x_offset += t_ref_img.width
|
||||||
|
|
||||||
combined.paste(s_ref_img, (x_offset, 0))
|
combined.paste(s_ref_img, (x_offset, 0))
|
||||||
|
x_offset += s_ref_img.width
|
||||||
|
|
||||||
|
combined.paste(k_ref_img, (x_offset, 0))
|
||||||
|
x_offset += k_ref_img.width
|
||||||
|
|
||||||
# 添加标注
|
# 添加标注
|
||||||
draw = ImageDraw.Draw(combined)
|
draw = ImageDraw.Draw(combined)
|
||||||
|
|
||||||
font = ImageFont.truetype("msyh.ttc", 16)
|
font = ImageFont.truetype("msyh.ttc", 16)
|
||||||
|
|
||||||
label = (
|
label = (
|
||||||
f"ROI {idx} SSIM/H/E:{s_res_details[3]:.2f}/{s_res_details[4]:.2f}/{s_res_details[5]:.2f}\n"
|
f"ROI {idx}\n"
|
||||||
f"T: {t_res_name or 'None'} ({t_res_val:.2f})\n"
|
f"T: {t_res_name or 'None'} ({t_res_val:.2f})\n"
|
||||||
f"S: {s_res_details[1] or 'None'} ({s_res_details[2]:.2f})\n"
|
f"S: {s_res_name or 'None'} ({s_res_details[2]:.2f})\n"
|
||||||
)
|
f"K: {k_res_name or 'None'} ({k_res_val:.2f})"
|
||||||
# 文本颜色:如果模板匹配和SSIM匹配结果一致,则为红,否则为品红
|
|
||||||
text_color = (
|
|
||||||
(255, 0, 0) if t_res_name == s_res_details[1] else (255, 0, 255)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调整标签位置到模板匹配参考图的左上角
|
# 文本颜色:如果三种方法结果一致则为红色,否则为黑色
|
||||||
# (起始 X 坐标是彩色ROI和灰度ROI的宽度之和)
|
text_color = (
|
||||||
|
(255, 0, 0) if t_res_name == s_res_name == k_res_name else (0, 0, 0)
|
||||||
|
)
|
||||||
|
|
||||||
draw.text(
|
draw.text(
|
||||||
(color_roi_img.width, gray_roi_img.height),
|
(color_roi_img.width, gray_roi_img.height),
|
||||||
|
@ -372,12 +444,11 @@ class DepotMatcher:
|
||||||
row_img.paste(img, (x, 0))
|
row_img.paste(img, (x, 0))
|
||||||
x += img.width
|
x += img.width
|
||||||
combined_images.append(row_img)
|
combined_images.append(row_img)
|
||||||
# 重置行变量
|
|
||||||
current_row_images = []
|
current_row_images = []
|
||||||
current_row_width = 0
|
current_row_width = 0
|
||||||
max_row_height = 0
|
max_row_height = 0
|
||||||
|
|
||||||
# 处理最后一行(如果未满 ROW_LIMIT)
|
# 处理最后一行
|
||||||
if current_row_images:
|
if current_row_images:
|
||||||
row_img = Image.new(
|
row_img = Image.new(
|
||||||
"RGB", (current_row_width, max_row_height), (255, 255, 255)
|
"RGB", (current_row_width, max_row_height), (255, 255, 255)
|
||||||
|
@ -388,44 +459,35 @@ class DepotMatcher:
|
||||||
x += img.width
|
x += img.width
|
||||||
combined_images.append(row_img)
|
combined_images.append(row_img)
|
||||||
|
|
||||||
# 如果没有生成任何图像行,则退出
|
|
||||||
if not combined_images:
|
|
||||||
print("没有生成任何结果图像行。")
|
|
||||||
return self
|
|
||||||
|
|
||||||
# 生成最终图像
|
# 生成最终图像
|
||||||
|
if combined_images:
|
||||||
total_height = sum(img.height for img in combined_images)
|
total_height = sum(img.height for img in combined_images)
|
||||||
# 使用第一行的宽度作为所有行的最大宽度(或计算所有行的最大宽度)
|
max_width = max(img.width for img in combined_images)
|
||||||
max_width = combined_images[0].width if combined_images else 0
|
final_img = Image.new("RGB", (max_width, total_height), (255, 255, 255))
|
||||||
# 如果行宽不一致,可能需要调整逻辑或使用最大宽度
|
|
||||||
final_img = Image.new(
|
|
||||||
"RGB",
|
|
||||||
(max_width, total_height),
|
|
||||||
(255, 255, 255),
|
|
||||||
)
|
|
||||||
|
|
||||||
y = 0
|
y = 0
|
||||||
for img in combined_images:
|
for img in combined_images:
|
||||||
# 确保粘贴时宽度不超过 final_img 的宽度
|
|
||||||
final_img.paste(img, (0, y))
|
final_img.paste(img, (0, y))
|
||||||
y += img.height
|
y += img.height
|
||||||
|
|
||||||
# 保存最终结果图像
|
|
||||||
output_path = "depot_test/output/matches_all.png"
|
output_path = "depot_test/output/matches_all.png"
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True) # 确保目录存在
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
final_img.save(output_path)
|
final_img.save(output_path)
|
||||||
print(f"结果图像已保存至: {output_path}")
|
print(f"结果图像已保存至: {output_path}")
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 使用示例
|
# 使用示例
|
||||||
matcher = DepotMatcher()
|
matcher = DepotMatcher()
|
||||||
|
|
||||||
(
|
(
|
||||||
matcher.load_references()
|
matcher.load_references()
|
||||||
.detect_and_crop()
|
.detect_and_crop()
|
||||||
.match_template(threshold=0.2)
|
.match_template(threshold=0.2)
|
||||||
.match_ssim(threshold=0.05)
|
.match_ssim(threshold=0.05)
|
||||||
|
.match_knn()
|
||||||
.display_results()
|
.display_results()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -150,9 +150,13 @@ def match_ssim(rois, refs, thresh):
|
||||||
dynamic_thresh = thresh * (1 + 0.3 * roi_complexity)
|
dynamic_thresh = thresh * (1 + 0.3 * roi_complexity)
|
||||||
|
|
||||||
if max_combined_score >= dynamic_thresh:
|
if max_combined_score >= dynamic_thresh:
|
||||||
results.append((idx, best_match, max_combined_score, best_ssim, best_hist, best_edge))
|
results.append(
|
||||||
|
(idx, best_match, max_combined_score, best_ssim, best_hist, best_edge)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
results.append((idx, None, max_combined_score, best_ssim, best_hist, best_edge))
|
results.append(
|
||||||
|
(idx, None, max_combined_score, best_ssim, best_hist, best_edge)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -281,11 +285,21 @@ if __name__ == "__main__":
|
||||||
t_val = next((val for i, name, val in results_template if i == idx), 0)
|
t_val = next((val for i, name, val in results_template if i == idx), 0)
|
||||||
|
|
||||||
# SSIM results - now includes detailed metrics
|
# SSIM results - now includes detailed metrics
|
||||||
s_res = next((name for i, name, val, ssim, hist, edge in results_ssim if i == idx), None)
|
s_res = next(
|
||||||
s_val = next((val for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0)
|
(name for i, name, val, ssim, hist, edge in results_ssim if i == idx), None
|
||||||
s_ssim = next((ssim for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0)
|
)
|
||||||
s_hist = next((hist for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0)
|
s_val = next(
|
||||||
s_edge = next((edge for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0)
|
(val for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||||
|
)
|
||||||
|
s_ssim = next(
|
||||||
|
(ssim for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||||
|
)
|
||||||
|
s_hist = next(
|
||||||
|
(hist for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||||
|
)
|
||||||
|
s_edge = next(
|
||||||
|
(edge for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||||
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"ROI {idx}:\n"
|
f"ROI {idx}:\n"
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datetime import datetime
|
|
||||||
from skimage.metrics import structural_similarity
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import json
|
from skimage.metrics import structural_similarity
|
||||||
from multiprocessing import Pool
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
# 配置路径
|
# 配置路径
|
||||||
REF_DIR = r"depot_test\output/test/origin"
|
REF_DIR = r"depot_test\output/test/origin"
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import cv2
|
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
def stitch_panorama(input_dir, output_dir):
|
def stitch_panorama(input_dir, output_dir):
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
image_paths = sorted(glob.glob(os.path.join(input_dir, "*.*")))
|
image_paths = sorted(glob.glob(os.path.join(input_dir, "*.*")))
|
||||||
|
@ -26,6 +28,7 @@ def stitch_panorama(input_dir, output_dir):
|
||||||
else:
|
else:
|
||||||
print(f"拼接失败,错误码:{status}")
|
print(f"拼接失败,错误码:{status}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
stitch_panorama(r"depot_test\Test", r"depot_test\output")
|
stitch_panorama(r"depot_test\Test", r"depot_test\output")
|
||||||
|
|
|
@ -56,7 +56,7 @@ class GetOrderRemainingTimeSolver(BaseSolver, BaseMixin):
|
||||||
elif pos := self.find("bill_accelerate"):
|
elif pos := self.find("bill_accelerate"):
|
||||||
scope = (70, -203), (194, -170)
|
scope = (70, -203), (194, -170)
|
||||||
scope = sa(scope, pos[0])
|
scope = sa(scope, pos[0])
|
||||||
if res := self.read_remain_time(pos):
|
if res := self.read_remain_time(scope):
|
||||||
self.res = res
|
self.res = res
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue