Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.
 
 
 
 

350 строки
13 KiB

  1. """train_collect.py
  2. Modalità COLLECT_TRAIN:
  3. - attende che tutti i gateway del feature-set (gateway.csv) siano online (traffic MQTT)
  4. - prende job CSV da jobs_dir/pending/*.csv
  5. - per ogni job: apre una finestra di raccolta di window_seconds, aggrega RSSI per GW
  6. - scrive sample CSV in samples_dir
  7. Bugfix fondamentale (per i tuoi NAN):
  8. - matching interno su MAC in formato **compact** (12 hex senza ':').
  9. """
  10. from __future__ import annotations
  11. import os
  12. import time
  13. import shutil
  14. import glob
  15. import math
  16. import datetime
  17. from dataclasses import dataclass
  18. from typing import Dict, List, Optional, Tuple
  19. import pandas as pd
  20. from .normalize import mac_to_compact, compact_to_colon
  21. from .mqtt_client import MqttSubscriber
  22. from .mqtt_parser import parse_publish_out
  23. from .fingerprint import FingerprintWindow
  24. from .logger_utils import log_msg as log
  25. def _ensure_dir(path: str) -> None:
  26. os.makedirs(path, exist_ok=True)
  27. def _read_delimited_csv(path: str, prefer_delim: str = ";") -> pd.DataFrame:
  28. for sep in [prefer_delim, ",", "\t"]:
  29. try:
  30. df = pd.read_csv(path, sep=sep, dtype=str, keep_default_na=False)
  31. if len(df.columns) >= 1:
  32. return df
  33. except Exception:
  34. continue
  35. return pd.read_csv(path, dtype=str, keep_default_na=False)
  36. def load_gateway_csv(path: str, delimiter: str = ";") -> Tuple[List[str], List[str]]:
  37. df = _read_delimited_csv(path, prefer_delim=delimiter)
  38. if df.empty:
  39. return [], []
  40. mac_col = None
  41. for c in df.columns:
  42. if c.strip().lower() == "mac":
  43. mac_col = c
  44. break
  45. if mac_col is None:
  46. mac_col = df.columns[0]
  47. headers: List[str] = []
  48. keys: List[str] = []
  49. seen = set()
  50. invalid = 0
  51. dup = 0
  52. for raw in df[mac_col].tolist():
  53. k = mac_to_compact(raw)
  54. if len(k) != 12:
  55. invalid += 1
  56. continue
  57. if k in seen:
  58. dup += 1
  59. continue
  60. seen.add(k)
  61. keys.append(k)
  62. headers.append(compact_to_colon(k))
  63. log(f"[gateway.csv] loaded gateways={len(keys)} invalid={invalid} duplicates={dup}")
  64. return headers, keys
  65. @dataclass
  66. class TrainTarget:
  67. mac: str # compact
  68. x: float
  69. y: float
  70. z: float
  71. def read_job_csv(job_path: str, delimiter: str = ";") -> List[TrainTarget]:
  72. df = _read_delimited_csv(job_path, prefer_delim=delimiter)
  73. if df.empty:
  74. return []
  75. cols = {c.strip().lower(): c for c in df.columns}
  76. def col(name: str) -> Optional[str]:
  77. return cols.get(name)
  78. mac_c = col("mac")
  79. x_c = col("x")
  80. y_c = col("y")
  81. z_c = col("z")
  82. if not mac_c:
  83. raise ValueError(f"Job CSV senza colonna 'mac': {job_path}")
  84. out: List[TrainTarget] = []
  85. for _, row in df.iterrows():
  86. m = mac_to_compact(row[mac_c])
  87. if len(m) != 12:
  88. continue
  89. x = float(row[x_c]) if x_c else 0.0
  90. y = float(row[y_c]) if y_c else 0.0
  91. z = float(row[z_c]) if z_c else 0.0
  92. out.append(TrainTarget(mac=m, x=x, y=y, z=z))
  93. return out
  94. def _pick_collect_cfg(settings: Dict) -> Dict:
  95. if "collect_train" in settings and isinstance(settings["collect_train"], dict):
  96. return settings["collect_train"]
  97. if "training" in settings and isinstance(settings["training"], dict):
  98. log("WARNING: config usa 'training:' (alias). Consiglio: rinomina in 'collect_train:'")
  99. return settings["training"]
  100. return {}
  101. def run_collect_train(settings: Dict) -> None:
  102. ct = _pick_collect_cfg(settings)
  103. paths = settings.get("paths", {}) or {}
  104. mqtt_cfg = settings.get("mqtt", {}) or {}
  105. dbg = settings.get("debug", {}) or {}
  106. jobs_dir = str(ct.get("jobs_dir", "/data/train/jobs"))
  107. samples_dir = str(ct.get("samples_dir", "/data/train/samples"))
  108. job_glob = str(ct.get("job_glob", "*.csv"))
  109. poll_seconds = float(ct.get("poll_seconds", ct.get("poll_pending_seconds", 2)))
  110. window_seconds = float(ct.get("window_seconds", 10))
  111. min_non_nan = int(ct.get("min_non_nan", 3))
  112. aggregate = str(ct.get("aggregate", "median")).lower()
  113. rssi_min = float(ct.get("rssi_min", -110))
  114. rssi_max = float(ct.get("rssi_max", -25))
  115. outlier_method = str(ct.get("outlier_method", "none")).lower()
  116. mad_z = float(ct.get("mad_z", 3.5))
  117. min_samples_per_gateway = int(ct.get("min_samples_per_gateway", 1))
  118. max_stddev = ct.get("max_stddev", None)
  119. max_stddev = float(max_stddev) if max_stddev is not None else None
  120. gateway_ready_max_age_seconds = float(ct.get("gateway_ready_max_age_seconds", 30))
  121. gw_ready_log_seconds = float(ct.get("gw_ready_log_seconds", 10))
  122. gw_ready_sleep_seconds = float(ct.get("gw_ready_sleep_seconds", 5))
  123. gw_ready_check_before_job = bool(ct.get("gw_ready_check_before_job", True))
  124. csv_delim = str(paths.get("csv_delimiter", ";"))
  125. gateway_csv = str(paths.get("gateways_csv", "/data/config/gateway.csv"))
  126. # Debug opzionale durante finestra
  127. log_progress = bool(dbg.get("collect_train_log_samples", False))
  128. log_first_seen = bool(dbg.get("collect_train_log_first_seen", False))
  129. log_every_s = float(dbg.get("collect_train_log_every_seconds", 15))
  130. pending_dir = os.path.join(jobs_dir, "pending")
  131. done_dir = os.path.join(jobs_dir, "done")
  132. error_dir = os.path.join(jobs_dir, "error")
  133. _ensure_dir(pending_dir)
  134. _ensure_dir(done_dir)
  135. _ensure_dir(error_dir)
  136. _ensure_dir(samples_dir)
  137. gateway_headers, gateway_keys = load_gateway_csv(gateway_csv, delimiter=csv_delim)
  138. if not gateway_keys:
  139. log("ERROR: Nessun gateway valido nel gateway.csv -> non posso partire.")
  140. return
  141. mqtt_host = str(mqtt_cfg.get("host", "mosquitto"))
  142. mqtt_port = int(mqtt_cfg.get("port", 1883))
  143. mqtt_topic = str(mqtt_cfg.get("topic", "publish_out/#"))
  144. mqtt_proto = str(mqtt_cfg.get("protocol", "mqttv311")).lower()
  145. client_id = str(mqtt_cfg.get("client_id", "ble-ai-localizer"))
  146. keepalive = int(mqtt_cfg.get("keepalive", 60))
  147. qos = int(mqtt_cfg.get("qos", 0))
  148. username = str(mqtt_cfg.get("username", ""))
  149. password = str(mqtt_cfg.get("password", ""))
  150. last_seen: Dict[str, float] = {}
  151. active_window: Optional[FingerprintWindow] = None
  152. active_logged_pairs: set = set()
  153. def on_mqtt_message(topic: str, payload: bytes) -> None:
  154. nonlocal active_window, active_logged_pairs
  155. events = parse_publish_out(topic, payload)
  156. now = time.time()
  157. for gw_key, b_key, rssi, _ts in events:
  158. if len(gw_key) == 12:
  159. last_seen[gw_key] = now
  160. if active_window is None:
  161. continue
  162. accepted = active_window.add(gw_key, b_key, rssi)
  163. if accepted and log_first_seen:
  164. pair = (b_key, gw_key)
  165. if pair not in active_logged_pairs:
  166. active_logged_pairs.add(pair)
  167. log(f"SEEN target beacon={b_key} gw={compact_to_colon(gw_key)} rssi={rssi:.1f}")
  168. sub = MqttSubscriber(
  169. host=mqtt_host,
  170. port=mqtt_port,
  171. topic=mqtt_topic,
  172. mqtt_proto=mqtt_proto,
  173. client_id=client_id,
  174. keepalive=keepalive,
  175. qos=qos,
  176. username=username if username else None,
  177. password=password if password else None,
  178. )
  179. import threading
  180. t = threading.Thread(target=sub.start_forever, args=(on_mqtt_message,), daemon=True)
  181. t.start()
  182. log("MQTT thread started (collect_train)")
  183. log(
  184. f"COLLECT_TRAIN config: gateway_csv={gateway_csv} gateways(feature-set)={len(gateway_keys)} "
  185. f"window_seconds={window_seconds:.1f} poll_seconds={poll_seconds:.1f} "
  186. f"jobs_dir={jobs_dir} pending_dir={pending_dir} done_dir={done_dir} error_dir={error_dir} "
  187. f"samples_dir={samples_dir} mqtt={mqtt_host}:{mqtt_port} topic={mqtt_topic}"
  188. )
  189. def gateways_online() -> Tuple[int, List[str]]:
  190. now = time.time()
  191. missing: List[str] = []
  192. for gk, hdr in zip(gateway_keys, gateway_headers):
  193. last = last_seen.get(gk)
  194. if last is None or (now - last) > gateway_ready_max_age_seconds:
  195. missing.append(hdr)
  196. return len(missing), missing
  197. def wait_for_gateways() -> None:
  198. last_log = 0.0
  199. while True:
  200. miss_n, missing = gateways_online()
  201. if miss_n == 0:
  202. log(f"GW READY: online={len(gateway_keys)}/{len(gateway_keys)} (max_age_s={gateway_ready_max_age_seconds:.1f})")
  203. return
  204. now = time.time()
  205. if now - last_log >= gw_ready_log_seconds:
  206. last_log = now
  207. log(
  208. f"WAIT gateways online ({miss_n} missing, seen={len(gateway_keys)-miss_n}/{len(gateway_keys)}): {missing} "
  209. f"(max_age_s={gateway_ready_max_age_seconds:.1f})"
  210. )
  211. time.sleep(gw_ready_sleep_seconds)
  212. while True:
  213. jobs = sorted(glob.glob(os.path.join(pending_dir, job_glob)))
  214. if not jobs:
  215. time.sleep(poll_seconds)
  216. continue
  217. for job_path in jobs:
  218. job_name = os.path.basename(job_path)
  219. try:
  220. if gw_ready_check_before_job:
  221. wait_for_gateways()
  222. targets = read_job_csv(job_path, delimiter=csv_delim)
  223. if not targets:
  224. raise RuntimeError("job CSV vuoto o senza MAC validi")
  225. beacon_keys = [t.mac for t in targets]
  226. log(f"TRAIN job START: {job_name} beacons={len(beacon_keys)}")
  227. active_logged_pairs = set()
  228. active_window = FingerprintWindow(
  229. beacon_keys=beacon_keys,
  230. gateway_headers=gateway_headers,
  231. gateway_keys=gateway_keys,
  232. rssi_min=rssi_min,
  233. rssi_max=rssi_max,
  234. outlier_method=outlier_method,
  235. mad_z=mad_z,
  236. min_samples_per_gateway=min_samples_per_gateway,
  237. max_stddev=max_stddev,
  238. )
  239. t0 = time.time()
  240. next_log = t0 + log_every_s
  241. while True:
  242. elapsed = time.time() - t0
  243. if elapsed >= window_seconds:
  244. break
  245. if log_progress and time.time() >= next_log:
  246. next_log = time.time() + log_every_s
  247. parts = []
  248. for b in beacon_keys:
  249. tops = active_window.top_gateways(b, aggregate=aggregate, top_n=3)
  250. if not tops:
  251. parts.append(f"{b}:0gw")
  252. else:
  253. top_s = ",".join([f"{hdr}({n})" for n, hdr, _agg in tops])
  254. parts.append(f"{b}:{top_s}")
  255. log(f"WINDOW progress {elapsed:.0f}/{window_seconds:.0f}s -> " + " | ".join(parts))
  256. time.sleep(0.25)
  257. rows: List[Dict[str, object]] = []
  258. for tt in targets:
  259. feats = active_window.features_for(tt.mac, aggregate=aggregate)
  260. non_nan = sum(0 if (isinstance(v, float) and math.isnan(v)) else 1 for v in feats.values())
  261. if non_nan < min_non_nan:
  262. log(f"WARNING: beacon {tt.mac} low features non_nan={non_nan}")
  263. tops = active_window.top_gateways(tt.mac, aggregate=aggregate, top_n=5)
  264. if tops:
  265. top_s = ", ".join([f"{hdr} n={n} agg={agg:.1f}" for n, hdr, agg in tops])
  266. log(f"SUMMARY beacon {tt.mac}: {top_s}")
  267. else:
  268. log(f"SUMMARY beacon {tt.mac}: no samples captured")
  269. row: Dict[str, object] = {"mac": tt.mac, "x": float(tt.x), "y": float(tt.y), "z": float(tt.z)}
  270. row.update(feats)
  271. rows.append(row)
  272. out_df = pd.DataFrame(rows)
  273. cols = ["mac", "x", "y", "z"] + gateway_headers
  274. out_df = out_df.reindex(columns=cols)
  275. epoch = int(time.time())
  276. out_name = f"{os.path.splitext(job_name)[0]}__{epoch}.csv"
  277. out_path = os.path.join(samples_dir, out_name)
  278. out_df.to_csv(out_path, sep=csv_delim, index=False, float_format="%.1f", na_rep="nan")
  279. log(f"TRAIN job DONE: wrote {out_path} rows={len(out_df)}")
  280. shutil.move(job_path, os.path.join(done_dir, job_name))
  281. except Exception as e:
  282. log(f"ERROR processing job {job_name}: {e}")
  283. try:
  284. shutil.move(job_path, os.path.join(error_dir, job_name))
  285. except Exception:
  286. pass
  287. finally:
  288. active_window = None