|
- # -*- coding: utf-8 -*-
- import json
- import os
- import time
- import joblib
- import numpy as np
- import requests
- import yaml
- import paho.mqtt.client as mqtt
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Any, Dict, List, Tuple
- from .logger_utils import log_msg as log
-
- BUILD_TAG = "infer-mac-fix-v21"
-
- # -------------------------
- # UTILITIES
- # -------------------------
-
- def _norm_mac(s: str) -> str:
- """
- Forza il formato standard xx:xx:xx:xx:xx:xx in MINUSCOLO per matchare il modello.
- Esempio MQTT: 'AC233FC1DD4E' -> 'ac:23:3f:c1:dd:4e'
- """
- s = (s or "").strip().replace("-", "").replace(":", "").replace(".", "").lower()
- if len(s) != 12: return s
- return ":".join([s[i:i+2] for i in range(0, 12, 2)])
-
- def _predict_xyz(model_pkg: Dict[str, Any], X: np.ndarray) -> Tuple[int, float, float]:
- floor_clf = model_pkg.get("floor_clf")
- z_pred = floor_clf.predict(X)
- z = int(z_pred[0])
-
- xy_models = model_pkg.get("xy_by_floor", {})
- regressor = xy_models.get(z)
-
- if regressor is None:
- return z, -1.0, -1.0
-
- xy_pred = regressor.predict(X)
- x, y = xy_pred[0]
- return z, float(x), float(y)
-
- @dataclass
- class _Point:
- t: float
- v: float
-
- class RollingRSSI:
- def __init__(self, window_s: float):
- self.window_s = window_s
- self.data: Dict[str, Dict[str, List[_Point]]] = {}
-
- def add(self, bm: str, gm: str, rssi: float):
- # bm e gm arrivano già normalizzati da on_message
- self.data.setdefault(bm, {}).setdefault(gm, []).append(_Point(time.time(), rssi))
-
- def prune(self):
- cutoff = time.time() - self.window_s
- for bm in list(self.data.keys()):
- for gm in list(self.data[bm].keys()):
- self.data[bm][gm] = [p for p in self.data[bm][gm] if p.t >= cutoff]
- if not self.data[bm][gm]: self.data[bm].pop(gm)
- if not self.data[bm]: self.data.pop(bm)
-
- def aggregate_features(self, bm: str, gws: List[str], agg: str, fill: float):
- per_gw = self.data.get(bm, {})
- feats = []
- found_count = 0
- for gm in gws:
- # gm è già minuscolo con : perché preso da gateways_ordered (dal modello)
- vals = [p.v for p in per_gw.get(gm, [])]
- if vals:
- val = np.median(vals) if agg == "median" else np.mean(vals)
- feats.append(val)
- found_count += 1
- else:
- feats.append(np.nan)
- X = np.array(feats).reshape(1, -1)
- return np.where(np.isnan(X), fill, X), found_count
-
- # -------------------------
- # MAIN RUNNER
- # -------------------------
-
- def run_infer(settings: Dict[str, Any]):
- inf_c = settings.get("infer", {})
- api_c = settings.get("api", {})
- mqtt_c = settings.get("mqtt", {})
- model_path = inf_c.get("model_path", "/data/model/model.joblib")
-
- log(f"INFER_MODE build tag={BUILD_TAG}")
-
- model_pkg = None
- last_model_mtime = 0
- gateways_ordered = [] # Saranno in formato aa:bb:cc...
- gateways_set = set()
- nan_fill = -110.0
-
- def load_model_dynamic():
- nonlocal model_pkg, last_model_mtime, gateways_ordered, gateways_set, nan_fill
- try:
- current_mtime = os.path.getmtime(model_path)
- if current_mtime != last_model_mtime:
- log(f"🧠 Aggiornamento modello: {model_path}")
- model_pkg = joblib.load(model_path)
- last_model_mtime = current_mtime
-
- # Il modello ha i gateway salvati come caricati nell'addestramento
- gateways_ordered = [gw.lower() for gw in model_pkg.get("gateways_order", [])]
- gateways_set = set(gateways_ordered)
- nan_fill = float(model_pkg.get("nan_fill", -110.0))
-
- log(f"✅ MODELLO PRONTO: {len(gateways_ordered)} GW. Fill: {nan_fill}")
- return True
- except: return False
- return True
-
- load_model_dynamic()
- rolling = RollingRSSI(float(inf_c.get("window_seconds", 5.0)))
-
- # Cache Beacons
- token_cache = {"token": None, "expires_at": 0}
- def get_token():
- if time.time() < token_cache["expires_at"]: return token_cache["token"]
- try:
- with open("/config/secrets.yaml", "r") as f:
- sec = yaml.safe_load(f).get("oidc", {})
- payload = {
- "grant_type": "password", "client_id": api_c.get("client_id", "Fastapi"),
- "client_secret": sec.get("client_secret", ""),
- "username": sec.get("username", "core"), "password": sec.get("password", "")
- }
- resp = requests.post(api_c["token_url"], data=payload, verify=False, timeout=10)
- if resp.status_code == 200:
- d = resp.json()
- token_cache["token"] = d["access_token"]
- token_cache["expires_at"] = time.time() + d.get("expires_in", 300) - 30
- return token_cache["token"]
- except: pass
- return None
-
- def fetch_beacons():
- token = get_token()
- if not token: return []
- try:
- headers = {"Authorization": f"Bearer {token}", "accept": "application/json"}
- resp = requests.get(api_c["get_beacons_url"], headers=headers, verify=False, timeout=10)
- # Normalizziamo anche i beacon cercati per il confronto
- return [_norm_mac(it["mac"]) for it in resp.json() if "mac" in it] if resp.status_code == 200 else []
- except: return []
-
- def on_message(client, userdata, msg):
- # msg.topic es: publish_out/ac233fc1dd4e
- raw_gw = msg.topic.split("/")[-1]
- gw_norm = _norm_mac(raw_gw)
-
- if gw_norm not in gateways_set: return
-
- try:
- items = json.loads(msg.payload.decode())
- for it in items:
- raw_bm = it.get("mac")
- rssi = it.get("rssi")
- if raw_bm and rssi is not None:
- bm_norm = _norm_mac(raw_bm)
- rolling.add(bm_norm, gw_norm, float(rssi))
- except: pass
-
- mqtt_client = mqtt.Client()
- mqtt_client.on_message = on_message
- try:
- mqtt_client.connect(mqtt_c["host"], mqtt_c["port"])
- mqtt_client.subscribe(mqtt_c["topic"])
- mqtt_client.loop_start()
- except Exception as e:
- log(f"MQTT Error: {e}")
-
- last_predict, last_api_refresh = 0.0, 0.0
- beacons_to_track = []
-
- while True:
- now = time.time()
- rolling.prune()
- load_model_dynamic()
-
- if now - last_api_refresh >= float(api_c.get("refresh_seconds", 30)):
- beacons_to_track = fetch_beacons()
- last_api_refresh = now
-
- if now - last_predict >= float(inf_c.get("refresh_seconds", 10.0)):
- rows, count_ok = [], 0
-
- if model_pkg and beacons_to_track:
- for bm_n in beacons_to_track:
- X, n_found = rolling.aggregate_features(bm_n, gateways_ordered, inf_c.get("aggregate", "median"), nan_fill)
-
- z, x, y = -1, -1.0, -1.0
- if n_found >= int(inf_c.get("min_non_nan", 1)):
- try:
- z, x, y = _predict_xyz(model_pkg, X)
- if z != -1: count_ok += 1
- except: pass
-
- rows.append(f"{bm_n};{int(z)};{int(round(x))};{int(round(y))}")
-
- try:
- out_p = Path(inf_c.get("output_csv", "/data/infer/infer.csv"))
- out_p.parent.mkdir(parents=True, exist_ok=True)
- with open(str(out_p) + ".tmp", "w") as f:
- f.write("mac;z;x;y\n")
- for r in rows: f.write(r + "\n")
- os.replace(str(out_p) + ".tmp", out_p)
- log(f"CYCLE: {count_ok}/{len(beacons_to_track)} localized (Input GW match: {len(gateways_set)})")
- except Exception as e: log(f"File Error: {e}")
-
- last_predict = now
- time.sleep(0.5)
|