|
|
|
@@ -9,28 +9,26 @@ import yaml |
|
|
|
import paho.mqtt.client as mqtt |
|
|
|
from dataclasses import dataclass |
|
|
|
from pathlib import Path |
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple |
|
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
from .logger_utils import log_msg as log |
|
|
|
|
|
|
|
# Importiamo la funzione per caricare i gateway per mantenere l'ordine delle feature |
|
|
|
from .csv_config import load_gateway_features_csv |
|
|
|
|
|
|
|
BUILD_TAG = "infer-debug-v20-autoreload" |
|
|
|
BUILD_TAG = "infer-mac-fix-v21" |
|
|
|
|
|
|
|
# ------------------------- |
|
|
|
# UTILITIES |
|
|
|
# ------------------------- |
|
|
|
|
|
|
|
def _norm_mac(s: str) -> str: |
|
|
|
s = (s or "").strip().replace("-", "").replace(":", "").replace(".", "").upper() |
|
|
|
""" |
|
|
|
Forza il formato standard xx:xx:xx:xx:xx:xx in MINUSCOLO per matchare il modello. |
|
|
|
Esempio MQTT: 'AC233FC1DD4E' -> 'ac:23:3f:c1:dd:4e' |
|
|
|
""" |
|
|
|
s = (s or "").strip().replace("-", "").replace(":", "").replace(".", "").lower() |
|
|
|
if len(s) != 12: return s |
|
|
|
return ":".join([s[i:i+2] for i in range(0, 12, 2)]) |
|
|
|
|
|
|
|
def _predict_xyz(model_pkg: Dict[str, Any], X: np.ndarray) -> Tuple[int, float, float]: |
|
|
|
floor_clf = model_pkg.get("floor_clf") |
|
|
|
if floor_clf is None: |
|
|
|
raise ValueError("Il pacchetto modello non contiene 'floor_clf'") |
|
|
|
|
|
|
|
z_pred = floor_clf.predict(X) |
|
|
|
z = int(z_pred[0]) |
|
|
|
|
|
|
|
@@ -42,7 +40,6 @@ def _predict_xyz(model_pkg: Dict[str, Any], X: np.ndarray) -> Tuple[int, float, |
|
|
|
|
|
|
|
xy_pred = regressor.predict(X) |
|
|
|
x, y = xy_pred[0] |
|
|
|
|
|
|
|
return z, float(x), float(y) |
|
|
|
|
|
|
|
@dataclass |
|
|
|
@@ -56,6 +53,7 @@ class RollingRSSI: |
|
|
|
self.data: Dict[str, Dict[str, List[_Point]]] = {} |
|
|
|
|
|
|
|
def add(self, bm: str, gm: str, rssi: float): |
|
|
|
# bm e gm arrivano già normalizzati da on_message |
|
|
|
self.data.setdefault(bm, {}).setdefault(gm, []).append(_Point(time.time(), rssi)) |
|
|
|
|
|
|
|
def prune(self): |
|
|
|
@@ -71,6 +69,7 @@ class RollingRSSI: |
|
|
|
feats = [] |
|
|
|
found_count = 0 |
|
|
|
for gm in gws: |
|
|
|
# gm è già minuscolo con : perché preso da gateways_ordered (dal modello) |
|
|
|
vals = [p.v for p in per_gw.get(gm, [])] |
|
|
|
if vals: |
|
|
|
val = np.median(vals) if agg == "median" else np.mean(vals) |
|
|
|
@@ -93,10 +92,9 @@ def run_infer(settings: Dict[str, Any]): |
|
|
|
|
|
|
|
log(f"INFER_MODE build tag={BUILD_TAG}") |
|
|
|
|
|
|
|
# Variabili di stato per il caricamento dinamico |
|
|
|
model_pkg = None |
|
|
|
last_model_mtime = 0 |
|
|
|
gateways_ordered = [] |
|
|
|
gateways_ordered = [] # Saranno in formato aa:bb:cc... |
|
|
|
gateways_set = set() |
|
|
|
nan_fill = -110.0 |
|
|
|
|
|
|
|
@@ -105,34 +103,24 @@ def run_infer(settings: Dict[str, Any]): |
|
|
|
try: |
|
|
|
current_mtime = os.path.getmtime(model_path) |
|
|
|
if current_mtime != last_model_mtime: |
|
|
|
log(f"🧠 Caricamento/Aggiornamento modello: {model_path}") |
|
|
|
log(f"🧠 Aggiornamento modello: {model_path}") |
|
|
|
model_pkg = joblib.load(model_path) |
|
|
|
last_model_mtime = current_mtime |
|
|
|
|
|
|
|
# METADATI PER DEBUG |
|
|
|
gateways_ordered = model_pkg.get("gateways_order", []) |
|
|
|
# Il modello ha i gateway salvati come caricati nell'addestramento |
|
|
|
gateways_ordered = [gw.lower() for gw in model_pkg.get("gateways_order", [])] |
|
|
|
gateways_set = set(gateways_ordered) |
|
|
|
nan_fill = float(model_pkg.get("nan_fill", -110.0)) |
|
|
|
floors = list(model_pkg.get("xy_by_floor", {}).keys()) |
|
|
|
|
|
|
|
log(f"✅ MODELLO PRONTO: {len(gateways_ordered)} GW allenati.") |
|
|
|
log(f"🏢 Piani mappati: {floors}") |
|
|
|
log(f"🧪 Valore Fill (NaN): {nan_fill}") |
|
|
|
if len(gateways_ordered) > 0: |
|
|
|
log(f"📡 Primi 3 GW di riferimento: {gateways_ordered[:3]}") |
|
|
|
log(f"✅ MODELLO PRONTO: {len(gateways_ordered)} GW. Fill: {nan_fill}") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
|
|
if model_pkg is None: |
|
|
|
log(f"⚠️ In attesa di un modello valido in {model_path}...") |
|
|
|
return False |
|
|
|
except: return False |
|
|
|
return True |
|
|
|
|
|
|
|
# Primo caricamento |
|
|
|
load_model_dynamic() |
|
|
|
|
|
|
|
rolling = RollingRSSI(float(inf_c.get("window_seconds", 5.0))) |
|
|
|
|
|
|
|
# --- Gestione Token e Beacons --- |
|
|
|
# Cache Beacons |
|
|
|
token_cache = {"token": None, "expires_at": 0} |
|
|
|
def get_token(): |
|
|
|
if time.time() < token_cache["expires_at"]: return token_cache["token"] |
|
|
|
@@ -159,18 +147,25 @@ def run_infer(settings: Dict[str, Any]): |
|
|
|
try: |
|
|
|
headers = {"Authorization": f"Bearer {token}", "accept": "application/json"} |
|
|
|
resp = requests.get(api_c["get_beacons_url"], headers=headers, verify=False, timeout=10) |
|
|
|
return [it["mac"] for it in resp.json() if "mac" in it] if resp.status_code == 200 else [] |
|
|
|
# Normalizziamo anche i beacon cercati per il confronto |
|
|
|
return [_norm_mac(it["mac"]) for it in resp.json() if "mac" in it] if resp.status_code == 200 else [] |
|
|
|
except: return [] |
|
|
|
|
|
|
|
def on_message(client, userdata, msg): |
|
|
|
if not gateways_set: return |
|
|
|
gw = _norm_mac(msg.topic.split("/")[-1]) |
|
|
|
if gw not in gateways_set: return |
|
|
|
# msg.topic es: publish_out/ac233fc1dd4e |
|
|
|
raw_gw = msg.topic.split("/")[-1] |
|
|
|
gw_norm = _norm_mac(raw_gw) |
|
|
|
|
|
|
|
if gw_norm not in gateways_set: return |
|
|
|
|
|
|
|
try: |
|
|
|
items = json.loads(msg.payload.decode()) |
|
|
|
for it in items: |
|
|
|
bm, rssi = _norm_mac(it.get("mac")), it.get("rssi") |
|
|
|
if bm and rssi is not None: rolling.add(bm, gw, float(rssi)) |
|
|
|
raw_bm = it.get("mac") |
|
|
|
rssi = it.get("rssi") |
|
|
|
if raw_bm and rssi is not None: |
|
|
|
bm_norm = _norm_mac(raw_bm) |
|
|
|
rolling.add(bm_norm, gw_norm, float(rssi)) |
|
|
|
except: pass |
|
|
|
|
|
|
|
mqtt_client = mqtt.Client() |
|
|
|
@@ -180,28 +175,25 @@ def run_infer(settings: Dict[str, Any]): |
|
|
|
mqtt_client.subscribe(mqtt_c["topic"]) |
|
|
|
mqtt_client.loop_start() |
|
|
|
except Exception as e: |
|
|
|
log(f"MQTT Connect Error: {e}") |
|
|
|
log(f"MQTT Error: {e}") |
|
|
|
|
|
|
|
last_predict, last_api_refresh = 0.0, 0.0 |
|
|
|
beacons = [] |
|
|
|
beacons_to_track = [] |
|
|
|
|
|
|
|
while True: |
|
|
|
now = time.time() |
|
|
|
rolling.prune() |
|
|
|
|
|
|
|
# Verifica se il link simbolico model.joblib è cambiato |
|
|
|
load_model_dynamic() |
|
|
|
|
|
|
|
if now - last_api_refresh >= float(api_c.get("refresh_seconds", 30)): |
|
|
|
beacons = fetch_beacons() |
|
|
|
beacons_to_track = fetch_beacons() |
|
|
|
last_api_refresh = now |
|
|
|
|
|
|
|
if now - last_predict >= float(inf_c.get("refresh_seconds", 10.0)): |
|
|
|
rows, count_ok = [], 0 |
|
|
|
|
|
|
|
if model_pkg: |
|
|
|
for bm in beacons: |
|
|
|
bm_n = _norm_mac(bm) |
|
|
|
if model_pkg and beacons_to_track: |
|
|
|
for bm_n in beacons_to_track: |
|
|
|
X, n_found = rolling.aggregate_features(bm_n, gateways_ordered, inf_c.get("aggregate", "median"), nan_fill) |
|
|
|
|
|
|
|
z, x, y = -1, -1.0, -1.0 |
|
|
|
@@ -209,8 +201,7 @@ def run_infer(settings: Dict[str, Any]): |
|
|
|
try: |
|
|
|
z, x, y = _predict_xyz(model_pkg, X) |
|
|
|
if z != -1: count_ok += 1 |
|
|
|
except Exception as e: |
|
|
|
pass |
|
|
|
except: pass |
|
|
|
|
|
|
|
rows.append(f"{bm_n};{int(z)};{int(round(x))};{int(round(y))}") |
|
|
|
|
|
|
|
@@ -221,7 +212,7 @@ def run_infer(settings: Dict[str, Any]): |
|
|
|
f.write("mac;z;x;y\n") |
|
|
|
for r in rows: f.write(r + "\n") |
|
|
|
os.replace(str(out_p) + ".tmp", out_p) |
|
|
|
log(f"CYCLE: {count_ok}/{len(rows)} localized") |
|
|
|
log(f"CYCLE: {count_ok}/{len(beacons_to_track)} localized (Input GW match: {len(gateways_set)})") |
|
|
|
except Exception as e: log(f"File Error: {e}") |
|
|
|
|
|
|
|
last_predict = now |
|
|
|
|