Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.
 
 
 
 

220 righe
8.0 KiB

  1. # -*- coding: utf-8 -*-
  2. import json
  3. import os
  4. import time
  5. import joblib
  6. import numpy as np
  7. import requests
  8. import yaml
  9. import paho.mqtt.client as mqtt
  10. from dataclasses import dataclass
  11. from pathlib import Path
  12. from typing import Any, Dict, List, Tuple
  13. from .logger_utils import log_msg as log
  14. BUILD_TAG = "infer-mac-fix-v21"
  15. # -------------------------
  16. # UTILITIES
  17. # -------------------------
  18. def _norm_mac(s: str) -> str:
  19. """
  20. Forza il formato standard xx:xx:xx:xx:xx:xx in MINUSCOLO per matchare il modello.
  21. Esempio MQTT: 'AC233FC1DD4E' -> 'ac:23:3f:c1:dd:4e'
  22. """
  23. s = (s or "").strip().replace("-", "").replace(":", "").replace(".", "").lower()
  24. if len(s) != 12: return s
  25. return ":".join([s[i:i+2] for i in range(0, 12, 2)])
  26. def _predict_xyz(model_pkg: Dict[str, Any], X: np.ndarray) -> Tuple[int, float, float]:
  27. floor_clf = model_pkg.get("floor_clf")
  28. z_pred = floor_clf.predict(X)
  29. z = int(z_pred[0])
  30. xy_models = model_pkg.get("xy_by_floor", {})
  31. regressor = xy_models.get(z)
  32. if regressor is None:
  33. return z, -1.0, -1.0
  34. xy_pred = regressor.predict(X)
  35. x, y = xy_pred[0]
  36. return z, float(x), float(y)
  37. @dataclass
  38. class _Point:
  39. t: float
  40. v: float
  41. class RollingRSSI:
  42. def __init__(self, window_s: float):
  43. self.window_s = window_s
  44. self.data: Dict[str, Dict[str, List[_Point]]] = {}
  45. def add(self, bm: str, gm: str, rssi: float):
  46. # bm e gm arrivano già normalizzati da on_message
  47. self.data.setdefault(bm, {}).setdefault(gm, []).append(_Point(time.time(), rssi))
  48. def prune(self):
  49. cutoff = time.time() - self.window_s
  50. for bm in list(self.data.keys()):
  51. for gm in list(self.data[bm].keys()):
  52. self.data[bm][gm] = [p for p in self.data[bm][gm] if p.t >= cutoff]
  53. if not self.data[bm][gm]: self.data[bm].pop(gm)
  54. if not self.data[bm]: self.data.pop(bm)
  55. def aggregate_features(self, bm: str, gws: List[str], agg: str, fill: float):
  56. per_gw = self.data.get(bm, {})
  57. feats = []
  58. found_count = 0
  59. for gm in gws:
  60. # gm è già minuscolo con : perché preso da gateways_ordered (dal modello)
  61. vals = [p.v for p in per_gw.get(gm, [])]
  62. if vals:
  63. val = np.median(vals) if agg == "median" else np.mean(vals)
  64. feats.append(val)
  65. found_count += 1
  66. else:
  67. feats.append(np.nan)
  68. X = np.array(feats).reshape(1, -1)
  69. return np.where(np.isnan(X), fill, X), found_count
  70. # -------------------------
  71. # MAIN RUNNER
  72. # -------------------------
  73. def run_infer(settings: Dict[str, Any]):
  74. inf_c = settings.get("infer", {})
  75. api_c = settings.get("api", {})
  76. mqtt_c = settings.get("mqtt", {})
  77. model_path = inf_c.get("model_path", "/data/model/model.joblib")
  78. log(f"INFER_MODE build tag={BUILD_TAG}")
  79. model_pkg = None
  80. last_model_mtime = 0
  81. gateways_ordered = [] # Saranno in formato aa:bb:cc...
  82. gateways_set = set()
  83. nan_fill = -110.0
  84. def load_model_dynamic():
  85. nonlocal model_pkg, last_model_mtime, gateways_ordered, gateways_set, nan_fill
  86. try:
  87. current_mtime = os.path.getmtime(model_path)
  88. if current_mtime != last_model_mtime:
  89. log(f"🧠 Aggiornamento modello: {model_path}")
  90. model_pkg = joblib.load(model_path)
  91. last_model_mtime = current_mtime
  92. # Il modello ha i gateway salvati come caricati nell'addestramento
  93. gateways_ordered = [gw.lower() for gw in model_pkg.get("gateways_order", [])]
  94. gateways_set = set(gateways_ordered)
  95. nan_fill = float(model_pkg.get("nan_fill", -110.0))
  96. log(f"✅ MODELLO PRONTO: {len(gateways_ordered)} GW. Fill: {nan_fill}")
  97. return True
  98. except: return False
  99. return True
  100. load_model_dynamic()
  101. rolling = RollingRSSI(float(inf_c.get("window_seconds", 5.0)))
  102. # Cache Beacons
  103. token_cache = {"token": None, "expires_at": 0}
  104. def get_token():
  105. if time.time() < token_cache["expires_at"]: return token_cache["token"]
  106. try:
  107. with open("/config/secrets.yaml", "r") as f:
  108. sec = yaml.safe_load(f).get("oidc", {})
  109. payload = {
  110. "grant_type": "password", "client_id": api_c.get("client_id", "Fastapi"),
  111. "client_secret": sec.get("client_secret", ""),
  112. "username": sec.get("username", "core"), "password": sec.get("password", "")
  113. }
  114. resp = requests.post(api_c["token_url"], data=payload, verify=False, timeout=10)
  115. if resp.status_code == 200:
  116. d = resp.json()
  117. token_cache["token"] = d["access_token"]
  118. token_cache["expires_at"] = time.time() + d.get("expires_in", 300) - 30
  119. return token_cache["token"]
  120. except: pass
  121. return None
  122. def fetch_beacons():
  123. token = get_token()
  124. if not token: return []
  125. try:
  126. headers = {"Authorization": f"Bearer {token}", "accept": "application/json"}
  127. resp = requests.get(api_c["get_beacons_url"], headers=headers, verify=False, timeout=10)
  128. # Normalizziamo anche i beacon cercati per il confronto
  129. return [_norm_mac(it["mac"]) for it in resp.json() if "mac" in it] if resp.status_code == 200 else []
  130. except: return []
  131. def on_message(client, userdata, msg):
  132. # msg.topic es: publish_out/ac233fc1dd4e
  133. raw_gw = msg.topic.split("/")[-1]
  134. gw_norm = _norm_mac(raw_gw)
  135. if gw_norm not in gateways_set: return
  136. try:
  137. items = json.loads(msg.payload.decode())
  138. for it in items:
  139. raw_bm = it.get("mac")
  140. rssi = it.get("rssi")
  141. if raw_bm and rssi is not None:
  142. bm_norm = _norm_mac(raw_bm)
  143. rolling.add(bm_norm, gw_norm, float(rssi))
  144. except: pass
  145. mqtt_client = mqtt.Client()
  146. mqtt_client.on_message = on_message
  147. try:
  148. mqtt_client.connect(mqtt_c["host"], mqtt_c["port"])
  149. mqtt_client.subscribe(mqtt_c["topic"])
  150. mqtt_client.loop_start()
  151. except Exception as e:
  152. log(f"MQTT Error: {e}")
  153. last_predict, last_api_refresh = 0.0, 0.0
  154. beacons_to_track = []
  155. while True:
  156. now = time.time()
  157. rolling.prune()
  158. load_model_dynamic()
  159. if now - last_api_refresh >= float(api_c.get("refresh_seconds", 30)):
  160. beacons_to_track = fetch_beacons()
  161. last_api_refresh = now
  162. if now - last_predict >= float(inf_c.get("refresh_seconds", 10.0)):
  163. rows, count_ok = [], 0
  164. if model_pkg and beacons_to_track:
  165. for bm_n in beacons_to_track:
  166. X, n_found = rolling.aggregate_features(bm_n, gateways_ordered, inf_c.get("aggregate", "median"), nan_fill)
  167. z, x, y = -1, -1.0, -1.0
  168. if n_found >= int(inf_c.get("min_non_nan", 1)):
  169. try:
  170. z, x, y = _predict_xyz(model_pkg, X)
  171. if z != -1: count_ok += 1
  172. except: pass
  173. rows.append(f"{bm_n};{int(z)};{int(round(x))};{int(round(y))}")
  174. try:
  175. out_p = Path(inf_c.get("output_csv", "/data/infer/infer.csv"))
  176. out_p.parent.mkdir(parents=True, exist_ok=True)
  177. with open(str(out_p) + ".tmp", "w") as f:
  178. f.write("mac;z;x;y\n")
  179. for r in rows: f.write(r + "\n")
  180. os.replace(str(out_p) + ".tmp", out_p)
  181. log(f"CYCLE: {count_ok}/{len(beacons_to_track)} localized (Input GW match: {len(gateways_set)})")
  182. except Exception as e: log(f"File Error: {e}")
  183. last_predict = now
  184. time.sleep(0.5)