Compare commits
2 commits
482e8cc776
...
eb5b123cb1
Author | SHA1 | Date | |
---|---|---|---|
eb5b123cb1 | |||
b80a44cedd |
7 changed files with 149 additions and 68 deletions
|
@ -24,17 +24,6 @@ def 提取特征点(模板):
|
|||
return hog_features
|
||||
|
||||
|
||||
def 训练knn模型(images, labels):
|
||||
knn_classifier = KNeighborsClassifier(weights="distance", n_neighbors=1, n_jobs=1)
|
||||
knn_classifier.fit(images, labels)
|
||||
return knn_classifier
|
||||
|
||||
|
||||
def 保存knn模型(classifier, filename):
|
||||
with lzma.open(filename, "wb") as f:
|
||||
pickle.dump(classifier, f)
|
||||
|
||||
|
||||
class DepotMatcher:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -154,6 +143,73 @@ class DepotMatcher:
|
|||
|
||||
return self
|
||||
|
||||
def 训练并保存knn模型(self, images, labels, filename):
|
||||
"""训练并保存KNN模型"""
|
||||
knn_classifier = KNeighborsClassifier(
|
||||
weights="distance", n_neighbors=1, n_jobs=1
|
||||
)
|
||||
knn_classifier.fit(images, labels)
|
||||
|
||||
with lzma.open(filename, "wb") as f:
|
||||
pickle.dump(knn_classifier, f)
|
||||
|
||||
return knn_classifier
|
||||
|
||||
def 训练knn模型(self, 模型保存路径="depot_knn_model.xz"):
|
||||
"""训练并保存KNN模型"""
|
||||
if not self.refs:
|
||||
print("错误:请先加载参考图像!")
|
||||
return None
|
||||
|
||||
# 准备训练数据
|
||||
images = []
|
||||
labels = []
|
||||
|
||||
for name, img_array in self.refs.items():
|
||||
# 提取HOG特征
|
||||
features = 提取特征点(img_array)
|
||||
images.append(features)
|
||||
labels.append(name)
|
||||
|
||||
# 训练模型
|
||||
self.knn_model = self.训练并保存knn模型(images, labels, 模型保存路径)
|
||||
print(f"KNN模型训练完成,已保存到: {模型保存路径}")
|
||||
return self
|
||||
|
||||
def 使用knn预测(self, 测试图像):
|
||||
"""使用训练好的KNN模型进行预测"""
|
||||
if not hasattr(self, "knn_model"):
|
||||
print("错误:请先训练KNN模型!")
|
||||
return None, None
|
||||
|
||||
# 提取测试图像的特征
|
||||
features = 提取特征点(测试图像)
|
||||
|
||||
# 预测
|
||||
预测结果 = self.knn_model.predict([features])
|
||||
置信度 = self.knn_model.predict_proba([features])
|
||||
|
||||
return 预测结果[0], 置信度[0]
|
||||
|
||||
def match_knn(self):
|
||||
"""KNN匹配方法"""
|
||||
if not hasattr(self, "knn_model"):
|
||||
print("错误:请先训练KNN模型!")
|
||||
return self
|
||||
|
||||
self.knn_results = []
|
||||
|
||||
for idx, gray_roi_np, _ in self.rois:
|
||||
# 使用KNN进行预测
|
||||
预测名称, 置信度 = self.使用knn预测(gray_roi_np)
|
||||
|
||||
# 获取最高置信度的值
|
||||
max_conf = np.max(置信度) if 置信度 is not None else 0
|
||||
|
||||
self.knn_results.append((idx, 预测名称, max_conf))
|
||||
|
||||
return self
|
||||
|
||||
def match_template(self, threshold=0.2):
|
||||
"""模板匹配方法"""
|
||||
self.template_results = []
|
||||
|
@ -257,7 +313,7 @@ class DepotMatcher:
|
|||
return self
|
||||
|
||||
def display_results(self):
|
||||
"""可视化匹配结果,直接使用内存中的图像数据"""
|
||||
"""可视化匹配结果,现在包含KNN结果"""
|
||||
ROW_LIMIT = 9
|
||||
|
||||
# 获取一个参考图像的尺寸作为空白图像的基础
|
||||
|
@ -271,14 +327,13 @@ class DepotMatcher:
|
|||
current_row_width = 0
|
||||
max_row_height = 0
|
||||
|
||||
# self.rois 现在包含 (idx, gray_roi_data, color_roi_data)
|
||||
for idx, gray_roi_data, color_roi_data in self.rois:
|
||||
color_roi_img = Image.fromarray(
|
||||
cv2.cvtColor(color_roi_data, cv2.COLOR_BGR2RGB)
|
||||
)
|
||||
|
||||
gray_roi_img = Image.fromarray(gray_roi_data).convert("L").convert("RGB")
|
||||
|
||||
# 获取模板匹配结果
|
||||
t_res_name = next(
|
||||
(name for i, name, val in self.template_results if i == idx), None
|
||||
)
|
||||
|
@ -291,6 +346,7 @@ class DepotMatcher:
|
|||
else blank_img_pil.copy()
|
||||
)
|
||||
|
||||
# 获取SSIM匹配结果
|
||||
s_res_details = next((d for d in self.ssim_results if d[0] == idx), None)
|
||||
s_res_name = s_res_details[1] if s_res_details else None
|
||||
s_ref_img = (
|
||||
|
@ -299,18 +355,32 @@ class DepotMatcher:
|
|||
else blank_img_pil.copy()
|
||||
)
|
||||
|
||||
# 获取KNN匹配结果
|
||||
k_res_details = next(
|
||||
(d for d in getattr(self, "knn_results", []) if d[0] == idx), None
|
||||
)
|
||||
k_res_name = k_res_details[1] if k_res_details else None
|
||||
k_ref_img = (
|
||||
Image.fromarray(self.refs[k_res_name]).convert("RGB")
|
||||
if k_res_name and k_res_name in self.refs
|
||||
else blank_img_pil.copy()
|
||||
)
|
||||
k_res_val = k_res_details[2] if k_res_details else 0
|
||||
|
||||
# 计算组合尺寸
|
||||
combined_width = (
|
||||
color_roi_img.width
|
||||
+ gray_roi_img.width
|
||||
+ t_ref_img.width
|
||||
+ s_ref_img.width
|
||||
+ k_ref_img.width
|
||||
)
|
||||
combined_height = max(
|
||||
color_roi_img.height,
|
||||
gray_roi_img.height,
|
||||
t_ref_img.height,
|
||||
s_ref_img.height,
|
||||
k_ref_img.height,
|
||||
)
|
||||
|
||||
# 创建组合图像
|
||||
|
@ -330,24 +400,26 @@ class DepotMatcher:
|
|||
x_offset += t_ref_img.width
|
||||
|
||||
combined.paste(s_ref_img, (x_offset, 0))
|
||||
x_offset += s_ref_img.width
|
||||
|
||||
combined.paste(k_ref_img, (x_offset, 0))
|
||||
x_offset += k_ref_img.width
|
||||
|
||||
# 添加标注
|
||||
draw = ImageDraw.Draw(combined)
|
||||
|
||||
font = ImageFont.truetype("msyh.ttc", 16)
|
||||
|
||||
label = (
|
||||
f"ROI {idx} SSIM/H/E:{s_res_details[3]:.2f}/{s_res_details[4]:.2f}/{s_res_details[5]:.2f}\n"
|
||||
f"ROI {idx}\n"
|
||||
f"T: {t_res_name or 'None'} ({t_res_val:.2f})\n"
|
||||
f"S: {s_res_details[1] or 'None'} ({s_res_details[2]:.2f})\n"
|
||||
)
|
||||
# 文本颜色:如果模板匹配和SSIM匹配结果一致,则为红,否则为品红
|
||||
text_color = (
|
||||
(255, 0, 0) if t_res_name == s_res_details[1] else (255, 0, 255)
|
||||
f"S: {s_res_name or 'None'} ({s_res_details[2]:.2f})\n"
|
||||
f"K: {k_res_name or 'None'} ({k_res_val:.2f})"
|
||||
)
|
||||
|
||||
# 调整标签位置到模板匹配参考图的左上角
|
||||
# (起始 X 坐标是彩色ROI和灰度ROI的宽度之和)
|
||||
# 文本颜色:如果三种方法结果一致则为红色,否则为黑色
|
||||
text_color = (
|
||||
(255, 0, 0) if t_res_name == s_res_name == k_res_name else (0, 0, 0)
|
||||
)
|
||||
|
||||
draw.text(
|
||||
(color_roi_img.width, gray_roi_img.height),
|
||||
|
@ -372,12 +444,11 @@ class DepotMatcher:
|
|||
row_img.paste(img, (x, 0))
|
||||
x += img.width
|
||||
combined_images.append(row_img)
|
||||
# 重置行变量
|
||||
current_row_images = []
|
||||
current_row_width = 0
|
||||
max_row_height = 0
|
||||
|
||||
# 处理最后一行(如果未满 ROW_LIMIT)
|
||||
# 处理最后一行
|
||||
if current_row_images:
|
||||
row_img = Image.new(
|
||||
"RGB", (current_row_width, max_row_height), (255, 255, 255)
|
||||
|
@ -388,44 +459,35 @@ class DepotMatcher:
|
|||
x += img.width
|
||||
combined_images.append(row_img)
|
||||
|
||||
# 如果没有生成任何图像行,则退出
|
||||
if not combined_images:
|
||||
print("没有生成任何结果图像行。")
|
||||
return self
|
||||
|
||||
# 生成最终图像
|
||||
total_height = sum(img.height for img in combined_images)
|
||||
# 使用第一行的宽度作为所有行的最大宽度(或计算所有行的最大宽度)
|
||||
max_width = combined_images[0].width if combined_images else 0
|
||||
# 如果行宽不一致,可能需要调整逻辑或使用最大宽度
|
||||
final_img = Image.new(
|
||||
"RGB",
|
||||
(max_width, total_height),
|
||||
(255, 255, 255),
|
||||
)
|
||||
if combined_images:
|
||||
total_height = sum(img.height for img in combined_images)
|
||||
max_width = max(img.width for img in combined_images)
|
||||
final_img = Image.new("RGB", (max_width, total_height), (255, 255, 255))
|
||||
|
||||
y = 0
|
||||
for img in combined_images:
|
||||
# 确保粘贴时宽度不超过 final_img 的宽度
|
||||
final_img.paste(img, (0, y))
|
||||
y += img.height
|
||||
y = 0
|
||||
for img in combined_images:
|
||||
final_img.paste(img, (0, y))
|
||||
y += img.height
|
||||
|
||||
output_path = "depot_test/output/matches_all.png"
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
final_img.save(output_path)
|
||||
print(f"结果图像已保存至: {output_path}")
|
||||
|
||||
# 保存最终结果图像
|
||||
output_path = "depot_test/output/matches_all.png"
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True) # 确保目录存在
|
||||
final_img.save(output_path)
|
||||
print(f"结果图像已保存至: {output_path}")
|
||||
return self
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用示例
|
||||
matcher = DepotMatcher()
|
||||
|
||||
(
|
||||
matcher.load_references()
|
||||
.detect_and_crop()
|
||||
.match_template(threshold=0.2)
|
||||
.match_ssim(threshold=0.05)
|
||||
.match_knn()
|
||||
.display_results()
|
||||
)
|
||||
|
||||
|
|
|
@ -150,9 +150,13 @@ def match_ssim(rois, refs, thresh):
|
|||
dynamic_thresh = thresh * (1 + 0.3 * roi_complexity)
|
||||
|
||||
if max_combined_score >= dynamic_thresh:
|
||||
results.append((idx, best_match, max_combined_score, best_ssim, best_hist, best_edge))
|
||||
results.append(
|
||||
(idx, best_match, max_combined_score, best_ssim, best_hist, best_edge)
|
||||
)
|
||||
else:
|
||||
results.append((idx, None, max_combined_score, best_ssim, best_hist, best_edge))
|
||||
results.append(
|
||||
(idx, None, max_combined_score, best_ssim, best_hist, best_edge)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
@ -281,11 +285,21 @@ if __name__ == "__main__":
|
|||
t_val = next((val for i, name, val in results_template if i == idx), 0)
|
||||
|
||||
# SSIM results - now includes detailed metrics
|
||||
s_res = next((name for i, name, val, ssim, hist, edge in results_ssim if i == idx), None)
|
||||
s_val = next((val for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0)
|
||||
s_ssim = next((ssim for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0)
|
||||
s_hist = next((hist for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0)
|
||||
s_edge = next((edge for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0)
|
||||
s_res = next(
|
||||
(name for i, name, val, ssim, hist, edge in results_ssim if i == idx), None
|
||||
)
|
||||
s_val = next(
|
||||
(val for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||
)
|
||||
s_ssim = next(
|
||||
(ssim for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||
)
|
||||
s_hist = next(
|
||||
(hist for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||
)
|
||||
s_edge = next(
|
||||
(edge for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||
)
|
||||
|
||||
print(
|
||||
f"ROI {idx}:\n"
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from multiprocessing import Pool
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from skimage.metrics import structural_similarity
|
||||
from PIL import Image
|
||||
import json
|
||||
from multiprocessing import Pool
|
||||
now=datetime.now()
|
||||
from skimage.metrics import structural_similarity
|
||||
|
||||
now = datetime.now()
|
||||
# 配置路径
|
||||
REF_DIR = r"depot_test\output/test/origin"
|
||||
ROI_DIR = r"depot_test\output/test/result"
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import cv2
|
||||
import glob
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
def stitch_panorama(input_dir, output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_paths = sorted(glob.glob(os.path.join(input_dir, "*.*")))
|
||||
|
@ -26,6 +28,7 @@ def stitch_panorama(input_dir, output_dir):
|
|||
else:
|
||||
print(f"拼接失败,错误码:{status}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start = datetime.now()
|
||||
stitch_panorama(r"depot_test\Test", r"depot_test\output")
|
||||
|
|
|
@ -56,7 +56,7 @@ class GetOrderRemainingTimeSolver(BaseSolver, BaseMixin):
|
|||
elif pos := self.find("bill_accelerate"):
|
||||
scope = (70, -203), (194, -170)
|
||||
scope = sa(scope, pos[0])
|
||||
if res := self.read_remain_time(pos):
|
||||
if res := self.read_remain_time(scope):
|
||||
self.res = res
|
||||
return True
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue