Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.
 
 
 
 

229 lignes
8.2 KiB

  1. # -*- coding: utf-8 -*-
  2. import json
  3. import os
  4. import time
  5. import joblib
  6. import numpy as np
  7. import requests
  8. import yaml
  9. import paho.mqtt.client as mqtt
  10. from dataclasses import dataclass
  11. from pathlib import Path
  12. from typing import Any, Callable, Dict, List, Optional, Tuple
  13. from .logger_utils import log_msg as log
  14. # Importiamo la funzione per caricare i gateway per mantenere l'ordine delle feature
  15. from .csv_config import load_gateway_features_csv
  16. BUILD_TAG = "infer-debug-v20-autoreload"
  17. # -------------------------
  18. # UTILITIES
  19. # -------------------------
  20. def _norm_mac(s: str) -> str:
  21. s = (s or "").strip().replace("-", "").replace(":", "").replace(".", "").upper()
  22. if len(s) != 12: return s
  23. return ":".join([s[i:i+2] for i in range(0, 12, 2)])
  24. def _predict_xyz(model_pkg: Dict[str, Any], X: np.ndarray) -> Tuple[int, float, float]:
  25. floor_clf = model_pkg.get("floor_clf")
  26. if floor_clf is None:
  27. raise ValueError("Il pacchetto modello non contiene 'floor_clf'")
  28. z_pred = floor_clf.predict(X)
  29. z = int(z_pred[0])
  30. xy_models = model_pkg.get("xy_by_floor", {})
  31. regressor = xy_models.get(z)
  32. if regressor is None:
  33. return z, -1.0, -1.0
  34. xy_pred = regressor.predict(X)
  35. x, y = xy_pred[0]
  36. return z, float(x), float(y)
  37. @dataclass
  38. class _Point:
  39. t: float
  40. v: float
  41. class RollingRSSI:
  42. def __init__(self, window_s: float):
  43. self.window_s = window_s
  44. self.data: Dict[str, Dict[str, List[_Point]]] = {}
  45. def add(self, bm: str, gm: str, rssi: float):
  46. self.data.setdefault(bm, {}).setdefault(gm, []).append(_Point(time.time(), rssi))
  47. def prune(self):
  48. cutoff = time.time() - self.window_s
  49. for bm in list(self.data.keys()):
  50. for gm in list(self.data[bm].keys()):
  51. self.data[bm][gm] = [p for p in self.data[bm][gm] if p.t >= cutoff]
  52. if not self.data[bm][gm]: self.data[bm].pop(gm)
  53. if not self.data[bm]: self.data.pop(bm)
  54. def aggregate_features(self, bm: str, gws: List[str], agg: str, fill: float):
  55. per_gw = self.data.get(bm, {})
  56. feats = []
  57. found_count = 0
  58. for gm in gws:
  59. vals = [p.v for p in per_gw.get(gm, [])]
  60. if vals:
  61. val = np.median(vals) if agg == "median" else np.mean(vals)
  62. feats.append(val)
  63. found_count += 1
  64. else:
  65. feats.append(np.nan)
  66. X = np.array(feats).reshape(1, -1)
  67. return np.where(np.isnan(X), fill, X), found_count
  68. # -------------------------
  69. # MAIN RUNNER
  70. # -------------------------
  71. def run_infer(settings: Dict[str, Any]):
  72. inf_c = settings.get("infer", {})
  73. api_c = settings.get("api", {})
  74. mqtt_c = settings.get("mqtt", {})
  75. model_path = inf_c.get("model_path", "/data/model/model.joblib")
  76. log(f"INFER_MODE build tag={BUILD_TAG}")
  77. # Variabili di stato per il caricamento dinamico
  78. model_pkg = None
  79. last_model_mtime = 0
  80. gateways_ordered = []
  81. gateways_set = set()
  82. nan_fill = -110.0
  83. def load_model_dynamic():
  84. nonlocal model_pkg, last_model_mtime, gateways_ordered, gateways_set, nan_fill
  85. try:
  86. current_mtime = os.path.getmtime(model_path)
  87. if current_mtime != last_model_mtime:
  88. log(f"🧠 Caricamento/Aggiornamento modello: {model_path}")
  89. model_pkg = joblib.load(model_path)
  90. last_model_mtime = current_mtime
  91. # METADATI PER DEBUG
  92. gateways_ordered = model_pkg.get("gateways_order", [])
  93. gateways_set = set(gateways_ordered)
  94. nan_fill = float(model_pkg.get("nan_fill", -110.0))
  95. floors = list(model_pkg.get("xy_by_floor", {}).keys())
  96. log(f"✅ MODELLO PRONTO: {len(gateways_ordered)} GW allenati.")
  97. log(f"🏢 Piani mappati: {floors}")
  98. log(f"🧪 Valore Fill (NaN): {nan_fill}")
  99. if len(gateways_ordered) > 0:
  100. log(f"📡 Primi 3 GW di riferimento: {gateways_ordered[:3]}")
  101. return True
  102. except Exception as e:
  103. if model_pkg is None:
  104. log(f"⚠️ In attesa di un modello valido in {model_path}...")
  105. return False
  106. return True
  107. # Primo caricamento
  108. load_model_dynamic()
  109. rolling = RollingRSSI(float(inf_c.get("window_seconds", 5.0)))
  110. # --- Gestione Token e Beacons ---
  111. token_cache = {"token": None, "expires_at": 0}
  112. def get_token():
  113. if time.time() < token_cache["expires_at"]: return token_cache["token"]
  114. try:
  115. with open("/config/secrets.yaml", "r") as f:
  116. sec = yaml.safe_load(f).get("oidc", {})
  117. payload = {
  118. "grant_type": "password", "client_id": api_c.get("client_id", "Fastapi"),
  119. "client_secret": sec.get("client_secret", ""),
  120. "username": sec.get("username", "core"), "password": sec.get("password", "")
  121. }
  122. resp = requests.post(api_c["token_url"], data=payload, verify=False, timeout=10)
  123. if resp.status_code == 200:
  124. d = resp.json()
  125. token_cache["token"] = d["access_token"]
  126. token_cache["expires_at"] = time.time() + d.get("expires_in", 300) - 30
  127. return token_cache["token"]
  128. except: pass
  129. return None
  130. def fetch_beacons():
  131. token = get_token()
  132. if not token: return []
  133. try:
  134. headers = {"Authorization": f"Bearer {token}", "accept": "application/json"}
  135. resp = requests.get(api_c["get_beacons_url"], headers=headers, verify=False, timeout=10)
  136. return [it["mac"] for it in resp.json() if "mac" in it] if resp.status_code == 200 else []
  137. except: return []
  138. def on_message(client, userdata, msg):
  139. if not gateways_set: return
  140. gw = _norm_mac(msg.topic.split("/")[-1])
  141. if gw not in gateways_set: return
  142. try:
  143. items = json.loads(msg.payload.decode())
  144. for it in items:
  145. bm, rssi = _norm_mac(it.get("mac")), it.get("rssi")
  146. if bm and rssi is not None: rolling.add(bm, gw, float(rssi))
  147. except: pass
  148. mqtt_client = mqtt.Client()
  149. mqtt_client.on_message = on_message
  150. try:
  151. mqtt_client.connect(mqtt_c["host"], mqtt_c["port"])
  152. mqtt_client.subscribe(mqtt_c["topic"])
  153. mqtt_client.loop_start()
  154. except Exception as e:
  155. log(f"MQTT Connect Error: {e}")
  156. last_predict, last_api_refresh = 0.0, 0.0
  157. beacons = []
  158. while True:
  159. now = time.time()
  160. rolling.prune()
  161. # Verifica se il link simbolico model.joblib è cambiato
  162. load_model_dynamic()
  163. if now - last_api_refresh >= float(api_c.get("refresh_seconds", 30)):
  164. beacons = fetch_beacons()
  165. last_api_refresh = now
  166. if now - last_predict >= float(inf_c.get("refresh_seconds", 10.0)):
  167. rows, count_ok = [], 0
  168. if model_pkg:
  169. for bm in beacons:
  170. bm_n = _norm_mac(bm)
  171. X, n_found = rolling.aggregate_features(bm_n, gateways_ordered, inf_c.get("aggregate", "median"), nan_fill)
  172. z, x, y = -1, -1.0, -1.0
  173. if n_found >= int(inf_c.get("min_non_nan", 1)):
  174. try:
  175. z, x, y = _predict_xyz(model_pkg, X)
  176. if z != -1: count_ok += 1
  177. except Exception as e:
  178. pass
  179. rows.append(f"{bm_n};{int(z)};{int(round(x))};{int(round(y))}")
  180. try:
  181. out_p = Path(inf_c.get("output_csv", "/data/infer/infer.csv"))
  182. out_p.parent.mkdir(parents=True, exist_ok=True)
  183. with open(str(out_p) + ".tmp", "w") as f:
  184. f.write("mac;z;x;y\n")
  185. for r in rows: f.write(r + "\n")
  186. os.replace(str(out_p) + ".tmp", out_p)
  187. log(f"CYCLE: {count_ok}/{len(rows)} localized")
  188. except Exception as e: log(f"File Error: {e}")
  189. last_predict = now
  190. time.sleep(0.5)