選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。
 
 
 
 

148 行
6.4 KiB

  1. import os
  2. import time
  3. import shutil
  4. import ssl
  5. from pathlib import Path
  6. from typing import Dict, Any, List
  7. import pandas as pd
  8. import paho.mqtt.client as mqtt
  9. from .normalize import norm_mac, mac_to_compact
  10. from .mqtt_parser import parse_publish_out
  11. from .fingerprint import FingerprintWindow
  12. from .logger_utils import log_msg as log
  13. from .gateways import load_gateway_csv
  14. from .beacons import read_job_csv, write_samples_csv, _coord_token
  15. def extract_campaign_id(beacon_name: str) -> str:
  16. """Estrae la campagna dal nome beacon (es. BC-00-41 -> 00)."""
  17. if not beacon_name: return "default"
  18. parts = str(beacon_name).split('-')
  19. if len(parts) >= 2:
  20. return parts[1]
  21. return "default"
  22. def run_collect_train(settings: Dict[str, Any]) -> None:
  23. cfg = settings.get("collect_train", {})
  24. paths = settings.get("paths", {})
  25. mqtt_cfg = settings.get("mqtt", {})
  26. modes = [
  27. {"pending": Path("/data/train/jobs/pending"), "done": Path("/data/train/jobs/done"), "samples": Path("/data/train/samples"), "label": "TRAIN"},
  28. {"pending": Path("/data/train/testjobs/pending"), "done": Path("/data/train/testjobs/done"), "samples": Path("/data/train/testsamples"), "label": "TEST"}
  29. ]
  30. for m in modes:
  31. for d in [m["pending"], m["done"], m["samples"]]: d.mkdir(parents=True, exist_ok=True)
  32. gateway_csv_path = paths.get("gateways_csv", "/data/config/gateway.csv")
  33. csv_delim = paths.get("csv_delimiter", ";")
  34. try:
  35. gw_df = pd.read_csv(gateway_csv_path, sep=csv_delim)
  36. gw_df.columns = [c.strip().lower() for c in gw_df.columns]
  37. gateway_headers = gw_df['mac'].tolist()
  38. gateway_compacts = [mac_to_compact(m) for m in gateway_headers]
  39. total_gw = len(gateway_headers)
  40. except Exception as e:
  41. log(f"ERRORE caricamento gateway: {e}")
  42. return
  43. # --- LOGICA GATEWAY ONLINE ---
  44. online_tracker = {mac_to_compact(m): {"original": m, "online": False} for m in gateway_headers}
  45. def on_check_message(client, userdata, msg):
  46. events = parse_publish_out(msg.topic, msg.payload)
  47. for gw_c, _, _, _ in events:
  48. if gw_c in online_tracker and not online_tracker[gw_c]["online"]:
  49. online_tracker[gw_c]["online"] = True
  50. curr = sum(1 for v in online_tracker.values() if v["online"])
  51. log(f" [{curr}/{total_gw}] ONLINE: {online_tracker[gw_c]['original']}")
  52. check_client = mqtt.Client(protocol=mqtt.MQTTv311)
  53. if mqtt_cfg.get("username"): check_client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password"))
  54. check_client.on_message = on_check_message
  55. check_client.connect(mqtt_cfg.get("host", "localhost"), mqtt_cfg.get("port", 1883))
  56. check_client.subscribe(mqtt_cfg.get("topic", "publish_out/#"))
  57. check_client.loop_start()
  58. while not all(v["online"] for v in online_tracker.values()): time.sleep(1)
  59. check_client.loop_stop()
  60. log("Collector pronto. Monitoraggio directory Job...")
  61. while True:
  62. job_found = False
  63. for m in modes:
  64. job_files = sorted(m["pending"].glob("*.csv"))
  65. if not job_files: continue
  66. job_found = True
  67. job_path = job_files[0]
  68. log(f"[{m['label']}] Elaborazione: {job_path.name}")
  69. job_rows = read_job_csv(job_path, delimiter=csv_delim)
  70. if not job_rows:
  71. job_path.rename(Path("/data/train/jobs/error") / job_path.name)
  72. continue
  73. job_beacon_keys = [mac_to_compact(r["mac"]) for r in job_rows]
  74. active_window = FingerprintWindow(
  75. beacon_keys=job_beacon_keys,
  76. gateway_keys=gateway_compacts,
  77. gateway_headers=gateway_headers,
  78. rssi_min=float(cfg.get("rssi_min", -110)),
  79. rssi_max=float(cfg.get("rssi_max", -25)),
  80. outlier_method=cfg.get("outlier_method", "mad"),
  81. min_samples_per_gateway=int(cfg.get("min_samples_per_gateway", 5))
  82. )
  83. mqtt_timestamps = []
  84. def on_job_message(c, u, msg):
  85. events = parse_publish_out(msg.topic, msg.payload)
  86. for gw_c, b_c, rssi, ts in events:
  87. if b_c in job_beacon_keys:
  88. active_window.add(gw_c, b_c, rssi)
  89. if ts: mqtt_timestamps.append(ts)
  90. client = mqtt.Client(protocol=mqtt.MQTTv311)
  91. if mqtt_cfg.get("username"): client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password"))
  92. client.on_message = on_job_message
  93. client.connect(mqtt_cfg.get("host", "localhost"), mqtt_cfg.get("port", 1883))
  94. client.subscribe(mqtt_cfg.get("topic", "publish_out/#"))
  95. client.loop_start()
  96. time.sleep(int(cfg.get("window_seconds", 30)))
  97. client.loop_stop()
  98. client.disconnect()
  99. ts_start = min(mqtt_timestamps) if mqtt_timestamps else 0
  100. ts_end = max(mqtt_timestamps) if mqtt_timestamps else 0
  101. valid_count = 0
  102. for r in job_rows:
  103. b_mac = r["mac"]
  104. b_compact = mac_to_compact(b_mac)
  105. b_name = r.get("beaconname", b_compact)
  106. campaign = extract_campaign_id(b_name)
  107. feats = active_window.features_for(b_compact, aggregate=cfg.get("aggregate", "median"))
  108. if sum(1 for v in feats.values() if v == v and v is not None) >= int(cfg.get("min_non_nan", 3)):
  109. out_row = {
  110. "mac": b_mac, "x": r["x"], "y": r["y"], "z": r["z"],
  111. "ts_start": ts_start, "ts_end": ts_end
  112. }
  113. out_row.update(feats)
  114. zt, xt, yt = _coord_token(r["z"]), _coord_token(r["x"]), _coord_token(r["y"])
  115. # NUOVA NOMENCLATURA: $CAMPAGNA_$Z_$X_$Y.csv
  116. out_filename = f"{campaign}_{zt}_{xt}_{yt}.csv"
  117. write_samples_csv(m["samples"] / out_filename, [out_row], gateway_headers, delimiter=csv_delim, rssi_decimals=int(cfg.get("rssi_decimals", 0)))
  118. valid_count += 1
  119. shutil.move(str(job_path), str(m["done"] / job_path.name))
  120. log(f" Job completato. Campioni: {valid_count} (MQTT TS: {ts_start}-{ts_end})")
  121. break
  122. if not job_found: time.sleep(2)