mower-ng/mower/solvers/rogue/detect_node.py
Elaina 8eee4520bf
All checks were successful
ci/woodpecker/push/check_format Pipeline was successful
节点相关方法补充优化
2024-12-19 17:42:11 +08:00

166 lines
6.1 KiB
Python

import cv2
import numpy as np
from mower.solvers.rogue import data
from mower.utils import config
from mower.utils import typealias as tp
from mower.utils.image import cropimg, loadres, thres2
from mower.utils.vector import in_scope, sa, vs
from .utils import template
class NodeDetector:
def __init__(self):
self.edges = []
self.rectangles = []
self.detect_roads_and_rectangles()
self.detect_nodes()
self.detect_x_edges()
self.detect_y_edges()
self.check_next_step()
def merge_rectangles(self, rects: list):
"合并有重叠的矩形"
merged = []
while rects:
r1 = rects.pop(0)
has_merged = False
for i, r2 in enumerate(merged):
if (
r1[0][0] <= r2[1][0]
and r1[1][0] >= r2[0][0] # x 轴重叠
and r1[0][1] <= r2[1][1]
and r1[1][1] >= r2[0][1]
): # y 轴重叠
new_top_left = (min(r1[0][0], r2[0][0]), min(r1[0][1], r2[0][1]))
new_bottom_right = (
max(r1[1][0], r2[1][0]),
max(r1[1][1], r2[1][1]),
)
merged[i] = (new_top_left, new_bottom_right)
has_merged = True
break
if not has_merged:
merged.append(r1)
return merged
def detect_roads_and_rectangles(self):
img = cv2.cvtColor(data.scene_image, cv2.COLOR_RGB2GRAY)
img = thres2(img, 150)
contours, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
rect_width = 350
rect_height = 100
rectangles = []
# 遍历轮廓并生成矩形
for c in contours:
x, y, w, h = cv2.boundingRect(c)
if w > 250:
# 获取最小外接矩形
rect = cv2.minAreaRect(c)
box = cv2.boxPoints(rect)
box = np.int0(box)
# 计算连通块的两个端点
left_point = tuple(box[np.argmin(box[:, 0])])
right_point = tuple(box[np.argmax(box[:, 0])])
self.edges.append((left_point, right_point))
# 左端点矩形
left_rect_top_left = (
left_point[0] - rect_width,
left_point[1] - rect_height // 2,
)
left_rect_bottom_right = (
left_point[0],
left_point[1] + rect_height // 2,
)
rectangles.append((left_rect_top_left, left_rect_bottom_right))
# 右端点矩形
right_rect_top_left = (
right_point[0],
right_point[1] - rect_height // 2,
)
right_rect_bottom_right = (
right_point[0] + rect_width,
right_point[1] + rect_height // 2,
)
rectangles.append((right_rect_top_left, right_rect_bottom_right))
self.rectangles = self.merge_rectangles(rectangles)
def detect_nodes(self):
self.rectangles.sort(key=lambda r: r[0][0]) # 按左上角横坐标排序
columns: list[list[tp.Scope]] = []
current_column: list[tp.Scope] = []
prev_x = None
col_tolerance = 100 # 横坐标差值小于100认为是同一列
# Step 1: 按照矩形的左上角横坐标分组列
for rect in self.rectangles:
top_left = rect[0]
if prev_x is None or abs(top_left[0] - prev_x) <= col_tolerance:
current_column.append(rect)
else:
columns.append(current_column)
current_column = [rect]
prev_x = top_left[0]
if current_column:
columns.append(current_column)
# Step 2: 为每列的矩形按照左上角纵坐标排序,生成节点编号
for col_idx, column in enumerate(columns):
column.sort(key=lambda r: r[0][1]) # 按左上角纵坐标排序
for row_idx, rect in enumerate(column):
node_id = f"{col_idx+1}{row_idx+1}"
data.nodes[node_id] = data.Node(rect)
def detect_x_edges(self):
for edge in self.edges:
st = ed = ""
for node_id in data.nodes:
if in_scope(data.nodes[node_id].scope, edge[0]):
st = node_id
if in_scope(data.nodes[node_id].scope, edge[1]):
ed = node_id
if st and ed:
break
data.nodes[st].next_nodes.append(ed)
def detect_y_edges(self):
for node_id in data.nodes:
next_id = str(int(node_id) + 1)
if next_id not in data.nodes:
continue
scope = (
(
data.nodes[node_id].type_scope[0][0],
data.nodes[node_id].type_scope[1][1],
),
(
data.nodes[next_id].type_scope[1][0],
data.nodes[next_id].type_scope[0][1],
),
)
res = loadres("rogue/y_edge")
score, _ = template(data.scene_image, res, scope)
if score > 0.7:
data.nodes[node_id].next_nodes.append(next_id)
data.nodes[next_id].next_nodes.append(node_id)
def check_next_step(self):
for node_id in data.nodes:
if data.nodes[node_id].status:
data.next_step[node_id] = 2
def check_node_in_screen(self, node_id: str) -> tp.Scope | None:
score, scope = template(
data.scene_image, cropimg(config.recog.img, ((0, 150), (1920, 880)))
)
node_scope = data.nodes[node_id].scope
if in_scope(scope, node_scope[0]) and in_scope(scope, node_scope[1]):
res_scope = vs(node_scope[0], scope[0]), vs(node_scope[1], scope[0])
res_scope = sa(res_scope, (0, 150))
return res_scope
return None