"""train_collect.py Modalità COLLECT_TRAIN: - attende che tutti i gateway del feature-set (gateway.csv) siano online (traffic MQTT) - prende job CSV da jobs_dir/pending/*.csv - per ogni job: apre una finestra di raccolta di window_seconds, aggrega RSSI per GW - scrive sample CSV in samples_dir Bugfix fondamentale (per i tuoi NAN): - matching interno su MAC in formato **compact** (12 hex senza ':'). """ from __future__ import annotations import os import time import shutil import glob import math import datetime from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import pandas as pd from .normalize import mac_to_compact, compact_to_colon from .mqtt_client import MqttSubscriber from .mqtt_parser import parse_publish_out from .fingerprint import FingerprintWindow from .logger_utils import log_msg as log def _ensure_dir(path: str) -> None: os.makedirs(path, exist_ok=True) def _read_delimited_csv(path: str, prefer_delim: str = ";") -> pd.DataFrame: for sep in [prefer_delim, ",", "\t"]: try: df = pd.read_csv(path, sep=sep, dtype=str, keep_default_na=False) if len(df.columns) >= 1: return df except Exception: continue return pd.read_csv(path, dtype=str, keep_default_na=False) def load_gateway_csv(path: str, delimiter: str = ";") -> Tuple[List[str], List[str]]: df = _read_delimited_csv(path, prefer_delim=delimiter) if df.empty: return [], [] mac_col = None for c in df.columns: if c.strip().lower() == "mac": mac_col = c break if mac_col is None: mac_col = df.columns[0] headers: List[str] = [] keys: List[str] = [] seen = set() invalid = 0 dup = 0 for raw in df[mac_col].tolist(): k = mac_to_compact(raw) if len(k) != 12: invalid += 1 continue if k in seen: dup += 1 continue seen.add(k) keys.append(k) headers.append(compact_to_colon(k)) log(f"[gateway.csv] loaded gateways={len(keys)} invalid={invalid} duplicates={dup}") return headers, keys @dataclass class TrainTarget: mac: str # compact x: float y: float z: float def read_job_csv(job_path: str, delimiter: str = ";") -> List[TrainTarget]: df = _read_delimited_csv(job_path, prefer_delim=delimiter) if df.empty: return [] cols = {c.strip().lower(): c for c in df.columns} def col(name: str) -> Optional[str]: return cols.get(name) mac_c = col("mac") x_c = col("x") y_c = col("y") z_c = col("z") if not mac_c: raise ValueError(f"Job CSV senza colonna 'mac': {job_path}") out: List[TrainTarget] = [] for _, row in df.iterrows(): m = mac_to_compact(row[mac_c]) if len(m) != 12: continue x = float(row[x_c]) if x_c else 0.0 y = float(row[y_c]) if y_c else 0.0 z = float(row[z_c]) if z_c else 0.0 out.append(TrainTarget(mac=m, x=x, y=y, z=z)) return out def _pick_collect_cfg(settings: Dict) -> Dict: if "collect_train" in settings and isinstance(settings["collect_train"], dict): return settings["collect_train"] if "training" in settings and isinstance(settings["training"], dict): log("WARNING: config usa 'training:' (alias). Consiglio: rinomina in 'collect_train:'") return settings["training"] return {} def run_collect_train(settings: Dict) -> None: ct = _pick_collect_cfg(settings) paths = settings.get("paths", {}) or {} mqtt_cfg = settings.get("mqtt", {}) or {} dbg = settings.get("debug", {}) or {} jobs_dir = str(ct.get("jobs_dir", "/data/train/jobs")) samples_dir = str(ct.get("samples_dir", "/data/train/samples")) job_glob = str(ct.get("job_glob", "*.csv")) poll_seconds = float(ct.get("poll_seconds", ct.get("poll_pending_seconds", 2))) window_seconds = float(ct.get("window_seconds", 10)) min_non_nan = int(ct.get("min_non_nan", 3)) aggregate = str(ct.get("aggregate", "median")).lower() rssi_min = float(ct.get("rssi_min", -110)) rssi_max = float(ct.get("rssi_max", -25)) outlier_method = str(ct.get("outlier_method", "none")).lower() mad_z = float(ct.get("mad_z", 3.5)) min_samples_per_gateway = int(ct.get("min_samples_per_gateway", 1)) max_stddev = ct.get("max_stddev", None) max_stddev = float(max_stddev) if max_stddev is not None else None gateway_ready_max_age_seconds = float(ct.get("gateway_ready_max_age_seconds", 30)) gw_ready_log_seconds = float(ct.get("gw_ready_log_seconds", 10)) gw_ready_sleep_seconds = float(ct.get("gw_ready_sleep_seconds", 5)) gw_ready_check_before_job = bool(ct.get("gw_ready_check_before_job", True)) csv_delim = str(paths.get("csv_delimiter", ";")) gateway_csv = str(paths.get("gateways_csv", "/data/config/gateway.csv")) # Debug opzionale durante finestra log_progress = bool(dbg.get("collect_train_log_samples", False)) log_first_seen = bool(dbg.get("collect_train_log_first_seen", False)) log_every_s = float(dbg.get("collect_train_log_every_seconds", 15)) pending_dir = os.path.join(jobs_dir, "pending") done_dir = os.path.join(jobs_dir, "done") error_dir = os.path.join(jobs_dir, "error") _ensure_dir(pending_dir) _ensure_dir(done_dir) _ensure_dir(error_dir) _ensure_dir(samples_dir) gateway_headers, gateway_keys = load_gateway_csv(gateway_csv, delimiter=csv_delim) if not gateway_keys: log("ERROR: Nessun gateway valido nel gateway.csv -> non posso partire.") return mqtt_host = str(mqtt_cfg.get("host", "mosquitto")) mqtt_port = int(mqtt_cfg.get("port", 1883)) mqtt_topic = str(mqtt_cfg.get("topic", "publish_out/#")) mqtt_proto = str(mqtt_cfg.get("protocol", "mqttv311")).lower() client_id = str(mqtt_cfg.get("client_id", "ble-ai-localizer")) keepalive = int(mqtt_cfg.get("keepalive", 60)) qos = int(mqtt_cfg.get("qos", 0)) username = str(mqtt_cfg.get("username", "")) password = str(mqtt_cfg.get("password", "")) last_seen: Dict[str, float] = {} active_window: Optional[FingerprintWindow] = None active_logged_pairs: set = set() def on_mqtt_message(topic: str, payload: bytes) -> None: nonlocal active_window, active_logged_pairs events = parse_publish_out(topic, payload) now = time.time() for gw_key, b_key, rssi, _ts in events: if len(gw_key) == 12: last_seen[gw_key] = now if active_window is None: continue accepted = active_window.add(gw_key, b_key, rssi) if accepted and log_first_seen: pair = (b_key, gw_key) if pair not in active_logged_pairs: active_logged_pairs.add(pair) log(f"SEEN target beacon={b_key} gw={compact_to_colon(gw_key)} rssi={rssi:.1f}") sub = MqttSubscriber( host=mqtt_host, port=mqtt_port, topic=mqtt_topic, mqtt_proto=mqtt_proto, client_id=client_id, keepalive=keepalive, qos=qos, username=username if username else None, password=password if password else None, ) import threading t = threading.Thread(target=sub.start_forever, args=(on_mqtt_message,), daemon=True) t.start() log("MQTT thread started (collect_train)") log( f"COLLECT_TRAIN config: gateway_csv={gateway_csv} gateways(feature-set)={len(gateway_keys)} " f"window_seconds={window_seconds:.1f} poll_seconds={poll_seconds:.1f} " f"jobs_dir={jobs_dir} pending_dir={pending_dir} done_dir={done_dir} error_dir={error_dir} " f"samples_dir={samples_dir} mqtt={mqtt_host}:{mqtt_port} topic={mqtt_topic}" ) def gateways_online() -> Tuple[int, List[str]]: now = time.time() missing: List[str] = [] for gk, hdr in zip(gateway_keys, gateway_headers): last = last_seen.get(gk) if last is None or (now - last) > gateway_ready_max_age_seconds: missing.append(hdr) return len(missing), missing def wait_for_gateways() -> None: last_log = 0.0 while True: miss_n, missing = gateways_online() if miss_n == 0: log(f"GW READY: online={len(gateway_keys)}/{len(gateway_keys)} (max_age_s={gateway_ready_max_age_seconds:.1f})") return now = time.time() if now - last_log >= gw_ready_log_seconds: last_log = now log( f"WAIT gateways online ({miss_n} missing, seen={len(gateway_keys)-miss_n}/{len(gateway_keys)}): {missing} " f"(max_age_s={gateway_ready_max_age_seconds:.1f})" ) time.sleep(gw_ready_sleep_seconds) while True: jobs = sorted(glob.glob(os.path.join(pending_dir, job_glob))) if not jobs: time.sleep(poll_seconds) continue for job_path in jobs: job_name = os.path.basename(job_path) try: if gw_ready_check_before_job: wait_for_gateways() targets = read_job_csv(job_path, delimiter=csv_delim) if not targets: raise RuntimeError("job CSV vuoto o senza MAC validi") beacon_keys = [t.mac for t in targets] log(f"TRAIN job START: {job_name} beacons={len(beacon_keys)}") active_logged_pairs = set() active_window = FingerprintWindow( beacon_keys=beacon_keys, gateway_headers=gateway_headers, gateway_keys=gateway_keys, rssi_min=rssi_min, rssi_max=rssi_max, outlier_method=outlier_method, mad_z=mad_z, min_samples_per_gateway=min_samples_per_gateway, max_stddev=max_stddev, ) t0 = time.time() next_log = t0 + log_every_s while True: elapsed = time.time() - t0 if elapsed >= window_seconds: break if log_progress and time.time() >= next_log: next_log = time.time() + log_every_s parts = [] for b in beacon_keys: tops = active_window.top_gateways(b, aggregate=aggregate, top_n=3) if not tops: parts.append(f"{b}:0gw") else: top_s = ",".join([f"{hdr}({n})" for n, hdr, _agg in tops]) parts.append(f"{b}:{top_s}") log(f"WINDOW progress {elapsed:.0f}/{window_seconds:.0f}s -> " + " | ".join(parts)) time.sleep(0.25) rows: List[Dict[str, object]] = [] for tt in targets: feats = active_window.features_for(tt.mac, aggregate=aggregate) non_nan = sum(0 if (isinstance(v, float) and math.isnan(v)) else 1 for v in feats.values()) if non_nan < min_non_nan: log(f"WARNING: beacon {tt.mac} low features non_nan={non_nan}") tops = active_window.top_gateways(tt.mac, aggregate=aggregate, top_n=5) if tops: top_s = ", ".join([f"{hdr} n={n} agg={agg:.1f}" for n, hdr, agg in tops]) log(f"SUMMARY beacon {tt.mac}: {top_s}") else: log(f"SUMMARY beacon {tt.mac}: no samples captured") row: Dict[str, object] = {"mac": tt.mac, "x": float(tt.x), "y": float(tt.y), "z": float(tt.z)} row.update(feats) rows.append(row) out_df = pd.DataFrame(rows) cols = ["mac", "x", "y", "z"] + gateway_headers out_df = out_df.reindex(columns=cols) epoch = int(time.time()) out_name = f"{os.path.splitext(job_name)[0]}__{epoch}.csv" out_path = os.path.join(samples_dir, out_name) out_df.to_csv(out_path, sep=csv_delim, index=False, float_format="%.1f", na_rep="nan") log(f"TRAIN job DONE: wrote {out_path} rows={len(out_df)}") shutil.move(job_path, os.path.join(done_dir, job_name)) except Exception as e: log(f"ERROR processing job {job_name}: {e}") try: shutil.move(job_path, os.path.join(error_dir, job_name)) except Exception: pass finally: active_window = None