|
- from .logger_utils import setup_global_logging, log_msg as log
- import csv
- import io
- import json
- import os
- import ssl
- import time
- import traceback
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Tuple
- import hashlib
- import re
- import math
- import statistics
- import pandas as pd
- import numpy as np
- import requests
- import joblib
- import paho.mqtt.client as mqtt
-
- # Import locali corretti
- from .settings import load_settings
-
- def build_info() -> str:
- return "infer-debug-v19-fixed"
-
- def main() -> None:
- # 1. Carica impostazioni
- settings = load_settings()
-
- # 2. Setup immediato dei log e dei silenziatori (PRIMA di ogni altra cosa)
- setup_global_logging(settings)
-
- # 3. Ora puoi loggare e tutto sarà sincronizzato e pulito
- cfg_file = settings.get("_config_file", "/config/config.yaml")
- keys = [k for k in settings.keys() if not str(k).startswith("_")]
-
- log(f"Settings loaded from {cfg_file}. Keys: {keys}")
- log(f"BUILD: {build_info()}")
-
- def mac_plain(s: str) -> str:
- """Normalizza MAC a 12 hex uppercase senza separatori."""
- return re.sub(r"[^0-9A-Fa-f]", "", (s or "")).upper()
-
- def mac_colon(s: str) -> str:
- """MAC in formato AA:BB:CC:DD:EE:FF."""
- p = mac_plain(s)
- if len(p) != 12:
- return p
- return ":".join(p[i:i+2] for i in range(0, 12, 2))
-
- def fmt_rssi(v, decimals: int) -> str:
- """Formatta RSSI come stringa, evitando '-82.0' quando decimals=0."""
- if v is None:
- return "nan"
- try:
- fv = float(v)
- except Exception:
- return "nan"
- if math.isnan(fv):
- return "nan"
- if decimals <= 0:
- return str(int(round(fv)))
- return f"{round(fv, decimals):.{decimals}f}"
-
-
- # -----------------------------
-
- # Build info (printed at startup for traceability)
- BUILD_ID = "ble-ai-localizer main.py 2026-01-30 build-floatagg-v1"
-
- def build_info() -> str:
- """Return a short build identifier for logs (no external deps, no git required)."""
- try:
- p = Path(__file__)
- data = p.read_bytes()
- sha = hashlib.sha256(data).hexdigest()[:12]
- size = p.stat().st_size
- return f"{BUILD_ID} sha256={sha} size={size}"
- except Exception:
- return f"{BUILD_ID} sha256=? size=?"
-
- # Settings
- # -----------------------------
- def load_settings() -> Dict[str, Any]:
- cfg = os.environ.get("CONFIG", "/config/config.yaml")
- import yaml
- with open(cfg, "r", encoding="utf-8") as f:
- data = yaml.safe_load(f) or {}
- data["_config_file"] = cfg
-
- # Normalize config sections: prefer collect_train
- if "collect_train" not in data and "training" in data:
- log("WARNING: config usa 'training:' (alias). Consiglio: rinomina in 'collect_train:'")
- data["collect_train"] = data.get("training", {}) or {}
- return data
-
-
- # -----------------------------
- # MAC helpers
- # -----------------------------
- def norm_mac(mac: str) -> str:
- """Return MAC as AA:BB:CC:DD:EE:FF (upper), ignoring separators."""
- m = (mac or "").strip().replace("-", "").replace(":", "").replace(".", "")
- m = m.upper()
- if len(m) != 12:
- return mac.strip().upper()
- return ":".join(m[i:i+2] for i in range(0, 12, 2))
-
-
- # -----------------------------
- # CSV write helpers
- # -----------------------------
- def safe_write_csv(
- path: Path,
- header: List[str],
- rows: List[Dict[str, Any]],
- delimiter: str = ";",
- rssi_decimals: int = 0,
- ):
- """Scrive CSV in modo atomico e formattazione 'umana'.
-
- - numeri interi: senza decimali (es. -82 invece di -82.0)
- - RSSI: arrotondamento controllato da rssi_decimals (0 -> intero, >0 -> N cifre decimali)
- *si applica solo alle colonne RSSI (dopo mac/x/y/z)*
- - NaN: 'nan'
- - colonna 'mac': normalizzata in formato con ':' (es. C3:00:00:57:B9:E7) se passa un MAC valido
- """
- tmp = path.with_suffix(path.suffix + ".tmp")
- # csv.writer richiede un singolo carattere come delimiter
- if not isinstance(delimiter, str) or len(delimiter) != 1:
- delimiter = ";"
-
- try:
- rssi_decimals = int(rssi_decimals)
- except Exception:
- rssi_decimals = 0
- if rssi_decimals < 0:
- rssi_decimals = 0
-
- def fmt_cell(v: Any, col: str, idx: int) -> str:
- if v is None:
- return "nan"
-
- # MAC normalizzato con ':'
- if col.lower() == "mac" and isinstance(v, str):
- v2 = mac_colon(v)
- return v2
-
- # NaN float
- if isinstance(v, float):
- if math.isnan(v):
- return "nan"
-
- # colonne RSSI (dopo mac/x/y/z)
- if idx >= 4:
- if rssi_decimals == 0:
- return str(int(round(v)))
- return f"{v:.{rssi_decimals}f}"
-
- # altre colonne: compatta i (quasi) interi
- if abs(v - round(v)) < 1e-9:
- return str(int(round(v)))
- return str(v)
-
- # int / numpy int
- if isinstance(v, (int, np.integer)):
- # RSSI columns (after mac/x/y/z): respect rssi_decimals even for integer values
- if idx >= 4:
- if rssi_decimals == 0:
- return str(int(v))
- return f"{float(v):.{rssi_decimals}f}"
- return str(int(v))
-
- # numpy float
- if isinstance(v, np.floating):
- fv = float(v)
- if math.isnan(fv):
- return "nan"
- if idx >= 4:
- if rssi_decimals == 0:
- return str(int(round(fv)))
- return f"{fv:.{rssi_decimals}f}"
- if abs(fv - round(fv)) < 1e-9:
- return str(int(round(fv)))
- return str(fv)
-
- return str(v)
-
- with tmp.open("w", newline="") as f:
- w = csv.writer(f, delimiter=delimiter)
- w.writerow(header)
- for row in rows:
- w.writerow([fmt_cell(row.get(col), col, idx) for idx, col in enumerate(header)])
-
- tmp.replace(path)
-
-
-
-
- def _coord_token(v: float) -> str:
- # Stable token for filenames from coordinates.
- # - if integer-ish -> '123'
- # - else keep up to 3 decimals, strip trailing zeros, replace '.' with '_'
- try:
- fv=float(v)
- except Exception:
- return str(v)
- if abs(fv - round(fv)) < 1e-9:
- return str(int(round(fv)))
- s=f"{fv:.3f}".rstrip('0').rstrip('.')
- return s.replace('.', '_')
- def read_job_csv(job_path: Path, delimiter: str) -> List[Dict[str, Any]]:
- """Legge job CSV supportando due formati:
-
- 1) Legacy:
- mac;x;y;z
- C3000057B9F4;1200;450;0
-
- 2) Esteso (storico):
- Position;Floor;RoomName;X;Y;Z;BeaconName;MAC
- A21;1;P1-NETW;800;1050;1;BC-21;C3:00:00:57:B9:E6
-
- Estrae solo X,Y,Z,MAC e normalizza MAC in formato compatto (senza ':', uppercase).
- """
- text = job_path.read_text(encoding="utf-8", errors="replace")
- if not text.strip():
- return []
-
- first_line = next((ln for ln in text.splitlines() if ln.strip()), "")
- use_delim = delimiter
- if use_delim not in first_line:
- if ";" in first_line and "," not in first_line:
- use_delim = ";"
- elif "," in first_line and ";" not in first_line:
- use_delim = ","
-
- def hnorm(h: str) -> str:
- h = (h or "").strip().lower()
- h = re_sub_non_alnum(h)
- return h
-
- f = io.StringIO(text)
- r = csv.reader(f, delimiter=use_delim)
- header = next(r, None)
- if not header:
- return []
-
- header_norm = [hnorm(h) for h in header]
- idx = {name: i for i, name in enumerate(header_norm) if name}
-
- def find_idx(names: List[str]) -> Optional[int]:
- for n in names:
- if n in idx:
- return idx[n]
- return None
-
- mac_i = find_idx(["mac", "beaconmac", "beacon_mac", "trackermac", "tracker_mac", "device", "devicemac"])
- x_i = find_idx(["x"])
- y_i = find_idx(["y"])
- z_i = find_idx(["z"])
-
- if mac_i is None or x_i is None or y_i is None or z_i is None:
- raise ValueError(
- f"Job CSV header non riconosciuto: {header}. "
- f"Attesi campi MAC/X/Y/Z (case-insensitive)."
- )
-
- rows: List[Dict[str, Any]] = []
- for cols in r:
- if not cols:
- continue
- if len(cols) <= max(mac_i, x_i, y_i, z_i):
- continue
- mac_raw = (cols[mac_i] or "").strip()
- if not mac_raw:
- continue
-
- mac_compact = norm_mac(mac_raw).replace(":", "")
-
- try:
- x = float((cols[x_i] or "").strip())
- y = float((cols[y_i] or "").strip())
- z = float((cols[z_i] or "").strip())
- except Exception:
- continue
-
- rows.append({"mac": mac_compact, "x": x, "y": y, "z": z})
-
- return rows
-
-
- def re_sub_non_alnum(s: str) -> str:
- out = []
- for ch in s:
- if ("a" <= ch <= "z") or ("0" <= ch <= "9"):
- out.append(ch)
- return "".join(out)
-
-
- def write_samples_csv(
- out_path: Path,
- sample_rows: List[Dict[str, Any]],
- gateway_macs: List[str],
- *,
- delimiter: str = ";",
- rssi_decimals: int = 0,
- ) -> None:
- header = ["mac", "x", "y", "z"] + gateway_macs
- safe_write_csv(out_path, header, sample_rows, delimiter=delimiter, rssi_decimals=rssi_decimals)
-
- def load_gateway_csv(path: Path, delimiter: str = ";") -> Tuple[List[str], int, int]:
- df = pd.read_csv(path, delimiter=delimiter)
- cols = [c.strip().lower() for c in df.columns]
- df.columns = cols
-
- invalid = 0
- macs: List[str] = []
- seen = set()
-
- if "mac" not in df.columns:
- raise ValueError(f"gateway.csv must have a 'mac' column, got columns={list(df.columns)}")
-
- for v in df["mac"].astype(str).tolist():
- nm = norm_mac(v)
- if len(nm.replace(":", "")) != 12:
- invalid += 1
- continue
- if nm in seen:
- continue
- seen.add(nm)
- macs.append(nm)
-
- duplicates = max(0, len(df) - invalid - len(macs))
- return macs, invalid, duplicates
-
-
- # -----------------------------
- # Fingerprint collector
- # -----------------------------
- @dataclass
- class FingerprintStats:
- counts: Dict[str, Dict[str, int]]
- last: Dict[str, Dict[str, float]]
-
-
- class FingerprintCollector:
- def __init__(self) -> None:
- self._lock = None
- try:
- import threading
- self._lock = threading.Lock()
- except Exception:
- self._lock = None
-
- # beacon_norm -> gw_norm -> list of rssi
- self.rssi: Dict[str, Dict[str, List[float]]] = {}
- self.last_seen_gw: Dict[str, float] = {}
- self.last_seen_beacon: Dict[str, float] = {}
-
- def _with_lock(self):
- if self._lock is None:
- class Dummy:
- def __enter__(self): return None
- def __exit__(self, *a): return False
- return Dummy()
- return self._lock
-
- def update(self, gw_mac: str, beacon_mac: str, rssi: float) -> None:
- gw = norm_mac(gw_mac)
- b = norm_mac(beacon_mac)
- now = time.time()
- with self._with_lock():
- self.last_seen_gw[gw] = now
- self.last_seen_beacon[b] = now
- self.rssi.setdefault(b, {}).setdefault(gw, []).append(float(rssi))
-
- def stats(self, beacons: List[str], gateways: List[str]) -> FingerprintStats:
- with self._with_lock():
- counts: Dict[str, Dict[str, int]] = {b: {g: 0 for g in gateways} for b in beacons}
- last: Dict[str, Dict[str, float]] = {b: {g: float("nan") for g in gateways} for b in beacons}
- for b in beacons:
- bm = norm_mac(b)
- for g in gateways:
- gm = norm_mac(g)
- vals = self.rssi.get(bm, {}).get(gm, [])
- counts[bm][gm] = len(vals)
- if vals:
- last[bm][gm] = vals[-1]
- return FingerprintStats(counts=counts, last=last)
-
- def feature_row(
- self,
- beacon_mac: str,
- gateways: List[str],
- aggregate: str,
- rssi_min: float,
- rssi_max: float,
- min_samples_per_gateway: int,
- outlier_method: str,
- mad_z: float,
- iqr_k: float,
- max_stddev: Optional[float],
- ) -> Dict[str, float]:
- b = norm_mac(beacon_mac)
- out: Dict[str, float] = {}
- with self._with_lock():
- for g in gateways:
- gm = norm_mac(g)
- vals = list(self.rssi.get(b, {}).get(gm, []))
-
- # hard clamp
- vals = [v for v in vals if (rssi_min <= v <= rssi_max)]
- if len(vals) < min_samples_per_gateway:
- out[gm] = float("nan")
- continue
-
- # outlier removal
- vals2 = vals
- if outlier_method == "mad":
- vals2 = mad_filter(vals2, z=mad_z)
- elif outlier_method == "iqr":
- vals2 = iqr_filter(vals2, k=iqr_k)
-
- if len(vals2) < min_samples_per_gateway:
- out[gm] = float("nan")
- continue
-
- if max_stddev is not None:
- import statistics
- try:
- sd = statistics.pstdev(vals2)
- if sd > max_stddev:
- out[gm] = float("nan")
- continue
- except Exception:
- pass
-
- # Aggregate: mantieni float (niente cast a int) per poter usare rssi_decimals.
- if aggregate == "median":
- out[gm] = float(statistics.median(vals2))
- elif aggregate == "median_low":
- out[gm] = float(statistics.median_low(sorted(vals2)))
- elif aggregate == "median_high":
- out[gm] = float(statistics.median_high(sorted(vals2)))
- elif aggregate == "mean":
- out[gm] = float(statistics.fmean(vals2))
- else:
- out[gm] = float(statistics.median(vals2))
- return out
-
-
- def mad_filter(vals: List[float], z: float = 3.5) -> List[float]:
- if not vals:
- return vals
- s = pd.Series(vals)
- med = s.median()
- mad = (s - med).abs().median()
- if mad == 0:
- return vals
- mz = 0.6745 * (s - med).abs() / mad
- return [float(v) for v, keep in zip(vals, (mz <= z).tolist()) if keep]
-
-
- def iqr_filter(vals: List[float], k: float = 1.5) -> List[float]:
- if not vals:
- return vals
- s = pd.Series(vals)
- q1 = s.quantile(0.25)
- q3 = s.quantile(0.75)
- iqr = q3 - q1
- if iqr == 0:
- return vals
- lo = q1 - k * iqr
- hi = q3 + k * iqr
- return [float(v) for v in vals if lo <= v <= hi]
-
-
- # -----------------------------
- # MQTT parsing
- # -----------------------------
- def parse_topic_gateway(topic: str) -> Optional[str]:
- # expected: publish_out/<gwmac>
- parts = (topic or "").split("/")
- if len(parts) < 2:
- return None
- return parts[-1]
-
-
- def parse_payload_list(payload: bytes) -> Optional[List[Dict[str, Any]]]:
- try:
- obj = json.loads(payload.decode("utf-8", errors="replace"))
- if isinstance(obj, list):
- return obj
- return None
- except Exception:
- return None
-
-
- def is_gateway_announce(item: Dict[str, Any]) -> bool:
- return str(item.get("type", "")).strip().lower() == "gateway" and "mac" in item
-
-
- # -----------------------------
- # Collect train
- # -----------------------------
- def run_collect_train(settings: Dict[str, Any]) -> None:
- cfg = settings.get("collect_train", {}) or {}
- paths = settings.get("paths", {}) or {}
- mqtt_cfg = settings.get("mqtt", {}) or {}
- debug = settings.get("debug", {}) or {}
-
- window_seconds = float(cfg.get("window_seconds", 180))
- poll_seconds = float(cfg.get("poll_seconds", 2))
- min_non_nan = int(cfg.get("min_non_nan", 3))
- min_samples_per_gateway = int(cfg.get("min_samples_per_gateway", 5))
- aggregate = str(cfg.get("aggregate", "median"))
- # Numero di cifre decimali per i valori RSSI nei file samples (0 = intero)
- try:
- rssi_decimals = int(cfg.get("rssi_decimals", 0))
- except Exception:
- rssi_decimals = 0
- if rssi_decimals < 0:
- rssi_decimals = 0
- rssi_min = float(cfg.get("rssi_min", -110))
- rssi_max = float(cfg.get("rssi_max", -25))
- outlier_method = str(cfg.get("outlier_method", "mad"))
- mad_z = float(cfg.get("mad_z", 3.5))
- iqr_k = float(cfg.get("iqr_k", 1.5))
- max_stddev = cfg.get("max_stddev", None)
- max_stddev = float(max_stddev) if max_stddev is not None else None
-
- gateway_csv = Path(paths.get("gateways_csv", "/data/config/gateway.csv"))
- csv_delimiter = str(paths.get("csv_delimiter", ";"))
-
- jobs_dir = Path(cfg.get("jobs_dir", "/data/train/jobs"))
- pending_dir = jobs_dir / "pending"
- done_dir = jobs_dir / "done"
- error_dir = jobs_dir / "error"
- samples_dir = Path(cfg.get("samples_dir", "/data/train/samples"))
-
- pending_dir.mkdir(parents=True, exist_ok=True)
- done_dir.mkdir(parents=True, exist_ok=True)
- error_dir.mkdir(parents=True, exist_ok=True)
- samples_dir.mkdir(parents=True, exist_ok=True)
-
- gw_ready_log_seconds = float(cfg.get("gw_ready_log_seconds", 10))
- gw_ready_sleep_seconds = float(cfg.get("gw_ready_sleep_seconds", 5))
- gw_ready_check_before_job = bool(cfg.get("gw_ready_check_before_job", True))
- online_max_age_s = float(debug.get("online_check_seconds", 30))
- progress_log_seconds = float(cfg.get("wait_all_gateways_log_seconds", 30))
-
- gateway_macs, invalid, duplicates = load_gateway_csv(gateway_csv, delimiter=csv_delimiter)
- log(f"[gateway.csv] loaded gateways={len(gateway_macs)} invalid={invalid} duplicates={duplicates}")
-
- log(
- "COLLECT_TRAIN config: gateway_csv=%s gateways(feature-set)=%d window_seconds=%.1f poll_seconds=%.1f rssi_decimals=%d jobs_dir=%s "
- "pending_dir=%s done_dir=%s error_dir=%s samples_dir=%s mqtt=%s:%s topic=%s"
- % (
- gateway_csv,
- len(gateway_macs),
- window_seconds,
- poll_seconds,
- rssi_decimals,
- jobs_dir,
- pending_dir,
- done_dir,
- error_dir,
- samples_dir,
- mqtt_cfg.get("host", ""),
- mqtt_cfg.get("port", ""),
- mqtt_cfg.get("topic", "publish_out/#"),
- )
- )
-
- fp = FingerprintCollector()
-
- # MQTT setup
- host = mqtt_cfg.get("host", "127.0.0.1")
- port = int(mqtt_cfg.get("port", 1883))
- topic = mqtt_cfg.get("topic", "publish_out/#")
- client_id = mqtt_cfg.get("client_id", "ble-ai-localizer")
- keepalive = int(mqtt_cfg.get("keepalive", 60))
- proto = mqtt.MQTTv311
-
- def on_connect(client, userdata, flags, rc):
- log(f"MQTT connected rc={rc}, subscribed to {topic}")
- client.subscribe(topic)
-
- def on_message(client, userdata, msg):
- gw_from_topic = parse_topic_gateway(msg.topic)
- if not gw_from_topic:
- return
- payload_list = parse_payload_list(msg.payload)
- if not payload_list:
- return
-
- for it in payload_list:
- if not isinstance(it, dict):
- continue
- if is_gateway_announce(it):
- gwm = it.get("mac", gw_from_topic)
- fp.last_seen_gw[norm_mac(gwm)] = time.time()
- continue
-
- bmac = it.get("mac")
- rssi = it.get("rssi")
- if not bmac or rssi is None:
- continue
- try:
- fp.update(gw_from_topic, bmac, float(rssi))
- except Exception:
- continue
-
- client = mqtt.Client(client_id=client_id, protocol=proto)
- client.on_connect = on_connect
- client.on_message = on_message
-
- username = str(mqtt_cfg.get("username", "") or "")
- password = str(mqtt_cfg.get("password", "") or "")
- if username:
- client.username_pw_set(username, password)
-
- tls = bool(mqtt_cfg.get("tls", False))
- if tls:
- client.tls_set(cert_reqs=ssl.CERT_NONE)
- client.tls_insecure_set(True)
-
- log("MQTT thread started (collect_train)")
- client.connect(host, port, keepalive=keepalive)
- client.loop_start()
-
- # Wait gateways online
- last_ready_log = 0.0
- while True:
- now = time.time()
- online = 0
- missing = []
- for g in gateway_macs:
- seen = fp.last_seen_gw.get(norm_mac(g))
- if seen is not None and (now - seen) <= online_max_age_s:
- online += 1
- else:
- missing.append(norm_mac(g))
- if online == len(gateway_macs):
- log(f"GW READY: online={online}/{len(gateway_macs)} (max_age_s={online_max_age_s:.1f})")
- break
- if now - last_ready_log >= gw_ready_log_seconds:
- log(f"WAIT gateways online ({len(missing)} missing, seen={online}/{len(gateway_macs)}): {missing} (max_age_s={online_max_age_s:.1f})")
- last_ready_log = now
- time.sleep(gw_ready_sleep_seconds)
-
- # Job loop
- while True:
- try:
- # periodic gw ready log
- now = time.time()
- if now - last_ready_log >= gw_ready_log_seconds:
- online = 0
- for g in gateway_macs:
- seen = fp.last_seen_gw.get(norm_mac(g))
- if seen is not None and (now - seen) <= online_max_age_s:
- online += 1
- log(f"GW READY: online={online}/{len(gateway_macs)} (max_age_s={online_max_age_s:.1f})")
- last_ready_log = now
-
- # pick job
- job_files = sorted(pending_dir.glob("*.csv"))
- if not job_files:
- time.sleep(poll_seconds)
- continue
-
- job_path = job_files[0]
- job_name = job_path.name
-
- rows = read_job_csv(job_path, delimiter=csv_delimiter)
- if not rows:
- # move empty/bad jobs to error
- log(f"TRAIN job ERROR: {job_name} err=EmptyJob: no valid rows")
- job_path.rename(error_dir / job_path.name)
- continue
-
- # normalize beacons for stats keys
- job_beacons_norm = [norm_mac(r["mac"]) for r in rows]
-
- # optionally wait gateways online before starting the window
- if gw_ready_check_before_job:
- while True:
- now = time.time()
- online = 0
- missing = []
- for g in gateway_macs:
- seen = fp.last_seen_gw.get(norm_mac(g))
- if seen is not None and (now - seen) <= online_max_age_s:
- online += 1
- else:
- missing.append(norm_mac(g))
- if online == len(gateway_macs):
- break
- log(f"WAIT gateways online before job ({len(missing)} missing, seen={online}/{len(gateway_macs)}): {missing}")
- time.sleep(1.0)
-
- log(f"TRAIN job START: {job_name} beacons={len(rows)}")
-
- start = time.time()
- deadline = start + window_seconds
- next_progress = start + progress_log_seconds
-
- while time.time() < deadline:
- time.sleep(0.5)
- if progress_log_seconds > 0 and time.time() >= next_progress:
- st = fp.stats(job_beacons_norm, gateway_macs)
- parts = []
- for b in job_beacons_norm:
- total = sum(st.counts[b].values())
- gw_seen = sum(1 for g in gateway_macs if st.counts[b][g] > 0)
- parts.append(f"{b.replace(':','')}: total={total} gw={gw_seen}/{len(gateway_macs)}")
- elapsed = int(time.time() - start)
- log(f"COLLECT progress: {elapsed}s/{int(window_seconds)}s " + " | ".join(parts))
- next_progress = time.time() + progress_log_seconds
-
- out_rows: List[Dict[str, Any]] = []
- st = fp.stats(job_beacons_norm, gateway_macs)
-
- for r, b_norm in zip(rows, job_beacons_norm):
- feats = fp.feature_row(
- beacon_mac=b_norm,
- gateways=gateway_macs,
- aggregate=aggregate,
- rssi_min=rssi_min,
- rssi_max=rssi_max,
- min_samples_per_gateway=min_samples_per_gateway,
- outlier_method=outlier_method,
- mad_z=mad_z,
- iqr_k=iqr_k,
- max_stddev=max_stddev,
- )
-
- non_nan = sum(1 for g in gateway_macs if feats.get(g) == feats.get(g))
- if non_nan < min_non_nan:
- sample_info = []
- for g in gateway_macs:
- c = st.counts[b_norm][g]
- if c > 0:
- sample_info.append(f"{g} n={c} last={st.last[b_norm][g]}")
- preview = ", ".join(sample_info[:8]) + (" ..." if len(sample_info) > 8 else "")
- log(
- f"WARNING: beacon {b_norm.replace(':','')} low features non_nan={non_nan} "
- f"(seen_gw={sum(1 for g in gateway_macs if st.counts[b_norm][g]>0)}) [{preview}]"
- )
-
- out_row: Dict[str, Any] = {
- "mac": r["mac"], # MAC sempre compatto, senza ':'
- "x": float(r["x"]),
- "y": float(r["y"]),
- "z": float(r["z"]),
- }
- out_row.update(feats)
- out_rows.append(out_row)
-
- written = []
- for out_row in out_rows:
- # Nome file: Z_X_Y.csv (Z, X, Y presi dal job)
- zt = _coord_token(out_row.get("z"))
- xt = _coord_token(out_row.get("x"))
- yt = _coord_token(out_row.get("y"))
- base_name = f"{zt}_{xt}_{yt}.csv"
- out_path = samples_dir / base_name
- write_samples_csv(out_path, [out_row], gateway_macs, delimiter=csv_delimiter, rssi_decimals=rssi_decimals)
- written.append(out_path.name)
-
- job_path.rename(done_dir / job_path.name)
- if written:
- shown = ", ".join(written[:10])
- more = "" if len(written) <= 10 else f" (+{len(written)-10} altri)"
- log(f"TRAIN job DONE: wrote {len(written)} sample files to {samples_dir}: {shown}{more}")
- else:
- log(f"TRAIN job DONE: no output rows (empty job?)")
-
- except Exception as e:
- log(f"TRAIN job ERROR: {job_name} err={type(e).__name__}: {e}")
- try:
- job_path.rename(error_dir / job_path.name)
- except Exception:
- pass
- time.sleep(0.5)
-
-
- def main() -> None:
- settings = load_settings()
- cfg_file = settings.get("_config_file", "")
- keys = [k for k in settings.keys() if not str(k).startswith("_")]
- log(f"Settings loaded from {cfg_file}. Keys: {keys}")
- log(f"BUILD: {build_info()}")
-
- mode = str(settings.get("mode", "collect_train")).strip().lower()
-
- if mode == "collect_train":
- run_collect_train(settings)
- return
-
- if mode == "train":
- from .train_mode import run_train
- run_train(settings)
- return
-
- if mode == "infer":
- from .infer_mode import run_infer
- run_infer(settings)
- return
-
- raise ValueError(f"unknown mode: {mode}")
-
-
- if __name__ == "__main__":
- main()
|