Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.
 
 
 
 

203 строки
7.4 KiB

  1. # -*- coding: utf-8 -*-
  2. import json
  3. import os
  4. import time
  5. import joblib
  6. import numpy as np
  7. import requests
  8. import yaml
  9. import paho.mqtt.client as mqtt
  10. from dataclasses import dataclass
  11. from pathlib import Path
  12. from typing import Any, Callable, Dict, List, Optional, Tuple
  13. from .logger_utils import log_msg as log
  14. # Importiamo la funzione per caricare i gateway per mantenere l'ordine delle feature
  15. from .csv_config import load_gateway_features_csv
  16. BUILD_TAG = "infer-debug-v19-hierarchical-final"
  17. # -------------------------
  18. # UTILITIES
  19. # -------------------------
  20. def _norm_mac(s: str) -> str:
  21. s = (s or "").strip().replace("-", "").replace(":", "").replace(".", "").upper()
  22. if len(s) != 12: return s
  23. return ":".join([s[i:i+2] for i in range(0, 12, 2)])
  24. def _predict_xyz(model_pkg: Dict[str, Any], X: np.ndarray) -> Tuple[int, float, float]:
  25. """
  26. Logica di predizione basata su train_mode.py:
  27. 1. Predice Z (piano) tramite KNeighborsClassifier
  28. 2. Predice X, Y tramite KNeighborsRegressor specifico per quel piano
  29. """
  30. # 1. Predizione del Piano (Z)
  31. floor_clf = model_pkg.get("floor_clf")
  32. if floor_clf is None:
  33. raise ValueError("Il pacchetto modello non contiene 'floor_clf'")
  34. z_pred = floor_clf.predict(X)
  35. z = int(z_pred[0])
  36. # 2. Predizione X, Y (Coordinate)
  37. xy_models = model_pkg.get("xy_by_floor", {})
  38. regressor = xy_models.get(z)
  39. if regressor is None:
  40. # Se il piano predetto non ha un regressore XY, restituiamo solo il piano
  41. return z, -1.0, -1.0
  42. xy_pred = regressor.predict(X)
  43. x, y = xy_pred[0] # L'output del regressore multi-output è [x, y]
  44. return z, float(x), float(y)
  45. @dataclass
  46. class _Point:
  47. t: float
  48. v: float
  49. class RollingRSSI:
  50. def __init__(self, window_s: float):
  51. self.window_s = window_s
  52. self.data: Dict[str, Dict[str, List[_Point]]] = {}
  53. def add(self, bm: str, gm: str, rssi: float):
  54. self.data.setdefault(bm, {}).setdefault(gm, []).append(_Point(time.time(), rssi))
  55. def prune(self):
  56. cutoff = time.time() - self.window_s
  57. for bm in list(self.data.keys()):
  58. for gm in list(self.data[bm].keys()):
  59. self.data[bm][gm] = [p for p in self.data[bm][gm] if p.t >= cutoff]
  60. if not self.data[bm][gm]: self.data[bm].pop(gm)
  61. if not self.data[bm]: self.data.pop(bm)
  62. def aggregate_features(self, bm: str, gws: List[str], agg: str, fill: float):
  63. per_gw = self.data.get(bm, {})
  64. feats = []
  65. found_count = 0
  66. for gm in gws:
  67. vals = [p.v for p in per_gw.get(gm, [])]
  68. if vals:
  69. val = np.median(vals) if agg == "median" else np.mean(vals)
  70. feats.append(val)
  71. found_count += 1
  72. else:
  73. feats.append(np.nan)
  74. X = np.array(feats).reshape(1, -1)
  75. # Sostituisce i NaN con il valore di riempimento definito nel training
  76. return np.where(np.isnan(X), fill, X), found_count
  77. # -------------------------
  78. # MAIN RUNNER
  79. # -------------------------
  80. def run_infer(settings: Dict[str, Any]):
  81. inf_c = settings.get("infer", {})
  82. api_c = settings.get("api", {})
  83. mqtt_c = settings.get("mqtt", {})
  84. log(f"INFER_MODE build tag={BUILD_TAG}")
  85. # Caricamento Modello e Configurazione Training
  86. try:
  87. model_pkg = joblib.load(inf_c.get("model_path", "/data/model/model.joblib"))
  88. # Recuperiamo l'ordine dei gateway e il nan_fill direttamente dal modello salvato
  89. gateways_ordered = model_pkg.get("gateways_order")
  90. nan_fill = float(model_pkg.get("nan_fill", -110.0))
  91. log(f"Model loaded. Features: {len(gateways_ordered)}, Fill: {nan_fill}")
  92. except Exception as e:
  93. log(f"CRITICAL: Failed to load model: {e}")
  94. return
  95. gateways_set = set(gateways_ordered)
  96. rolling = RollingRSSI(float(inf_c.get("window_seconds", 5.0)))
  97. # --- Gestione Token e Beacons (identica a prima) ---
  98. token_cache = {"token": None, "expires_at": 0}
  99. def get_token():
  100. if time.time() < token_cache["expires_at"]: return token_cache["token"]
  101. try:
  102. # Recupero credenziali da secrets.yaml
  103. with open("/config/secrets.yaml", "r") as f:
  104. sec = yaml.safe_load(f).get("oidc", {})
  105. payload = {
  106. "grant_type": "password", "client_id": api_c.get("client_id", "Fastapi"),
  107. "client_secret": sec.get("client_secret", ""),
  108. "username": sec.get("username", "core"), "password": sec.get("password", "")
  109. }
  110. resp = requests.post(api_c["token_url"], data=payload, verify=False, timeout=10)
  111. if resp.status_code == 200:
  112. d = resp.json()
  113. token_cache["token"] = d["access_token"]
  114. token_cache["expires_at"] = time.time() + d.get("expires_in", 300) - 30
  115. return token_cache["token"]
  116. except: pass
  117. return None
  118. def fetch_beacons():
  119. token = get_token()
  120. if not token: return []
  121. try:
  122. headers = {"Authorization": f"Bearer {token}", "accept": "application/json"}
  123. resp = requests.get(api_c["get_beacons_url"], headers=headers, verify=False, timeout=10)
  124. return [it["mac"] for it in resp.json() if "mac" in it] if resp.status_code == 200 else []
  125. except: return []
  126. def on_message(client, userdata, msg):
  127. gw = _norm_mac(msg.topic.split("/")[-1])
  128. if gw not in gateways_set: return
  129. try:
  130. items = json.loads(msg.payload.decode())
  131. for it in items:
  132. bm, rssi = _norm_mac(it.get("mac")), it.get("rssi")
  133. if bm and rssi is not None: rolling.add(bm, gw, float(rssi))
  134. except: pass
  135. mqtt_client = mqtt.Client()
  136. mqtt_client.on_message = on_message
  137. mqtt_client.connect(mqtt_c["host"], mqtt_c["port"])
  138. mqtt_client.subscribe(mqtt_c["topic"])
  139. mqtt_client.loop_start()
  140. last_predict, last_api_refresh = 0.0, 0.0
  141. beacons = []
  142. while True:
  143. now = time.time()
  144. rolling.prune()
  145. if now - last_api_refresh >= float(api_c.get("refresh_seconds", 30)):
  146. beacons = fetch_beacons()
  147. last_api_refresh = now
  148. if now - last_predict >= float(inf_c.get("refresh_seconds", 10.0)):
  149. rows, count_ok = [], 0
  150. for bm in beacons:
  151. bm_n = _norm_mac(bm)
  152. X, n_found = rolling.aggregate_features(bm_n, gateways_ordered, inf_c.get("aggregate", "median"), nan_fill)
  153. z, x, y = -1, -1.0, -1.0
  154. if n_found >= int(inf_c.get("min_non_nan", 1)):
  155. try:
  156. z, x, y = _predict_xyz(model_pkg, X)
  157. if z != -1: count_ok += 1
  158. except Exception as e:
  159. log(f"Infer error {bm_n}: {e}")
  160. rows.append(f"{bm_n};{int(z)};{int(round(x))};{int(round(y))}")
  161. try:
  162. out_p = Path(inf_c.get("output_csv", "/data/infer/infer.csv"))
  163. with open(str(out_p) + ".tmp", "w") as f:
  164. f.write("mac;z;x;y\n")
  165. for r in rows: f.write(r + "\n")
  166. os.replace(str(out_p) + ".tmp", out_p)
  167. log(f"CYCLE: {count_ok}/{len(rows)} localized")
  168. except Exception as e: log(f"File Error: {e}")
  169. last_predict = now
  170. time.sleep(0.5)