import os import time import shutil import ssl from pathlib import Path from typing import Dict, Any, List import pandas as pd import paho.mqtt.client as mqtt from .normalize import norm_mac, mac_to_compact from .mqtt_parser import parse_publish_out from .fingerprint import FingerprintWindow from .logger_utils import log_msg as log from .gateways import load_gateway_csv from .beacons import read_job_csv, write_samples_csv, _coord_token def extract_campaign_id(beacon_name: str) -> str: """Estrae la campagna dal nome beacon (es. BC-00-41 -> 00).""" if not beacon_name: return "default" parts = str(beacon_name).split('-') if len(parts) >= 2: return parts[1] return "default" def run_collect_train(settings: Dict[str, Any]) -> None: cfg = settings.get("collect_train", {}) paths = settings.get("paths", {}) mqtt_cfg = settings.get("mqtt", {}) modes = [ {"pending": Path("/data/train/jobs/pending"), "done": Path("/data/train/jobs/done"), "samples": Path("/data/train/samples"), "label": "TRAIN"}, {"pending": Path("/data/train/testjobs/pending"), "done": Path("/data/train/testjobs/done"), "samples": Path("/data/train/testsamples"), "label": "TEST"} ] for m in modes: for d in [m["pending"], m["done"], m["samples"]]: d.mkdir(parents=True, exist_ok=True) gateway_csv_path = paths.get("gateways_csv", "/data/config/gateway.csv") csv_delim = paths.get("csv_delimiter", ";") try: gw_df = pd.read_csv(gateway_csv_path, sep=csv_delim) gw_df.columns = [c.strip().lower() for c in gw_df.columns] gateway_headers = gw_df['mac'].tolist() gateway_compacts = [mac_to_compact(m) for m in gateway_headers] total_gw = len(gateway_headers) except Exception as e: log(f"ERRORE caricamento gateway: {e}") return # --- LOGICA GATEWAY ONLINE --- online_tracker = {mac_to_compact(m): {"original": m, "online": False} for m in gateway_headers} def on_check_message(client, userdata, msg): events = parse_publish_out(msg.topic, msg.payload) for gw_c, _, _, _ in events: if gw_c in online_tracker and not online_tracker[gw_c]["online"]: online_tracker[gw_c]["online"] = True curr = sum(1 for v in online_tracker.values() if v["online"]) log(f" [{curr}/{total_gw}] ONLINE: {online_tracker[gw_c]['original']}") check_client = mqtt.Client(protocol=mqtt.MQTTv311) if mqtt_cfg.get("username"): check_client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password")) check_client.on_message = on_check_message check_client.connect(mqtt_cfg.get("host", "localhost"), mqtt_cfg.get("port", 1883)) check_client.subscribe(mqtt_cfg.get("topic", "publish_out/#")) check_client.loop_start() while not all(v["online"] for v in online_tracker.values()): time.sleep(1) check_client.loop_stop() log("Collector pronto. Monitoraggio directory Job...") while True: job_found = False for m in modes: job_files = sorted(m["pending"].glob("*.csv")) if not job_files: continue job_found = True job_path = job_files[0] log(f"[{m['label']}] Elaborazione: {job_path.name}") job_rows = read_job_csv(job_path, delimiter=csv_delim) if not job_rows: job_path.rename(Path("/data/train/jobs/error") / job_path.name) continue job_beacon_keys = [mac_to_compact(r["mac"]) for r in job_rows] active_window = FingerprintWindow( beacon_keys=job_beacon_keys, gateway_keys=gateway_compacts, gateway_headers=gateway_headers, rssi_min=float(cfg.get("rssi_min", -110)), rssi_max=float(cfg.get("rssi_max", -25)), outlier_method=cfg.get("outlier_method", "mad"), min_samples_per_gateway=int(cfg.get("min_samples_per_gateway", 5)) ) mqtt_timestamps = [] def on_job_message(c, u, msg): events = parse_publish_out(msg.topic, msg.payload) for gw_c, b_c, rssi, ts in events: if b_c in job_beacon_keys: active_window.add(gw_c, b_c, rssi) if ts: mqtt_timestamps.append(ts) client = mqtt.Client(protocol=mqtt.MQTTv311) if mqtt_cfg.get("username"): client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password")) client.on_message = on_job_message client.connect(mqtt_cfg.get("host", "localhost"), mqtt_cfg.get("port", 1883)) client.subscribe(mqtt_cfg.get("topic", "publish_out/#")) client.loop_start() time.sleep(int(cfg.get("window_seconds", 30))) client.loop_stop() client.disconnect() ts_start = min(mqtt_timestamps) if mqtt_timestamps else 0 ts_end = max(mqtt_timestamps) if mqtt_timestamps else 0 valid_count = 0 for r in job_rows: b_mac = r["mac"] b_compact = mac_to_compact(b_mac) b_name = r.get("beaconname", b_compact) campaign = extract_campaign_id(b_name) feats = active_window.features_for(b_compact, aggregate=cfg.get("aggregate", "median")) if sum(1 for v in feats.values() if v == v and v is not None) >= int(cfg.get("min_non_nan", 3)): out_row = { "mac": b_mac, "x": r["x"], "y": r["y"], "z": r["z"], "ts_start": ts_start, "ts_end": ts_end } out_row.update(feats) zt, xt, yt = _coord_token(r["z"]), _coord_token(r["x"]), _coord_token(r["y"]) # NUOVA NOMENCLATURA: $CAMPAGNA_$Z_$X_$Y.csv out_filename = f"{campaign}_{zt}_{xt}_{yt}.csv" write_samples_csv(m["samples"] / out_filename, [out_row], gateway_headers, delimiter=csv_delim, rssi_decimals=int(cfg.get("rssi_decimals", 0))) valid_count += 1 shutil.move(str(job_path), str(m["done"] / job_path.name)) log(f" Job completato. Campioni: {valid_count} (MQTT TS: {ts_start}-{ts_end})") break if not job_found: time.sleep(2)