获取最优路径方法
All checks were successful
ci/woodpecker/push/check_format Pipeline was successful

This commit is contained in:
Elaina 2024-12-21 15:02:55 +08:00
parent 0b1043b849
commit c0b92b9cc9
3 changed files with 138 additions and 2 deletions

View file

@ -51,8 +51,13 @@ class Node:
self.type = "未知"
self.type_scope = match_scope
def __repr__(self):
return (
f"Node(type={self.type},status={self.status},next_nodes={self.next_nodes})"
)
nodes: dict[str, Node] = {} # 节点编号:节点对象
next_step: dict[str, int] = {} # 节点编号:剩余刷新次数
next_step: list[str] = [] # 下一步可前往的节点编号
current_layer = 0 # 当前层数
scene_image = None # 整层的图像

View file

@ -152,7 +152,7 @@ class NodeDetector:
def check_next_step(self):
for node_id in data.nodes:
if data.nodes[node_id].status:
data.next_step[node_id] = 2
data.next_step.append(node_id)
def check_node_in_screen(self, node_id: str) -> tp.Scope | None:
score, scope = template(

View file

@ -0,0 +1,131 @@
from functools import cached_property
from mower.solvers.rogue import data
class Path:
# 定义一个类属性,表示比较规则的优先级
_compare_rules = []
def __init__(self, path: list[str]):
self.path = path
@cached_property
def length(self):
"""路径的长度"""
return len(self.path)
@cached_property
def type(self):
"""
返回路径中每种类型节点的数量
示例:
{
"作战": 2,
"紧急作战": 1,
...
}
"""
type_num = {}
for node_id in self.path:
node_type_name = data.nodes[node_id].type
if node_type_name in type_num:
type_num[node_type_name] += 1
else:
type_num[node_type_name] = 1
return type_num
@cached_property
def idea(self):
"""路径需要的构想值"""
idea_num = 0
for i in range(1, len(self.path)):
if self.path[i][0] == self.path[i - 1][0]:
idea_num += 2
return idea_num
@classmethod
def set_compare_rules(cls, rules: list[tuple[str, bool]]):
"""
设置比较规则列表
Args:
rules: list[tuple[str, bool]]
- 每个元组的第一个元素是比较属性或类型名称 "length", "紧急作战"
- 第二个元素是布尔值True 表示升序False 表示降序
"""
cls._compare_rules = rules
def __lt__(self, other: "Path"):
"""
重载 < 运算符根据 _compare_rules 动态比较多个属性
"""
if not self._compare_rules:
raise ValueError("未定义比较规则,请先调用 set_compare_rules 设置规则。")
for rule, ascending in self._compare_rules:
# 获取当前规则对应的值
if rule == "length":
# 按路径长度比较
self_value = self.length
other_value = other.length
else:
# 按类型数量比较
self_value = self.type.get(rule, 0)
other_value = other.type.get(rule, 0)
# 如果值不相等,则进行比较
if self_value != other_value:
if ascending:
return self_value < other_value # 升序
else:
return self_value > other_value # 降序
# 如果所有规则都相等,则不比较
return False
def __repr__(self):
return (
f"Path({self.path},length={self.length},type={self.type},idea={self.idea})"
)
def dfs(node_id: str, path: list, all_paths: list["Path"]):
"""
Args:
node_id (str): 当前节点的编号
path (list): 当前路径存储经过的节点编号
all_paths (list[&quot;Path&quot;]): 保存所有路径的列表
"""
path.append(node_id)
if not data.nodes[node_id].next_nodes:
all_paths.append(Path(path[:]))
else:
for next_node_id in data.nodes[node_id].next_nodes:
if next_node_id not in path:
dfs(next_node_id, path, all_paths)
path.pop()
def get_optimal_path(
start_id: str,
idea_num: int = 2,
compare_rules=[("紧急作战", True), ("length", False)],
) -> "Path":
"""
获取最优路径
Args:
idea_num (int): 可使用的构想值
Returns:
Path: 最优路径
"""
all_paths: list["Path"] = []
dfs(start_id, [], all_paths)
Path.set_compare_rules(compare_rules)
all_paths.sort()
for path in all_paths:
if path.idea <= idea_num:
return path