公招结果识别:mobilenet保留分类器前三层;先通过职业分类,再用knn分类
All checks were successful
ci/woodpecker/push/check_format Pipeline was successful
All checks were successful
ci/woodpecker/push/check_format Pipeline was successful
This commit is contained in:
parent
4e99518088
commit
8b69e90570
4 changed files with 34 additions and 25 deletions
|
@ -18,10 +18,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
from copy import copy
|
||||
from datetime import timedelta
|
||||
from functools import cache
|
||||
from itertools import combinations
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from mower.data import recruit_agent
|
||||
from mower.static import recruit, recruit_result_knn
|
||||
|
@ -32,7 +32,6 @@ from mower.utils.log import logger
|
|||
from mower.utils.path import get_path
|
||||
from mower.utils.scene import Scene
|
||||
from mower.utils.solver import BaseSolver
|
||||
from mower.utils.typealias import Image
|
||||
from mower.utils.vector import sa, va
|
||||
|
||||
|
||||
|
@ -91,8 +90,6 @@ class RecruitSolver(BaseSolver):
|
|||
solver_max_duration = timedelta(minutes=3)
|
||||
|
||||
def run(self):
|
||||
net_path = get_path("@install/mower/static/mobilenet_v3_small_features.onnx")
|
||||
self.mobilenet = cv2.dnn.readNetFromONNX(str(net_path))
|
||||
self.index_known: bool = False
|
||||
self.slot_index: int = 0
|
||||
self.info = {}
|
||||
|
@ -235,18 +232,30 @@ class RecruitSolver(BaseSolver):
|
|||
self.ctap((x, y), 2)
|
||||
return True
|
||||
|
||||
@property
|
||||
@cache
|
||||
def mobilenet(self):
|
||||
net_name = "mobilenet_v3_small_feature_extractor.onnx"
|
||||
net_path = get_path(f"@install/mower/static/{net_name}")
|
||||
return cv2.dnn.readNetFromONNX(str(net_path))
|
||||
|
||||
def recruit_result(self) -> str:
|
||||
img: Image = cropimg(config.recog.img, ((800, 100), (1400, 700)))
|
||||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) * 255.0
|
||||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32) * 255.0
|
||||
img = (img.astype(np.float32) - mean) / std
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
img = np.expand_dims(img, axis=0).astype(np.float32)
|
||||
self.mobilenet.setInput(img, "input")
|
||||
features = self.mobilenet.forward("output")
|
||||
result = recruit_result_knn().predict(features)
|
||||
logger.debug(result)
|
||||
return result
|
||||
for profession, knn_classifier in recruit_result_knn().items():
|
||||
if config.recog.find(f"recruit/profession/{profession}"):
|
||||
img = cropimg(config.recog.img, ((800, 100), (1300, 600)))
|
||||
blob = cv2.dnn.blobFromImage(
|
||||
img,
|
||||
scalefactor=1 / 255,
|
||||
size=(224, 224),
|
||||
mean=(0, 0, 0),
|
||||
swapRB=False,
|
||||
crop=False,
|
||||
)
|
||||
self.mobilenet.setInput(blob)
|
||||
features = self.mobilenet.forward()
|
||||
result = str(knn_classifier.predict(features)[0])
|
||||
logger.debug(result)
|
||||
return result
|
||||
|
||||
def transition(self):
|
||||
if len(self.info) == 4:
|
||||
|
|
Binary file not shown.
BIN
mower/static/recruit_result_knn.pkl
(Stored with Git LFS)
BIN
mower/static/recruit_result_knn.pkl
(Stored with Git LFS)
Binary file not shown.
|
@ -408,14 +408,14 @@ template_matching = {
|
|||
"recruit/agent_token_first": ((1700, 760), (1920, 810)),
|
||||
"recruit/begin_recruit": None,
|
||||
"recruit/job_requirements": None,
|
||||
"recruit/profession/CASTER": None,
|
||||
"recruit/profession/MEDIC": None,
|
||||
"recruit/profession/PIONEER": None,
|
||||
"recruit/profession/SNIPER": None,
|
||||
"recruit/profession/SPECIAL": None,
|
||||
"recruit/profession/SUPPORT": None,
|
||||
"recruit/profession/TANK": None,
|
||||
"recruit/profession/WARRIOR": None,
|
||||
"recruit/profession/CASTER": ((700, 720), (1040, 890)),
|
||||
"recruit/profession/MEDIC": ((700, 720), (1040, 890)),
|
||||
"recruit/profession/PIONEER": ((700, 720), (1040, 890)),
|
||||
"recruit/profession/SNIPER": ((700, 720), (1040, 890)),
|
||||
"recruit/profession/SPECIAL": ((700, 720), (1040, 890)),
|
||||
"recruit/profession/SUPPORT": ((700, 720), (1040, 890)),
|
||||
"recruit/profession/TANK": ((700, 720), (1040, 890)),
|
||||
"recruit/profession/WARRIOR": ((700, 720), (1040, 890)),
|
||||
"recruit/recruit_done": None,
|
||||
"recruit/recruit_lock": None,
|
||||
"recruit/refresh_comfirm": (1237, 714),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue