|
- # -*- coding: utf-8 -*-
- import json
- import os
- import time
- import joblib
- import numpy as np
- import requests
- import yaml
- import paho.mqtt.client as mqtt
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Any, Callable, Dict, List, Optional, Tuple
- from .logger_utils import log_msg as log
-
- # Importiamo la funzione per caricare i gateway per mantenere l'ordine delle feature
- from .csv_config import load_gateway_features_csv
-
- BUILD_TAG = "infer-debug-v19-hierarchical-final"
-
- # -------------------------
- # UTILITIES
- # -------------------------
-
- def _norm_mac(s: str) -> str:
- s = (s or "").strip().replace("-", "").replace(":", "").replace(".", "").upper()
- if len(s) != 12: return s
- return ":".join([s[i:i+2] for i in range(0, 12, 2)])
-
- def _predict_xyz(model_pkg: Dict[str, Any], X: np.ndarray) -> Tuple[int, float, float]:
- """
- Logica di predizione basata su train_mode.py:
- 1. Predice Z (piano) tramite KNeighborsClassifier
- 2. Predice X, Y tramite KNeighborsRegressor specifico per quel piano
- """
- # 1. Predizione del Piano (Z)
- floor_clf = model_pkg.get("floor_clf")
- if floor_clf is None:
- raise ValueError("Il pacchetto modello non contiene 'floor_clf'")
-
- z_pred = floor_clf.predict(X)
- z = int(z_pred[0])
-
- # 2. Predizione X, Y (Coordinate)
- xy_models = model_pkg.get("xy_by_floor", {})
- regressor = xy_models.get(z)
-
- if regressor is None:
- # Se il piano predetto non ha un regressore XY, restituiamo solo il piano
- return z, -1.0, -1.0
-
- xy_pred = regressor.predict(X)
- x, y = xy_pred[0] # L'output del regressore multi-output è [x, y]
-
- return z, float(x), float(y)
-
- @dataclass
- class _Point:
- t: float
- v: float
-
- class RollingRSSI:
- def __init__(self, window_s: float):
- self.window_s = window_s
- self.data: Dict[str, Dict[str, List[_Point]]] = {}
-
- def add(self, bm: str, gm: str, rssi: float):
- self.data.setdefault(bm, {}).setdefault(gm, []).append(_Point(time.time(), rssi))
-
- def prune(self):
- cutoff = time.time() - self.window_s
- for bm in list(self.data.keys()):
- for gm in list(self.data[bm].keys()):
- self.data[bm][gm] = [p for p in self.data[bm][gm] if p.t >= cutoff]
- if not self.data[bm][gm]: self.data[bm].pop(gm)
- if not self.data[bm]: self.data.pop(bm)
-
- def aggregate_features(self, bm: str, gws: List[str], agg: str, fill: float):
- per_gw = self.data.get(bm, {})
- feats = []
- found_count = 0
- for gm in gws:
- vals = [p.v for p in per_gw.get(gm, [])]
- if vals:
- val = np.median(vals) if agg == "median" else np.mean(vals)
- feats.append(val)
- found_count += 1
- else:
- feats.append(np.nan)
- X = np.array(feats).reshape(1, -1)
- # Sostituisce i NaN con il valore di riempimento definito nel training
- return np.where(np.isnan(X), fill, X), found_count
-
- # -------------------------
- # MAIN RUNNER
- # -------------------------
-
- def run_infer(settings: Dict[str, Any]):
- inf_c = settings.get("infer", {})
- api_c = settings.get("api", {})
- mqtt_c = settings.get("mqtt", {})
-
- log(f"INFER_MODE build tag={BUILD_TAG}")
-
- # Caricamento Modello e Configurazione Training
- try:
- model_pkg = joblib.load(inf_c.get("model_path", "/data/model/model.joblib"))
- # Recuperiamo l'ordine dei gateway e il nan_fill direttamente dal modello salvato
- gateways_ordered = model_pkg.get("gateways_order")
- nan_fill = float(model_pkg.get("nan_fill", -110.0))
- log(f"Model loaded. Features: {len(gateways_ordered)}, Fill: {nan_fill}")
- except Exception as e:
- log(f"CRITICAL: Failed to load model: {e}")
- return
-
- gateways_set = set(gateways_ordered)
- rolling = RollingRSSI(float(inf_c.get("window_seconds", 5.0)))
-
- # --- Gestione Token e Beacons (identica a prima) ---
- token_cache = {"token": None, "expires_at": 0}
- def get_token():
- if time.time() < token_cache["expires_at"]: return token_cache["token"]
- try:
- # Recupero credenziali da secrets.yaml
- with open("/config/secrets.yaml", "r") as f:
- sec = yaml.safe_load(f).get("oidc", {})
- payload = {
- "grant_type": "password", "client_id": api_c.get("client_id", "Fastapi"),
- "client_secret": sec.get("client_secret", ""),
- "username": sec.get("username", "core"), "password": sec.get("password", "")
- }
- resp = requests.post(api_c["token_url"], data=payload, verify=False, timeout=10)
- if resp.status_code == 200:
- d = resp.json()
- token_cache["token"] = d["access_token"]
- token_cache["expires_at"] = time.time() + d.get("expires_in", 300) - 30
- return token_cache["token"]
- except: pass
- return None
-
- def fetch_beacons():
- token = get_token()
- if not token: return []
- try:
- headers = {"Authorization": f"Bearer {token}", "accept": "application/json"}
- resp = requests.get(api_c["get_beacons_url"], headers=headers, verify=False, timeout=10)
- return [it["mac"] for it in resp.json() if "mac" in it] if resp.status_code == 200 else []
- except: return []
-
- def on_message(client, userdata, msg):
- gw = _norm_mac(msg.topic.split("/")[-1])
- if gw not in gateways_set: return
- try:
- items = json.loads(msg.payload.decode())
- for it in items:
- bm, rssi = _norm_mac(it.get("mac")), it.get("rssi")
- if bm and rssi is not None: rolling.add(bm, gw, float(rssi))
- except: pass
-
- mqtt_client = mqtt.Client()
- mqtt_client.on_message = on_message
- mqtt_client.connect(mqtt_c["host"], mqtt_c["port"])
- mqtt_client.subscribe(mqtt_c["topic"])
- mqtt_client.loop_start()
-
- last_predict, last_api_refresh = 0.0, 0.0
- beacons = []
-
- while True:
- now = time.time()
- rolling.prune()
-
- if now - last_api_refresh >= float(api_c.get("refresh_seconds", 30)):
- beacons = fetch_beacons()
- last_api_refresh = now
-
- if now - last_predict >= float(inf_c.get("refresh_seconds", 10.0)):
- rows, count_ok = [], 0
- for bm in beacons:
- bm_n = _norm_mac(bm)
- X, n_found = rolling.aggregate_features(bm_n, gateways_ordered, inf_c.get("aggregate", "median"), nan_fill)
-
- z, x, y = -1, -1.0, -1.0
- if n_found >= int(inf_c.get("min_non_nan", 1)):
- try:
- z, x, y = _predict_xyz(model_pkg, X)
- if z != -1: count_ok += 1
- except Exception as e:
- log(f"Infer error {bm_n}: {e}")
-
- rows.append(f"{bm_n};{int(z)};{int(round(x))};{int(round(y))}")
-
- try:
- out_p = Path(inf_c.get("output_csv", "/data/infer/infer.csv"))
- with open(str(out_p) + ".tmp", "w") as f:
- f.write("mac;z;x;y\n")
- for r in rows: f.write(r + "\n")
- os.replace(str(out_p) + ".tmp", out_p)
- log(f"CYCLE: {count_ok}/{len(rows)} localized")
- except Exception as e: log(f"File Error: {e}")
-
- last_predict = now
- time.sleep(0.5)
|