# -*- coding: utf-8 -*- import json import os import time import joblib import numpy as np import requests import yaml import paho.mqtt.client as mqtt from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple from .logger_utils import log_msg as log # Importiamo la funzione per caricare i gateway per mantenere l'ordine delle feature from .csv_config import load_gateway_features_csv BUILD_TAG = "infer-debug-v20-autoreload" # ------------------------- # UTILITIES # ------------------------- def _norm_mac(s: str) -> str: s = (s or "").strip().replace("-", "").replace(":", "").replace(".", "").upper() if len(s) != 12: return s return ":".join([s[i:i+2] for i in range(0, 12, 2)]) def _predict_xyz(model_pkg: Dict[str, Any], X: np.ndarray) -> Tuple[int, float, float]: floor_clf = model_pkg.get("floor_clf") if floor_clf is None: raise ValueError("Il pacchetto modello non contiene 'floor_clf'") z_pred = floor_clf.predict(X) z = int(z_pred[0]) xy_models = model_pkg.get("xy_by_floor", {}) regressor = xy_models.get(z) if regressor is None: return z, -1.0, -1.0 xy_pred = regressor.predict(X) x, y = xy_pred[0] return z, float(x), float(y) @dataclass class _Point: t: float v: float class RollingRSSI: def __init__(self, window_s: float): self.window_s = window_s self.data: Dict[str, Dict[str, List[_Point]]] = {} def add(self, bm: str, gm: str, rssi: float): self.data.setdefault(bm, {}).setdefault(gm, []).append(_Point(time.time(), rssi)) def prune(self): cutoff = time.time() - self.window_s for bm in list(self.data.keys()): for gm in list(self.data[bm].keys()): self.data[bm][gm] = [p for p in self.data[bm][gm] if p.t >= cutoff] if not self.data[bm][gm]: self.data[bm].pop(gm) if not self.data[bm]: self.data.pop(bm) def aggregate_features(self, bm: str, gws: List[str], agg: str, fill: float): per_gw = self.data.get(bm, {}) feats = [] found_count = 0 for gm in gws: vals = [p.v for p in per_gw.get(gm, [])] if vals: val = np.median(vals) if agg == "median" else np.mean(vals) feats.append(val) found_count += 1 else: feats.append(np.nan) X = np.array(feats).reshape(1, -1) return np.where(np.isnan(X), fill, X), found_count # ------------------------- # MAIN RUNNER # ------------------------- def run_infer(settings: Dict[str, Any]): inf_c = settings.get("infer", {}) api_c = settings.get("api", {}) mqtt_c = settings.get("mqtt", {}) model_path = inf_c.get("model_path", "/data/model/model.joblib") log(f"INFER_MODE build tag={BUILD_TAG}") # Variabili di stato per il caricamento dinamico model_pkg = None last_model_mtime = 0 gateways_ordered = [] gateways_set = set() nan_fill = -110.0 def load_model_dynamic(): nonlocal model_pkg, last_model_mtime, gateways_ordered, gateways_set, nan_fill try: current_mtime = os.path.getmtime(model_path) if current_mtime != last_model_mtime: log(f"🧠 Caricamento/Aggiornamento modello: {model_path}") model_pkg = joblib.load(model_path) last_model_mtime = current_mtime # METADATI PER DEBUG gateways_ordered = model_pkg.get("gateways_order", []) gateways_set = set(gateways_ordered) nan_fill = float(model_pkg.get("nan_fill", -110.0)) floors = list(model_pkg.get("xy_by_floor", {}).keys()) log(f"✅ MODELLO PRONTO: {len(gateways_ordered)} GW allenati.") log(f"🏢 Piani mappati: {floors}") log(f"🧪 Valore Fill (NaN): {nan_fill}") if len(gateways_ordered) > 0: log(f"📡 Primi 3 GW di riferimento: {gateways_ordered[:3]}") return True except Exception as e: if model_pkg is None: log(f"⚠️ In attesa di un modello valido in {model_path}...") return False return True # Primo caricamento load_model_dynamic() rolling = RollingRSSI(float(inf_c.get("window_seconds", 5.0))) # --- Gestione Token e Beacons --- token_cache = {"token": None, "expires_at": 0} def get_token(): if time.time() < token_cache["expires_at"]: return token_cache["token"] try: with open("/config/secrets.yaml", "r") as f: sec = yaml.safe_load(f).get("oidc", {}) payload = { "grant_type": "password", "client_id": api_c.get("client_id", "Fastapi"), "client_secret": sec.get("client_secret", ""), "username": sec.get("username", "core"), "password": sec.get("password", "") } resp = requests.post(api_c["token_url"], data=payload, verify=False, timeout=10) if resp.status_code == 200: d = resp.json() token_cache["token"] = d["access_token"] token_cache["expires_at"] = time.time() + d.get("expires_in", 300) - 30 return token_cache["token"] except: pass return None def fetch_beacons(): token = get_token() if not token: return [] try: headers = {"Authorization": f"Bearer {token}", "accept": "application/json"} resp = requests.get(api_c["get_beacons_url"], headers=headers, verify=False, timeout=10) return [it["mac"] for it in resp.json() if "mac" in it] if resp.status_code == 200 else [] except: return [] def on_message(client, userdata, msg): if not gateways_set: return gw = _norm_mac(msg.topic.split("/")[-1]) if gw not in gateways_set: return try: items = json.loads(msg.payload.decode()) for it in items: bm, rssi = _norm_mac(it.get("mac")), it.get("rssi") if bm and rssi is not None: rolling.add(bm, gw, float(rssi)) except: pass mqtt_client = mqtt.Client() mqtt_client.on_message = on_message try: mqtt_client.connect(mqtt_c["host"], mqtt_c["port"]) mqtt_client.subscribe(mqtt_c["topic"]) mqtt_client.loop_start() except Exception as e: log(f"MQTT Connect Error: {e}") last_predict, last_api_refresh = 0.0, 0.0 beacons = [] while True: now = time.time() rolling.prune() # Verifica se il link simbolico model.joblib è cambiato load_model_dynamic() if now - last_api_refresh >= float(api_c.get("refresh_seconds", 30)): beacons = fetch_beacons() last_api_refresh = now if now - last_predict >= float(inf_c.get("refresh_seconds", 10.0)): rows, count_ok = [], 0 if model_pkg: for bm in beacons: bm_n = _norm_mac(bm) X, n_found = rolling.aggregate_features(bm_n, gateways_ordered, inf_c.get("aggregate", "median"), nan_fill) z, x, y = -1, -1.0, -1.0 if n_found >= int(inf_c.get("min_non_nan", 1)): try: z, x, y = _predict_xyz(model_pkg, X) if z != -1: count_ok += 1 except Exception as e: pass rows.append(f"{bm_n};{int(z)};{int(round(x))};{int(round(y))}") try: out_p = Path(inf_c.get("output_csv", "/data/infer/infer.csv")) out_p.parent.mkdir(parents=True, exist_ok=True) with open(str(out_p) + ".tmp", "w") as f: f.write("mac;z;x;y\n") for r in rows: f.write(r + "\n") os.replace(str(out_p) + ".tmp", out_p) log(f"CYCLE: {count_ok}/{len(rows)} localized") except Exception as e: log(f"File Error: {e}") last_predict = now time.sleep(0.5)