# -*- coding: utf-8 -*- import json import os import time import joblib import numpy as np import requests import yaml import paho.mqtt.client as mqtt from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Tuple from .logger_utils import log_msg as log BUILD_TAG = "infer-mac-fix-v21" # ------------------------- # UTILITIES # ------------------------- def _norm_mac(s: str) -> str: """ Forza il formato standard xx:xx:xx:xx:xx:xx in MINUSCOLO per matchare il modello. Esempio MQTT: 'AC233FC1DD4E' -> 'ac:23:3f:c1:dd:4e' """ s = (s or "").strip().replace("-", "").replace(":", "").replace(".", "").lower() if len(s) != 12: return s return ":".join([s[i:i+2] for i in range(0, 12, 2)]) def _predict_xyz(model_pkg: Dict[str, Any], X: np.ndarray) -> Tuple[int, float, float]: floor_clf = model_pkg.get("floor_clf") z_pred = floor_clf.predict(X) z = int(z_pred[0]) xy_models = model_pkg.get("xy_by_floor", {}) regressor = xy_models.get(z) if regressor is None: return z, -1.0, -1.0 xy_pred = regressor.predict(X) x, y = xy_pred[0] return z, float(x), float(y) @dataclass class _Point: t: float v: float class RollingRSSI: def __init__(self, window_s: float): self.window_s = window_s self.data: Dict[str, Dict[str, List[_Point]]] = {} def add(self, bm: str, gm: str, rssi: float): # bm e gm arrivano già normalizzati da on_message self.data.setdefault(bm, {}).setdefault(gm, []).append(_Point(time.time(), rssi)) def prune(self): cutoff = time.time() - self.window_s for bm in list(self.data.keys()): for gm in list(self.data[bm].keys()): self.data[bm][gm] = [p for p in self.data[bm][gm] if p.t >= cutoff] if not self.data[bm][gm]: self.data[bm].pop(gm) if not self.data[bm]: self.data.pop(bm) def aggregate_features(self, bm: str, gws: List[str], agg: str, fill: float): per_gw = self.data.get(bm, {}) feats = [] found_count = 0 for gm in gws: # gm è già minuscolo con : perché preso da gateways_ordered (dal modello) vals = [p.v for p in per_gw.get(gm, [])] if vals: val = np.median(vals) if agg == "median" else np.mean(vals) feats.append(val) found_count += 1 else: feats.append(np.nan) X = np.array(feats).reshape(1, -1) return np.where(np.isnan(X), fill, X), found_count # ------------------------- # MAIN RUNNER # ------------------------- def run_infer(settings: Dict[str, Any]): inf_c = settings.get("infer", {}) api_c = settings.get("api", {}) mqtt_c = settings.get("mqtt", {}) model_path = inf_c.get("model_path", "/data/model/model.joblib") log(f"INFER_MODE build tag={BUILD_TAG}") model_pkg = None last_model_mtime = 0 gateways_ordered = [] # Saranno in formato aa:bb:cc... gateways_set = set() nan_fill = -110.0 def load_model_dynamic(): nonlocal model_pkg, last_model_mtime, gateways_ordered, gateways_set, nan_fill try: current_mtime = os.path.getmtime(model_path) if current_mtime != last_model_mtime: log(f"🧠 Aggiornamento modello: {model_path}") model_pkg = joblib.load(model_path) last_model_mtime = current_mtime # Il modello ha i gateway salvati come caricati nell'addestramento gateways_ordered = [gw.lower() for gw in model_pkg.get("gateways_order", [])] gateways_set = set(gateways_ordered) nan_fill = float(model_pkg.get("nan_fill", -110.0)) log(f"✅ MODELLO PRONTO: {len(gateways_ordered)} GW. Fill: {nan_fill}") return True except: return False return True load_model_dynamic() rolling = RollingRSSI(float(inf_c.get("window_seconds", 5.0))) # Cache Beacons token_cache = {"token": None, "expires_at": 0} def get_token(): if time.time() < token_cache["expires_at"]: return token_cache["token"] try: with open("/config/secrets.yaml", "r") as f: sec = yaml.safe_load(f).get("oidc", {}) payload = { "grant_type": "password", "client_id": api_c.get("client_id", "Fastapi"), "client_secret": sec.get("client_secret", ""), "username": sec.get("username", "core"), "password": sec.get("password", "") } resp = requests.post(api_c["token_url"], data=payload, verify=False, timeout=10) if resp.status_code == 200: d = resp.json() token_cache["token"] = d["access_token"] token_cache["expires_at"] = time.time() + d.get("expires_in", 300) - 30 return token_cache["token"] except: pass return None def fetch_beacons(): token = get_token() if not token: return [] try: headers = {"Authorization": f"Bearer {token}", "accept": "application/json"} resp = requests.get(api_c["get_beacons_url"], headers=headers, verify=False, timeout=10) # Normalizziamo anche i beacon cercati per il confronto return [_norm_mac(it["mac"]) for it in resp.json() if "mac" in it] if resp.status_code == 200 else [] except: return [] def on_message(client, userdata, msg): # msg.topic es: publish_out/ac233fc1dd4e raw_gw = msg.topic.split("/")[-1] gw_norm = _norm_mac(raw_gw) if gw_norm not in gateways_set: return try: items = json.loads(msg.payload.decode()) for it in items: raw_bm = it.get("mac") rssi = it.get("rssi") if raw_bm and rssi is not None: bm_norm = _norm_mac(raw_bm) rolling.add(bm_norm, gw_norm, float(rssi)) except: pass mqtt_client = mqtt.Client() mqtt_client.on_message = on_message try: mqtt_client.connect(mqtt_c["host"], mqtt_c["port"]) mqtt_client.subscribe(mqtt_c["topic"]) mqtt_client.loop_start() except Exception as e: log(f"MQTT Error: {e}") last_predict, last_api_refresh = 0.0, 0.0 beacons_to_track = [] while True: now = time.time() rolling.prune() load_model_dynamic() if now - last_api_refresh >= float(api_c.get("refresh_seconds", 30)): beacons_to_track = fetch_beacons() last_api_refresh = now if now - last_predict >= float(inf_c.get("refresh_seconds", 10.0)): rows, count_ok = [], 0 if model_pkg and beacons_to_track: for bm_n in beacons_to_track: X, n_found = rolling.aggregate_features(bm_n, gateways_ordered, inf_c.get("aggregate", "median"), nan_fill) z, x, y = -1, -1.0, -1.0 if n_found >= int(inf_c.get("min_non_nan", 1)): try: z, x, y = _predict_xyz(model_pkg, X) if z != -1: count_ok += 1 except: pass rows.append(f"{bm_n};{int(z)};{int(round(x))};{int(round(y))}") try: out_p = Path(inf_c.get("output_csv", "/data/infer/infer.csv")) out_p.parent.mkdir(parents=True, exist_ok=True) with open(str(out_p) + ".tmp", "w") as f: f.write("mac;z;x;y\n") for r in rows: f.write(r + "\n") os.replace(str(out_p) + ".tmp", out_p) log(f"CYCLE: {count_ok}/{len(beacons_to_track)} localized (Input GW match: {len(gateways_set)})") except Exception as e: log(f"File Error: {e}") last_predict = now time.sleep(0.5)