You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

818 regels
27 KiB

  1. from .logger_utils import setup_global_logging, log_msg as log
  2. import csv
  3. import io
  4. import json
  5. import os
  6. import ssl
  7. import time
  8. import traceback
  9. from dataclasses import dataclass
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional, Tuple
  12. import hashlib
  13. import re
  14. import math
  15. import statistics
  16. import pandas as pd
  17. import numpy as np
  18. import requests
  19. import joblib
  20. import paho.mqtt.client as mqtt
  21. # Import locali corretti
  22. from .settings import load_settings
  23. def build_info() -> str:
  24. return "infer-debug-v19-fixed"
  25. def main() -> None:
  26. # 1. Carica impostazioni
  27. settings = load_settings()
  28. # 2. Setup immediato dei log e dei silenziatori (PRIMA di ogni altra cosa)
  29. setup_global_logging(settings)
  30. # 3. Ora puoi loggare e tutto sarà sincronizzato e pulito
  31. cfg_file = settings.get("_config_file", "/config/config.yaml")
  32. keys = [k for k in settings.keys() if not str(k).startswith("_")]
  33. log(f"Settings loaded from {cfg_file}. Keys: {keys}")
  34. log(f"BUILD: {build_info()}")
  35. def mac_plain(s: str) -> str:
  36. """Normalizza MAC a 12 hex uppercase senza separatori."""
  37. return re.sub(r"[^0-9A-Fa-f]", "", (s or "")).upper()
  38. def mac_colon(s: str) -> str:
  39. """MAC in formato AA:BB:CC:DD:EE:FF."""
  40. p = mac_plain(s)
  41. if len(p) != 12:
  42. return p
  43. return ":".join(p[i:i+2] for i in range(0, 12, 2))
  44. def fmt_rssi(v, decimals: int) -> str:
  45. """Formatta RSSI come stringa, evitando '-82.0' quando decimals=0."""
  46. if v is None:
  47. return "nan"
  48. try:
  49. fv = float(v)
  50. except Exception:
  51. return "nan"
  52. if math.isnan(fv):
  53. return "nan"
  54. if decimals <= 0:
  55. return str(int(round(fv)))
  56. return f"{round(fv, decimals):.{decimals}f}"
  57. # -----------------------------
  58. # Build info (printed at startup for traceability)
  59. BUILD_ID = "ble-ai-localizer main.py 2026-01-30 build-floatagg-v1"
  60. def build_info() -> str:
  61. """Return a short build identifier for logs (no external deps, no git required)."""
  62. try:
  63. p = Path(__file__)
  64. data = p.read_bytes()
  65. sha = hashlib.sha256(data).hexdigest()[:12]
  66. size = p.stat().st_size
  67. return f"{BUILD_ID} sha256={sha} size={size}"
  68. except Exception:
  69. return f"{BUILD_ID} sha256=? size=?"
  70. # Settings
  71. # -----------------------------
  72. def load_settings() -> Dict[str, Any]:
  73. cfg = os.environ.get("CONFIG", "/config/config.yaml")
  74. import yaml
  75. with open(cfg, "r", encoding="utf-8") as f:
  76. data = yaml.safe_load(f) or {}
  77. data["_config_file"] = cfg
  78. # Normalize config sections: prefer collect_train
  79. if "collect_train" not in data and "training" in data:
  80. log("WARNING: config usa 'training:' (alias). Consiglio: rinomina in 'collect_train:'")
  81. data["collect_train"] = data.get("training", {}) or {}
  82. return data
  83. # -----------------------------
  84. # MAC helpers
  85. # -----------------------------
  86. def norm_mac(mac: str) -> str:
  87. """Return MAC as AA:BB:CC:DD:EE:FF (upper), ignoring separators."""
  88. m = (mac or "").strip().replace("-", "").replace(":", "").replace(".", "")
  89. m = m.upper()
  90. if len(m) != 12:
  91. return mac.strip().upper()
  92. return ":".join(m[i:i+2] for i in range(0, 12, 2))
  93. # -----------------------------
  94. # CSV write helpers
  95. # -----------------------------
  96. def safe_write_csv(
  97. path: Path,
  98. header: List[str],
  99. rows: List[Dict[str, Any]],
  100. delimiter: str = ";",
  101. rssi_decimals: int = 0,
  102. ):
  103. """Scrive CSV in modo atomico e formattazione 'umana'.
  104. - numeri interi: senza decimali (es. -82 invece di -82.0)
  105. - RSSI: arrotondamento controllato da rssi_decimals (0 -> intero, >0 -> N cifre decimali)
  106. *si applica solo alle colonne RSSI (dopo mac/x/y/z)*
  107. - NaN: 'nan'
  108. - colonna 'mac': normalizzata in formato con ':' (es. C3:00:00:57:B9:E7) se passa un MAC valido
  109. """
  110. tmp = path.with_suffix(path.suffix + ".tmp")
  111. # csv.writer richiede un singolo carattere come delimiter
  112. if not isinstance(delimiter, str) or len(delimiter) != 1:
  113. delimiter = ";"
  114. try:
  115. rssi_decimals = int(rssi_decimals)
  116. except Exception:
  117. rssi_decimals = 0
  118. if rssi_decimals < 0:
  119. rssi_decimals = 0
  120. def fmt_cell(v: Any, col: str, idx: int) -> str:
  121. if v is None:
  122. return "nan"
  123. # MAC normalizzato con ':'
  124. if col.lower() == "mac" and isinstance(v, str):
  125. v2 = mac_colon(v)
  126. return v2
  127. # NaN float
  128. if isinstance(v, float):
  129. if math.isnan(v):
  130. return "nan"
  131. # colonne RSSI (dopo mac/x/y/z)
  132. if idx >= 4:
  133. if rssi_decimals == 0:
  134. return str(int(round(v)))
  135. return f"{v:.{rssi_decimals}f}"
  136. # altre colonne: compatta i (quasi) interi
  137. if abs(v - round(v)) < 1e-9:
  138. return str(int(round(v)))
  139. return str(v)
  140. # int / numpy int
  141. if isinstance(v, (int, np.integer)):
  142. # RSSI columns (after mac/x/y/z): respect rssi_decimals even for integer values
  143. if idx >= 4:
  144. if rssi_decimals == 0:
  145. return str(int(v))
  146. return f"{float(v):.{rssi_decimals}f}"
  147. return str(int(v))
  148. # numpy float
  149. if isinstance(v, np.floating):
  150. fv = float(v)
  151. if math.isnan(fv):
  152. return "nan"
  153. if idx >= 4:
  154. if rssi_decimals == 0:
  155. return str(int(round(fv)))
  156. return f"{fv:.{rssi_decimals}f}"
  157. if abs(fv - round(fv)) < 1e-9:
  158. return str(int(round(fv)))
  159. return str(fv)
  160. return str(v)
  161. with tmp.open("w", newline="") as f:
  162. w = csv.writer(f, delimiter=delimiter)
  163. w.writerow(header)
  164. for row in rows:
  165. w.writerow([fmt_cell(row.get(col), col, idx) for idx, col in enumerate(header)])
  166. tmp.replace(path)
  167. def _coord_token(v: float) -> str:
  168. # Stable token for filenames from coordinates.
  169. # - if integer-ish -> '123'
  170. # - else keep up to 3 decimals, strip trailing zeros, replace '.' with '_'
  171. try:
  172. fv=float(v)
  173. except Exception:
  174. return str(v)
  175. if abs(fv - round(fv)) < 1e-9:
  176. return str(int(round(fv)))
  177. s=f"{fv:.3f}".rstrip('0').rstrip('.')
  178. return s.replace('.', '_')
  179. def read_job_csv(job_path: Path, delimiter: str) -> List[Dict[str, Any]]:
  180. """Legge job CSV supportando due formati:
  181. 1) Legacy:
  182. mac;x;y;z
  183. C3000057B9F4;1200;450;0
  184. 2) Esteso (storico):
  185. Position;Floor;RoomName;X;Y;Z;BeaconName;MAC
  186. A21;1;P1-NETW;800;1050;1;BC-21;C3:00:00:57:B9:E6
  187. Estrae solo X,Y,Z,MAC e normalizza MAC in formato compatto (senza ':', uppercase).
  188. """
  189. text = job_path.read_text(encoding="utf-8", errors="replace")
  190. if not text.strip():
  191. return []
  192. first_line = next((ln for ln in text.splitlines() if ln.strip()), "")
  193. use_delim = delimiter
  194. if use_delim not in first_line:
  195. if ";" in first_line and "," not in first_line:
  196. use_delim = ";"
  197. elif "," in first_line and ";" not in first_line:
  198. use_delim = ","
  199. def hnorm(h: str) -> str:
  200. h = (h or "").strip().lower()
  201. h = re_sub_non_alnum(h)
  202. return h
  203. f = io.StringIO(text)
  204. r = csv.reader(f, delimiter=use_delim)
  205. header = next(r, None)
  206. if not header:
  207. return []
  208. header_norm = [hnorm(h) for h in header]
  209. idx = {name: i for i, name in enumerate(header_norm) if name}
  210. def find_idx(names: List[str]) -> Optional[int]:
  211. for n in names:
  212. if n in idx:
  213. return idx[n]
  214. return None
  215. mac_i = find_idx(["mac", "beaconmac", "beacon_mac", "trackermac", "tracker_mac", "device", "devicemac"])
  216. x_i = find_idx(["x"])
  217. y_i = find_idx(["y"])
  218. z_i = find_idx(["z"])
  219. if mac_i is None or x_i is None or y_i is None or z_i is None:
  220. raise ValueError(
  221. f"Job CSV header non riconosciuto: {header}. "
  222. f"Attesi campi MAC/X/Y/Z (case-insensitive)."
  223. )
  224. rows: List[Dict[str, Any]] = []
  225. for cols in r:
  226. if not cols:
  227. continue
  228. if len(cols) <= max(mac_i, x_i, y_i, z_i):
  229. continue
  230. mac_raw = (cols[mac_i] or "").strip()
  231. if not mac_raw:
  232. continue
  233. mac_compact = norm_mac(mac_raw).replace(":", "")
  234. try:
  235. x = float((cols[x_i] or "").strip())
  236. y = float((cols[y_i] or "").strip())
  237. z = float((cols[z_i] or "").strip())
  238. except Exception:
  239. continue
  240. rows.append({"mac": mac_compact, "x": x, "y": y, "z": z})
  241. return rows
  242. def re_sub_non_alnum(s: str) -> str:
  243. out = []
  244. for ch in s:
  245. if ("a" <= ch <= "z") or ("0" <= ch <= "9"):
  246. out.append(ch)
  247. return "".join(out)
  248. def write_samples_csv(
  249. out_path: Path,
  250. sample_rows: List[Dict[str, Any]],
  251. gateway_macs: List[str],
  252. *,
  253. delimiter: str = ";",
  254. rssi_decimals: int = 0,
  255. ) -> None:
  256. header = ["mac", "x", "y", "z"] + gateway_macs
  257. safe_write_csv(out_path, header, sample_rows, delimiter=delimiter, rssi_decimals=rssi_decimals)
  258. def load_gateway_csv(path: Path, delimiter: str = ";") -> Tuple[List[str], int, int]:
  259. df = pd.read_csv(path, delimiter=delimiter)
  260. cols = [c.strip().lower() for c in df.columns]
  261. df.columns = cols
  262. invalid = 0
  263. macs: List[str] = []
  264. seen = set()
  265. if "mac" not in df.columns:
  266. raise ValueError(f"gateway.csv must have a 'mac' column, got columns={list(df.columns)}")
  267. for v in df["mac"].astype(str).tolist():
  268. nm = norm_mac(v)
  269. if len(nm.replace(":", "")) != 12:
  270. invalid += 1
  271. continue
  272. if nm in seen:
  273. continue
  274. seen.add(nm)
  275. macs.append(nm)
  276. duplicates = max(0, len(df) - invalid - len(macs))
  277. return macs, invalid, duplicates
  278. # -----------------------------
  279. # Fingerprint collector
  280. # -----------------------------
  281. @dataclass
  282. class FingerprintStats:
  283. counts: Dict[str, Dict[str, int]]
  284. last: Dict[str, Dict[str, float]]
  285. class FingerprintCollector:
  286. def __init__(self) -> None:
  287. self._lock = None
  288. try:
  289. import threading
  290. self._lock = threading.Lock()
  291. except Exception:
  292. self._lock = None
  293. # beacon_norm -> gw_norm -> list of rssi
  294. self.rssi: Dict[str, Dict[str, List[float]]] = {}
  295. self.last_seen_gw: Dict[str, float] = {}
  296. self.last_seen_beacon: Dict[str, float] = {}
  297. def _with_lock(self):
  298. if self._lock is None:
  299. class Dummy:
  300. def __enter__(self): return None
  301. def __exit__(self, *a): return False
  302. return Dummy()
  303. return self._lock
  304. def update(self, gw_mac: str, beacon_mac: str, rssi: float) -> None:
  305. gw = norm_mac(gw_mac)
  306. b = norm_mac(beacon_mac)
  307. now = time.time()
  308. with self._with_lock():
  309. self.last_seen_gw[gw] = now
  310. self.last_seen_beacon[b] = now
  311. self.rssi.setdefault(b, {}).setdefault(gw, []).append(float(rssi))
  312. def stats(self, beacons: List[str], gateways: List[str]) -> FingerprintStats:
  313. with self._with_lock():
  314. counts: Dict[str, Dict[str, int]] = {b: {g: 0 for g in gateways} for b in beacons}
  315. last: Dict[str, Dict[str, float]] = {b: {g: float("nan") for g in gateways} for b in beacons}
  316. for b in beacons:
  317. bm = norm_mac(b)
  318. for g in gateways:
  319. gm = norm_mac(g)
  320. vals = self.rssi.get(bm, {}).get(gm, [])
  321. counts[bm][gm] = len(vals)
  322. if vals:
  323. last[bm][gm] = vals[-1]
  324. return FingerprintStats(counts=counts, last=last)
  325. def feature_row(
  326. self,
  327. beacon_mac: str,
  328. gateways: List[str],
  329. aggregate: str,
  330. rssi_min: float,
  331. rssi_max: float,
  332. min_samples_per_gateway: int,
  333. outlier_method: str,
  334. mad_z: float,
  335. iqr_k: float,
  336. max_stddev: Optional[float],
  337. ) -> Dict[str, float]:
  338. b = norm_mac(beacon_mac)
  339. out: Dict[str, float] = {}
  340. with self._with_lock():
  341. for g in gateways:
  342. gm = norm_mac(g)
  343. vals = list(self.rssi.get(b, {}).get(gm, []))
  344. # hard clamp
  345. vals = [v for v in vals if (rssi_min <= v <= rssi_max)]
  346. if len(vals) < min_samples_per_gateway:
  347. out[gm] = float("nan")
  348. continue
  349. # outlier removal
  350. vals2 = vals
  351. if outlier_method == "mad":
  352. vals2 = mad_filter(vals2, z=mad_z)
  353. elif outlier_method == "iqr":
  354. vals2 = iqr_filter(vals2, k=iqr_k)
  355. if len(vals2) < min_samples_per_gateway:
  356. out[gm] = float("nan")
  357. continue
  358. if max_stddev is not None:
  359. import statistics
  360. try:
  361. sd = statistics.pstdev(vals2)
  362. if sd > max_stddev:
  363. out[gm] = float("nan")
  364. continue
  365. except Exception:
  366. pass
  367. # Aggregate: mantieni float (niente cast a int) per poter usare rssi_decimals.
  368. if aggregate == "median":
  369. out[gm] = float(statistics.median(vals2))
  370. elif aggregate == "median_low":
  371. out[gm] = float(statistics.median_low(sorted(vals2)))
  372. elif aggregate == "median_high":
  373. out[gm] = float(statistics.median_high(sorted(vals2)))
  374. elif aggregate == "mean":
  375. out[gm] = float(statistics.fmean(vals2))
  376. else:
  377. out[gm] = float(statistics.median(vals2))
  378. return out
  379. def mad_filter(vals: List[float], z: float = 3.5) -> List[float]:
  380. if not vals:
  381. return vals
  382. s = pd.Series(vals)
  383. med = s.median()
  384. mad = (s - med).abs().median()
  385. if mad == 0:
  386. return vals
  387. mz = 0.6745 * (s - med).abs() / mad
  388. return [float(v) for v, keep in zip(vals, (mz <= z).tolist()) if keep]
  389. def iqr_filter(vals: List[float], k: float = 1.5) -> List[float]:
  390. if not vals:
  391. return vals
  392. s = pd.Series(vals)
  393. q1 = s.quantile(0.25)
  394. q3 = s.quantile(0.75)
  395. iqr = q3 - q1
  396. if iqr == 0:
  397. return vals
  398. lo = q1 - k * iqr
  399. hi = q3 + k * iqr
  400. return [float(v) for v in vals if lo <= v <= hi]
  401. # -----------------------------
  402. # MQTT parsing
  403. # -----------------------------
  404. def parse_topic_gateway(topic: str) -> Optional[str]:
  405. # expected: publish_out/<gwmac>
  406. parts = (topic or "").split("/")
  407. if len(parts) < 2:
  408. return None
  409. return parts[-1]
  410. def parse_payload_list(payload: bytes) -> Optional[List[Dict[str, Any]]]:
  411. try:
  412. obj = json.loads(payload.decode("utf-8", errors="replace"))
  413. if isinstance(obj, list):
  414. return obj
  415. return None
  416. except Exception:
  417. return None
  418. def is_gateway_announce(item: Dict[str, Any]) -> bool:
  419. return str(item.get("type", "")).strip().lower() == "gateway" and "mac" in item
  420. # -----------------------------
  421. # Collect train
  422. # -----------------------------
  423. def run_collect_train(settings: Dict[str, Any]) -> None:
  424. cfg = settings.get("collect_train", {}) or {}
  425. paths = settings.get("paths", {}) or {}
  426. mqtt_cfg = settings.get("mqtt", {}) or {}
  427. debug = settings.get("debug", {}) or {}
  428. window_seconds = float(cfg.get("window_seconds", 180))
  429. poll_seconds = float(cfg.get("poll_seconds", 2))
  430. min_non_nan = int(cfg.get("min_non_nan", 3))
  431. min_samples_per_gateway = int(cfg.get("min_samples_per_gateway", 5))
  432. aggregate = str(cfg.get("aggregate", "median"))
  433. # Numero di cifre decimali per i valori RSSI nei file samples (0 = intero)
  434. try:
  435. rssi_decimals = int(cfg.get("rssi_decimals", 0))
  436. except Exception:
  437. rssi_decimals = 0
  438. if rssi_decimals < 0:
  439. rssi_decimals = 0
  440. rssi_min = float(cfg.get("rssi_min", -110))
  441. rssi_max = float(cfg.get("rssi_max", -25))
  442. outlier_method = str(cfg.get("outlier_method", "mad"))
  443. mad_z = float(cfg.get("mad_z", 3.5))
  444. iqr_k = float(cfg.get("iqr_k", 1.5))
  445. max_stddev = cfg.get("max_stddev", None)
  446. max_stddev = float(max_stddev) if max_stddev is not None else None
  447. gateway_csv = Path(paths.get("gateways_csv", "/data/config/gateway.csv"))
  448. csv_delimiter = str(paths.get("csv_delimiter", ";"))
  449. jobs_dir = Path(cfg.get("jobs_dir", "/data/train/jobs"))
  450. pending_dir = jobs_dir / "pending"
  451. done_dir = jobs_dir / "done"
  452. error_dir = jobs_dir / "error"
  453. samples_dir = Path(cfg.get("samples_dir", "/data/train/samples"))
  454. pending_dir.mkdir(parents=True, exist_ok=True)
  455. done_dir.mkdir(parents=True, exist_ok=True)
  456. error_dir.mkdir(parents=True, exist_ok=True)
  457. samples_dir.mkdir(parents=True, exist_ok=True)
  458. gw_ready_log_seconds = float(cfg.get("gw_ready_log_seconds", 10))
  459. gw_ready_sleep_seconds = float(cfg.get("gw_ready_sleep_seconds", 5))
  460. gw_ready_check_before_job = bool(cfg.get("gw_ready_check_before_job", True))
  461. online_max_age_s = float(debug.get("online_check_seconds", 30))
  462. progress_log_seconds = float(cfg.get("wait_all_gateways_log_seconds", 30))
  463. gateway_macs, invalid, duplicates = load_gateway_csv(gateway_csv, delimiter=csv_delimiter)
  464. log(f"[gateway.csv] loaded gateways={len(gateway_macs)} invalid={invalid} duplicates={duplicates}")
  465. log(
  466. "COLLECT_TRAIN config: gateway_csv=%s gateways(feature-set)=%d window_seconds=%.1f poll_seconds=%.1f rssi_decimals=%d jobs_dir=%s "
  467. "pending_dir=%s done_dir=%s error_dir=%s samples_dir=%s mqtt=%s:%s topic=%s"
  468. % (
  469. gateway_csv,
  470. len(gateway_macs),
  471. window_seconds,
  472. poll_seconds,
  473. rssi_decimals,
  474. jobs_dir,
  475. pending_dir,
  476. done_dir,
  477. error_dir,
  478. samples_dir,
  479. mqtt_cfg.get("host", ""),
  480. mqtt_cfg.get("port", ""),
  481. mqtt_cfg.get("topic", "publish_out/#"),
  482. )
  483. )
  484. fp = FingerprintCollector()
  485. # MQTT setup
  486. host = mqtt_cfg.get("host", "127.0.0.1")
  487. port = int(mqtt_cfg.get("port", 1883))
  488. topic = mqtt_cfg.get("topic", "publish_out/#")
  489. client_id = mqtt_cfg.get("client_id", "ble-ai-localizer")
  490. keepalive = int(mqtt_cfg.get("keepalive", 60))
  491. proto = mqtt.MQTTv311
  492. def on_connect(client, userdata, flags, rc):
  493. log(f"MQTT connected rc={rc}, subscribed to {topic}")
  494. client.subscribe(topic)
  495. def on_message(client, userdata, msg):
  496. gw_from_topic = parse_topic_gateway(msg.topic)
  497. if not gw_from_topic:
  498. return
  499. payload_list = parse_payload_list(msg.payload)
  500. if not payload_list:
  501. return
  502. for it in payload_list:
  503. if not isinstance(it, dict):
  504. continue
  505. if is_gateway_announce(it):
  506. gwm = it.get("mac", gw_from_topic)
  507. fp.last_seen_gw[norm_mac(gwm)] = time.time()
  508. continue
  509. bmac = it.get("mac")
  510. rssi = it.get("rssi")
  511. if not bmac or rssi is None:
  512. continue
  513. try:
  514. fp.update(gw_from_topic, bmac, float(rssi))
  515. except Exception:
  516. continue
  517. client = mqtt.Client(client_id=client_id, protocol=proto)
  518. client.on_connect = on_connect
  519. client.on_message = on_message
  520. username = str(mqtt_cfg.get("username", "") or "")
  521. password = str(mqtt_cfg.get("password", "") or "")
  522. if username:
  523. client.username_pw_set(username, password)
  524. tls = bool(mqtt_cfg.get("tls", False))
  525. if tls:
  526. client.tls_set(cert_reqs=ssl.CERT_NONE)
  527. client.tls_insecure_set(True)
  528. log("MQTT thread started (collect_train)")
  529. client.connect(host, port, keepalive=keepalive)
  530. client.loop_start()
  531. # Wait gateways online
  532. last_ready_log = 0.0
  533. while True:
  534. now = time.time()
  535. online = 0
  536. missing = []
  537. for g in gateway_macs:
  538. seen = fp.last_seen_gw.get(norm_mac(g))
  539. if seen is not None and (now - seen) <= online_max_age_s:
  540. online += 1
  541. else:
  542. missing.append(norm_mac(g))
  543. if online == len(gateway_macs):
  544. log(f"GW READY: online={online}/{len(gateway_macs)} (max_age_s={online_max_age_s:.1f})")
  545. break
  546. if now - last_ready_log >= gw_ready_log_seconds:
  547. log(f"WAIT gateways online ({len(missing)} missing, seen={online}/{len(gateway_macs)}): {missing} (max_age_s={online_max_age_s:.1f})")
  548. last_ready_log = now
  549. time.sleep(gw_ready_sleep_seconds)
  550. # Job loop
  551. while True:
  552. try:
  553. # periodic gw ready log
  554. now = time.time()
  555. if now - last_ready_log >= gw_ready_log_seconds:
  556. online = 0
  557. for g in gateway_macs:
  558. seen = fp.last_seen_gw.get(norm_mac(g))
  559. if seen is not None and (now - seen) <= online_max_age_s:
  560. online += 1
  561. log(f"GW READY: online={online}/{len(gateway_macs)} (max_age_s={online_max_age_s:.1f})")
  562. last_ready_log = now
  563. # pick job
  564. job_files = sorted(pending_dir.glob("*.csv"))
  565. if not job_files:
  566. time.sleep(poll_seconds)
  567. continue
  568. job_path = job_files[0]
  569. job_name = job_path.name
  570. rows = read_job_csv(job_path, delimiter=csv_delimiter)
  571. if not rows:
  572. # move empty/bad jobs to error
  573. log(f"TRAIN job ERROR: {job_name} err=EmptyJob: no valid rows")
  574. job_path.rename(error_dir / job_path.name)
  575. continue
  576. # normalize beacons for stats keys
  577. job_beacons_norm = [norm_mac(r["mac"]) for r in rows]
  578. # optionally wait gateways online before starting the window
  579. if gw_ready_check_before_job:
  580. while True:
  581. now = time.time()
  582. online = 0
  583. missing = []
  584. for g in gateway_macs:
  585. seen = fp.last_seen_gw.get(norm_mac(g))
  586. if seen is not None and (now - seen) <= online_max_age_s:
  587. online += 1
  588. else:
  589. missing.append(norm_mac(g))
  590. if online == len(gateway_macs):
  591. break
  592. log(f"WAIT gateways online before job ({len(missing)} missing, seen={online}/{len(gateway_macs)}): {missing}")
  593. time.sleep(1.0)
  594. log(f"TRAIN job START: {job_name} beacons={len(rows)}")
  595. start = time.time()
  596. deadline = start + window_seconds
  597. next_progress = start + progress_log_seconds
  598. while time.time() < deadline:
  599. time.sleep(0.5)
  600. if progress_log_seconds > 0 and time.time() >= next_progress:
  601. st = fp.stats(job_beacons_norm, gateway_macs)
  602. parts = []
  603. for b in job_beacons_norm:
  604. total = sum(st.counts[b].values())
  605. gw_seen = sum(1 for g in gateway_macs if st.counts[b][g] > 0)
  606. parts.append(f"{b.replace(':','')}: total={total} gw={gw_seen}/{len(gateway_macs)}")
  607. elapsed = int(time.time() - start)
  608. log(f"COLLECT progress: {elapsed}s/{int(window_seconds)}s " + " | ".join(parts))
  609. next_progress = time.time() + progress_log_seconds
  610. out_rows: List[Dict[str, Any]] = []
  611. st = fp.stats(job_beacons_norm, gateway_macs)
  612. for r, b_norm in zip(rows, job_beacons_norm):
  613. feats = fp.feature_row(
  614. beacon_mac=b_norm,
  615. gateways=gateway_macs,
  616. aggregate=aggregate,
  617. rssi_min=rssi_min,
  618. rssi_max=rssi_max,
  619. min_samples_per_gateway=min_samples_per_gateway,
  620. outlier_method=outlier_method,
  621. mad_z=mad_z,
  622. iqr_k=iqr_k,
  623. max_stddev=max_stddev,
  624. )
  625. non_nan = sum(1 for g in gateway_macs if feats.get(g) == feats.get(g))
  626. if non_nan < min_non_nan:
  627. sample_info = []
  628. for g in gateway_macs:
  629. c = st.counts[b_norm][g]
  630. if c > 0:
  631. sample_info.append(f"{g} n={c} last={st.last[b_norm][g]}")
  632. preview = ", ".join(sample_info[:8]) + (" ..." if len(sample_info) > 8 else "")
  633. log(
  634. f"WARNING: beacon {b_norm.replace(':','')} low features non_nan={non_nan} "
  635. f"(seen_gw={sum(1 for g in gateway_macs if st.counts[b_norm][g]>0)}) [{preview}]"
  636. )
  637. out_row: Dict[str, Any] = {
  638. "mac": r["mac"], # MAC sempre compatto, senza ':'
  639. "x": float(r["x"]),
  640. "y": float(r["y"]),
  641. "z": float(r["z"]),
  642. }
  643. out_row.update(feats)
  644. out_rows.append(out_row)
  645. written = []
  646. for out_row in out_rows:
  647. # Nome file: Z_X_Y.csv (Z, X, Y presi dal job)
  648. zt = _coord_token(out_row.get("z"))
  649. xt = _coord_token(out_row.get("x"))
  650. yt = _coord_token(out_row.get("y"))
  651. base_name = f"{zt}_{xt}_{yt}.csv"
  652. out_path = samples_dir / base_name
  653. write_samples_csv(out_path, [out_row], gateway_macs, delimiter=csv_delimiter, rssi_decimals=rssi_decimals)
  654. written.append(out_path.name)
  655. job_path.rename(done_dir / job_path.name)
  656. if written:
  657. shown = ", ".join(written[:10])
  658. more = "" if len(written) <= 10 else f" (+{len(written)-10} altri)"
  659. log(f"TRAIN job DONE: wrote {len(written)} sample files to {samples_dir}: {shown}{more}")
  660. else:
  661. log(f"TRAIN job DONE: no output rows (empty job?)")
  662. except Exception as e:
  663. log(f"TRAIN job ERROR: {job_name} err={type(e).__name__}: {e}")
  664. try:
  665. job_path.rename(error_dir / job_path.name)
  666. except Exception:
  667. pass
  668. time.sleep(0.5)
  669. def main() -> None:
  670. settings = load_settings()
  671. cfg_file = settings.get("_config_file", "")
  672. keys = [k for k in settings.keys() if not str(k).startswith("_")]
  673. log(f"Settings loaded from {cfg_file}. Keys: {keys}")
  674. log(f"BUILD: {build_info()}")
  675. mode = str(settings.get("mode", "collect_train")).strip().lower()
  676. if mode == "collect_train":
  677. run_collect_train(settings)
  678. return
  679. if mode == "train":
  680. from .train_mode import run_train
  681. run_train(settings)
  682. return
  683. if mode == "infer":
  684. from .infer_mode import run_infer
  685. run_infer(settings)
  686. return
  687. raise ValueError(f"unknown mode: {mode}")
  688. if __name__ == "__main__":
  689. main()