|
- """train_collect.py
-
- Modalità COLLECT_TRAIN:
- - attende che tutti i gateway del feature-set (gateway.csv) siano online (traffic MQTT)
- - prende job CSV da jobs_dir/pending/*.csv
- - per ogni job: apre una finestra di raccolta di window_seconds, aggrega RSSI per GW
- - scrive sample CSV in samples_dir
-
- Bugfix fondamentale (per i tuoi NAN):
- - matching interno su MAC in formato **compact** (12 hex senza ':').
- """
-
- from __future__ import annotations
-
- import os
- import time
- import shutil
- import glob
- import math
- import datetime
- from dataclasses import dataclass
- from typing import Dict, List, Optional, Tuple
-
- import pandas as pd
-
- from .normalize import mac_to_compact, compact_to_colon
- from .mqtt_client import MqttSubscriber
- from .mqtt_parser import parse_publish_out
- from .fingerprint import FingerprintWindow
- from .logger_utils import log_msg as log
-
-
- def _ensure_dir(path: str) -> None:
- os.makedirs(path, exist_ok=True)
-
-
- def _read_delimited_csv(path: str, prefer_delim: str = ";") -> pd.DataFrame:
- for sep in [prefer_delim, ",", "\t"]:
- try:
- df = pd.read_csv(path, sep=sep, dtype=str, keep_default_na=False)
- if len(df.columns) >= 1:
- return df
- except Exception:
- continue
- return pd.read_csv(path, dtype=str, keep_default_na=False)
-
-
- def load_gateway_csv(path: str, delimiter: str = ";") -> Tuple[List[str], List[str]]:
- df = _read_delimited_csv(path, prefer_delim=delimiter)
- if df.empty:
- return [], []
-
- mac_col = None
- for c in df.columns:
- if c.strip().lower() == "mac":
- mac_col = c
- break
- if mac_col is None:
- mac_col = df.columns[0]
-
- headers: List[str] = []
- keys: List[str] = []
- seen = set()
- invalid = 0
- dup = 0
-
- for raw in df[mac_col].tolist():
- k = mac_to_compact(raw)
- if len(k) != 12:
- invalid += 1
- continue
- if k in seen:
- dup += 1
- continue
- seen.add(k)
- keys.append(k)
- headers.append(compact_to_colon(k))
-
- log(f"[gateway.csv] loaded gateways={len(keys)} invalid={invalid} duplicates={dup}")
- return headers, keys
-
-
- @dataclass
- class TrainTarget:
- mac: str # compact
- x: float
- y: float
- z: float
-
-
- def read_job_csv(job_path: str, delimiter: str = ";") -> List[TrainTarget]:
- df = _read_delimited_csv(job_path, prefer_delim=delimiter)
- if df.empty:
- return []
-
- cols = {c.strip().lower(): c for c in df.columns}
-
- def col(name: str) -> Optional[str]:
- return cols.get(name)
-
- mac_c = col("mac")
- x_c = col("x")
- y_c = col("y")
- z_c = col("z")
- if not mac_c:
- raise ValueError(f"Job CSV senza colonna 'mac': {job_path}")
-
- out: List[TrainTarget] = []
- for _, row in df.iterrows():
- m = mac_to_compact(row[mac_c])
- if len(m) != 12:
- continue
- x = float(row[x_c]) if x_c else 0.0
- y = float(row[y_c]) if y_c else 0.0
- z = float(row[z_c]) if z_c else 0.0
- out.append(TrainTarget(mac=m, x=x, y=y, z=z))
- return out
-
-
- def _pick_collect_cfg(settings: Dict) -> Dict:
- if "collect_train" in settings and isinstance(settings["collect_train"], dict):
- return settings["collect_train"]
- if "training" in settings and isinstance(settings["training"], dict):
- log("WARNING: config usa 'training:' (alias). Consiglio: rinomina in 'collect_train:'")
- return settings["training"]
- return {}
-
-
- def run_collect_train(settings: Dict) -> None:
- ct = _pick_collect_cfg(settings)
- paths = settings.get("paths", {}) or {}
- mqtt_cfg = settings.get("mqtt", {}) or {}
- dbg = settings.get("debug", {}) or {}
-
- jobs_dir = str(ct.get("jobs_dir", "/data/train/jobs"))
- samples_dir = str(ct.get("samples_dir", "/data/train/samples"))
- job_glob = str(ct.get("job_glob", "*.csv"))
- poll_seconds = float(ct.get("poll_seconds", ct.get("poll_pending_seconds", 2)))
- window_seconds = float(ct.get("window_seconds", 10))
- min_non_nan = int(ct.get("min_non_nan", 3))
-
- aggregate = str(ct.get("aggregate", "median")).lower()
- rssi_min = float(ct.get("rssi_min", -110))
- rssi_max = float(ct.get("rssi_max", -25))
- outlier_method = str(ct.get("outlier_method", "none")).lower()
- mad_z = float(ct.get("mad_z", 3.5))
- min_samples_per_gateway = int(ct.get("min_samples_per_gateway", 1))
- max_stddev = ct.get("max_stddev", None)
- max_stddev = float(max_stddev) if max_stddev is not None else None
-
- gateway_ready_max_age_seconds = float(ct.get("gateway_ready_max_age_seconds", 30))
- gw_ready_log_seconds = float(ct.get("gw_ready_log_seconds", 10))
- gw_ready_sleep_seconds = float(ct.get("gw_ready_sleep_seconds", 5))
- gw_ready_check_before_job = bool(ct.get("gw_ready_check_before_job", True))
-
- csv_delim = str(paths.get("csv_delimiter", ";"))
- gateway_csv = str(paths.get("gateways_csv", "/data/config/gateway.csv"))
-
- # Debug opzionale durante finestra
- log_progress = bool(dbg.get("collect_train_log_samples", False))
- log_first_seen = bool(dbg.get("collect_train_log_first_seen", False))
- log_every_s = float(dbg.get("collect_train_log_every_seconds", 15))
-
- pending_dir = os.path.join(jobs_dir, "pending")
- done_dir = os.path.join(jobs_dir, "done")
- error_dir = os.path.join(jobs_dir, "error")
-
- _ensure_dir(pending_dir)
- _ensure_dir(done_dir)
- _ensure_dir(error_dir)
- _ensure_dir(samples_dir)
-
- gateway_headers, gateway_keys = load_gateway_csv(gateway_csv, delimiter=csv_delim)
- if not gateway_keys:
- log("ERROR: Nessun gateway valido nel gateway.csv -> non posso partire.")
- return
-
- mqtt_host = str(mqtt_cfg.get("host", "mosquitto"))
- mqtt_port = int(mqtt_cfg.get("port", 1883))
- mqtt_topic = str(mqtt_cfg.get("topic", "publish_out/#"))
- mqtt_proto = str(mqtt_cfg.get("protocol", "mqttv311")).lower()
- client_id = str(mqtt_cfg.get("client_id", "ble-ai-localizer"))
- keepalive = int(mqtt_cfg.get("keepalive", 60))
- qos = int(mqtt_cfg.get("qos", 0))
- username = str(mqtt_cfg.get("username", ""))
- password = str(mqtt_cfg.get("password", ""))
-
- last_seen: Dict[str, float] = {}
- active_window: Optional[FingerprintWindow] = None
- active_logged_pairs: set = set()
-
- def on_mqtt_message(topic: str, payload: bytes) -> None:
- nonlocal active_window, active_logged_pairs
- events = parse_publish_out(topic, payload)
- now = time.time()
- for gw_key, b_key, rssi, _ts in events:
- if len(gw_key) == 12:
- last_seen[gw_key] = now
-
- if active_window is None:
- continue
-
- accepted = active_window.add(gw_key, b_key, rssi)
- if accepted and log_first_seen:
- pair = (b_key, gw_key)
- if pair not in active_logged_pairs:
- active_logged_pairs.add(pair)
- log(f"SEEN target beacon={b_key} gw={compact_to_colon(gw_key)} rssi={rssi:.1f}")
-
- sub = MqttSubscriber(
- host=mqtt_host,
- port=mqtt_port,
- topic=mqtt_topic,
- mqtt_proto=mqtt_proto,
- client_id=client_id,
- keepalive=keepalive,
- qos=qos,
- username=username if username else None,
- password=password if password else None,
- )
-
- import threading
- t = threading.Thread(target=sub.start_forever, args=(on_mqtt_message,), daemon=True)
- t.start()
-
- log("MQTT thread started (collect_train)")
- log(
- f"COLLECT_TRAIN config: gateway_csv={gateway_csv} gateways(feature-set)={len(gateway_keys)} "
- f"window_seconds={window_seconds:.1f} poll_seconds={poll_seconds:.1f} "
- f"jobs_dir={jobs_dir} pending_dir={pending_dir} done_dir={done_dir} error_dir={error_dir} "
- f"samples_dir={samples_dir} mqtt={mqtt_host}:{mqtt_port} topic={mqtt_topic}"
- )
-
- def gateways_online() -> Tuple[int, List[str]]:
- now = time.time()
- missing: List[str] = []
- for gk, hdr in zip(gateway_keys, gateway_headers):
- last = last_seen.get(gk)
- if last is None or (now - last) > gateway_ready_max_age_seconds:
- missing.append(hdr)
- return len(missing), missing
-
- def wait_for_gateways() -> None:
- last_log = 0.0
- while True:
- miss_n, missing = gateways_online()
- if miss_n == 0:
- log(f"GW READY: online={len(gateway_keys)}/{len(gateway_keys)} (max_age_s={gateway_ready_max_age_seconds:.1f})")
- return
- now = time.time()
- if now - last_log >= gw_ready_log_seconds:
- last_log = now
- log(
- f"WAIT gateways online ({miss_n} missing, seen={len(gateway_keys)-miss_n}/{len(gateway_keys)}): {missing} "
- f"(max_age_s={gateway_ready_max_age_seconds:.1f})"
- )
- time.sleep(gw_ready_sleep_seconds)
-
- while True:
- jobs = sorted(glob.glob(os.path.join(pending_dir, job_glob)))
- if not jobs:
- time.sleep(poll_seconds)
- continue
-
- for job_path in jobs:
- job_name = os.path.basename(job_path)
- try:
- if gw_ready_check_before_job:
- wait_for_gateways()
-
- targets = read_job_csv(job_path, delimiter=csv_delim)
- if not targets:
- raise RuntimeError("job CSV vuoto o senza MAC validi")
-
- beacon_keys = [t.mac for t in targets]
- log(f"TRAIN job START: {job_name} beacons={len(beacon_keys)}")
-
- active_logged_pairs = set()
- active_window = FingerprintWindow(
- beacon_keys=beacon_keys,
- gateway_headers=gateway_headers,
- gateway_keys=gateway_keys,
- rssi_min=rssi_min,
- rssi_max=rssi_max,
- outlier_method=outlier_method,
- mad_z=mad_z,
- min_samples_per_gateway=min_samples_per_gateway,
- max_stddev=max_stddev,
- )
-
- t0 = time.time()
- next_log = t0 + log_every_s
-
- while True:
- elapsed = time.time() - t0
- if elapsed >= window_seconds:
- break
- if log_progress and time.time() >= next_log:
- next_log = time.time() + log_every_s
- parts = []
- for b in beacon_keys:
- tops = active_window.top_gateways(b, aggregate=aggregate, top_n=3)
- if not tops:
- parts.append(f"{b}:0gw")
- else:
- top_s = ",".join([f"{hdr}({n})" for n, hdr, _agg in tops])
- parts.append(f"{b}:{top_s}")
- log(f"WINDOW progress {elapsed:.0f}/{window_seconds:.0f}s -> " + " | ".join(parts))
- time.sleep(0.25)
-
- rows: List[Dict[str, object]] = []
- for tt in targets:
- feats = active_window.features_for(tt.mac, aggregate=aggregate)
- non_nan = sum(0 if (isinstance(v, float) and math.isnan(v)) else 1 for v in feats.values())
- if non_nan < min_non_nan:
- log(f"WARNING: beacon {tt.mac} low features non_nan={non_nan}")
-
- tops = active_window.top_gateways(tt.mac, aggregate=aggregate, top_n=5)
- if tops:
- top_s = ", ".join([f"{hdr} n={n} agg={agg:.1f}" for n, hdr, agg in tops])
- log(f"SUMMARY beacon {tt.mac}: {top_s}")
- else:
- log(f"SUMMARY beacon {tt.mac}: no samples captured")
-
- row: Dict[str, object] = {"mac": tt.mac, "x": float(tt.x), "y": float(tt.y), "z": float(tt.z)}
- row.update(feats)
- rows.append(row)
-
- out_df = pd.DataFrame(rows)
- cols = ["mac", "x", "y", "z"] + gateway_headers
- out_df = out_df.reindex(columns=cols)
-
- epoch = int(time.time())
- out_name = f"{os.path.splitext(job_name)[0]}__{epoch}.csv"
- out_path = os.path.join(samples_dir, out_name)
- out_df.to_csv(out_path, sep=csv_delim, index=False, float_format="%.1f", na_rep="nan")
-
- log(f"TRAIN job DONE: wrote {out_path} rows={len(out_df)}")
-
- shutil.move(job_path, os.path.join(done_dir, job_name))
-
- except Exception as e:
- log(f"ERROR processing job {job_name}: {e}")
- try:
- shutil.move(job_path, os.path.join(error_dir, job_name))
- except Exception:
- pass
- finally:
- active_window = None
|