# -*- coding: utf-8 -*- import json import os import time import joblib import numpy as np import requests import yaml import paho.mqtt.client as mqtt from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple from .logger_utils import log_msg as log # Importiamo la funzione per caricare i gateway per mantenere l'ordine delle feature from .csv_config import load_gateway_features_csv BUILD_TAG = "infer-debug-v19-hierarchical-final" # ------------------------- # UTILITIES # ------------------------- def _norm_mac(s: str) -> str: s = (s or "").strip().replace("-", "").replace(":", "").replace(".", "").upper() if len(s) != 12: return s return ":".join([s[i:i+2] for i in range(0, 12, 2)]) def _predict_xyz(model_pkg: Dict[str, Any], X: np.ndarray) -> Tuple[int, float, float]: """ Logica di predizione basata su train_mode.py: 1. Predice Z (piano) tramite KNeighborsClassifier 2. Predice X, Y tramite KNeighborsRegressor specifico per quel piano """ # 1. Predizione del Piano (Z) floor_clf = model_pkg.get("floor_clf") if floor_clf is None: raise ValueError("Il pacchetto modello non contiene 'floor_clf'") z_pred = floor_clf.predict(X) z = int(z_pred[0]) # 2. Predizione X, Y (Coordinate) xy_models = model_pkg.get("xy_by_floor", {}) regressor = xy_models.get(z) if regressor is None: # Se il piano predetto non ha un regressore XY, restituiamo solo il piano return z, -1.0, -1.0 xy_pred = regressor.predict(X) x, y = xy_pred[0] # L'output del regressore multi-output รจ [x, y] return z, float(x), float(y) @dataclass class _Point: t: float v: float class RollingRSSI: def __init__(self, window_s: float): self.window_s = window_s self.data: Dict[str, Dict[str, List[_Point]]] = {} def add(self, bm: str, gm: str, rssi: float): self.data.setdefault(bm, {}).setdefault(gm, []).append(_Point(time.time(), rssi)) def prune(self): cutoff = time.time() - self.window_s for bm in list(self.data.keys()): for gm in list(self.data[bm].keys()): self.data[bm][gm] = [p for p in self.data[bm][gm] if p.t >= cutoff] if not self.data[bm][gm]: self.data[bm].pop(gm) if not self.data[bm]: self.data.pop(bm) def aggregate_features(self, bm: str, gws: List[str], agg: str, fill: float): per_gw = self.data.get(bm, {}) feats = [] found_count = 0 for gm in gws: vals = [p.v for p in per_gw.get(gm, [])] if vals: val = np.median(vals) if agg == "median" else np.mean(vals) feats.append(val) found_count += 1 else: feats.append(np.nan) X = np.array(feats).reshape(1, -1) # Sostituisce i NaN con il valore di riempimento definito nel training return np.where(np.isnan(X), fill, X), found_count # ------------------------- # MAIN RUNNER # ------------------------- def run_infer(settings: Dict[str, Any]): inf_c = settings.get("infer", {}) api_c = settings.get("api", {}) mqtt_c = settings.get("mqtt", {}) log(f"INFER_MODE build tag={BUILD_TAG}") # Caricamento Modello e Configurazione Training try: model_pkg = joblib.load(inf_c.get("model_path", "/data/model/model.joblib")) # Recuperiamo l'ordine dei gateway e il nan_fill direttamente dal modello salvato gateways_ordered = model_pkg.get("gateways_order") nan_fill = float(model_pkg.get("nan_fill", -110.0)) log(f"Model loaded. Features: {len(gateways_ordered)}, Fill: {nan_fill}") except Exception as e: log(f"CRITICAL: Failed to load model: {e}") return gateways_set = set(gateways_ordered) rolling = RollingRSSI(float(inf_c.get("window_seconds", 5.0))) # --- Gestione Token e Beacons (identica a prima) --- token_cache = {"token": None, "expires_at": 0} def get_token(): if time.time() < token_cache["expires_at"]: return token_cache["token"] try: # Recupero credenziali da secrets.yaml with open("/config/secrets.yaml", "r") as f: sec = yaml.safe_load(f).get("oidc", {}) payload = { "grant_type": "password", "client_id": api_c.get("client_id", "Fastapi"), "client_secret": sec.get("client_secret", ""), "username": sec.get("username", "core"), "password": sec.get("password", "") } resp = requests.post(api_c["token_url"], data=payload, verify=False, timeout=10) if resp.status_code == 200: d = resp.json() token_cache["token"] = d["access_token"] token_cache["expires_at"] = time.time() + d.get("expires_in", 300) - 30 return token_cache["token"] except: pass return None def fetch_beacons(): token = get_token() if not token: return [] try: headers = {"Authorization": f"Bearer {token}", "accept": "application/json"} resp = requests.get(api_c["get_beacons_url"], headers=headers, verify=False, timeout=10) return [it["mac"] for it in resp.json() if "mac" in it] if resp.status_code == 200 else [] except: return [] def on_message(client, userdata, msg): gw = _norm_mac(msg.topic.split("/")[-1]) if gw not in gateways_set: return try: items = json.loads(msg.payload.decode()) for it in items: bm, rssi = _norm_mac(it.get("mac")), it.get("rssi") if bm and rssi is not None: rolling.add(bm, gw, float(rssi)) except: pass mqtt_client = mqtt.Client() mqtt_client.on_message = on_message mqtt_client.connect(mqtt_c["host"], mqtt_c["port"]) mqtt_client.subscribe(mqtt_c["topic"]) mqtt_client.loop_start() last_predict, last_api_refresh = 0.0, 0.0 beacons = [] while True: now = time.time() rolling.prune() if now - last_api_refresh >= float(api_c.get("refresh_seconds", 30)): beacons = fetch_beacons() last_api_refresh = now if now - last_predict >= float(inf_c.get("refresh_seconds", 10.0)): rows, count_ok = [], 0 for bm in beacons: bm_n = _norm_mac(bm) X, n_found = rolling.aggregate_features(bm_n, gateways_ordered, inf_c.get("aggregate", "median"), nan_fill) z, x, y = -1, -1.0, -1.0 if n_found >= int(inf_c.get("min_non_nan", 1)): try: z, x, y = _predict_xyz(model_pkg, X) if z != -1: count_ok += 1 except Exception as e: log(f"Infer error {bm_n}: {e}") rows.append(f"{bm_n};{int(z)};{int(round(x))};{int(round(y))}") try: out_p = Path(inf_c.get("output_csv", "/data/infer/infer.csv")) with open(str(out_p) + ".tmp", "w") as f: f.write("mac;z;x;y\n") for r in rows: f.write(r + "\n") os.replace(str(out_p) + ".tmp", out_p) log(f"CYCLE: {count_ok}/{len(rows)} localized") except Exception as e: log(f"File Error: {e}") last_predict = now time.sleep(0.5)