No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.
 
 
 
 

420 líneas
14 KiB

  1. # app/train_mode.py
  2. # Training mode: build hierarchical KNN model (floor classifier + per-floor X/Y regressors)
  3. # Adds verbose dataset statistics useful for large training runs.
  4. from __future__ import annotations
  5. import glob
  6. import os
  7. import time
  8. import math
  9. from dataclasses import dataclass
  10. from typing import Any, Callable, Dict, List, Optional, Tuple
  11. import joblib
  12. from datetime import datetime
  13. import numpy as np
  14. import pandas as pd
  15. from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
  16. import sklearn
  17. # NOTE: these are already present in the project
  18. from .csv_config import load_gateway_features_csv
  19. from .logger_utils import log_msg as log
  20. @dataclass
  21. class GatewayStats:
  22. mac: str
  23. total_samples: int = 0 # total rows processed (per sample point)
  24. non_missing: int = 0 # non-missing rssi count
  25. missing: int = 0 # missing (nan) count
  26. sum_: float = 0.0
  27. sumsq: float = 0.0
  28. min_: float = float("inf")
  29. max_: float = float("-inf")
  30. def add(self, v: float, is_missing: bool) -> None:
  31. self.total_samples += 1
  32. if is_missing:
  33. self.missing += 1
  34. return
  35. self.non_missing += 1
  36. self.sum_ += v
  37. self.sumsq += v * v
  38. if v < self.min_:
  39. self.min_ = v
  40. if v > self.max_:
  41. self.max_ = v
  42. def mean(self) -> float:
  43. return self.sum_ / self.non_missing if self.non_missing else float("nan")
  44. def std(self) -> float:
  45. if self.non_missing <= 1:
  46. return float("nan")
  47. mu = self.mean()
  48. var = max(0.0, (self.sumsq / self.non_missing) - (mu * mu))
  49. return math.sqrt(var)
  50. def missing_pct(self) -> float:
  51. return (self.missing / self.total_samples) * 100.0 if self.total_samples else 0.0
  52. def _get(d: Dict[str, Any], key: str, default: Any = None) -> Any:
  53. return d.get(key, default) if isinstance(d, dict) else default
  54. def _as_bool(v: Any, default: bool = False) -> bool:
  55. if v is None:
  56. return default
  57. if isinstance(v, bool):
  58. return v
  59. if isinstance(v, (int, float)):
  60. return bool(v)
  61. s = str(v).strip().lower()
  62. return s in ("1", "true", "yes", "y", "on")
  63. def _safe_float(v: Any) -> Optional[float]:
  64. try:
  65. if v is None:
  66. return None
  67. if isinstance(v, float) and math.isnan(v):
  68. return None
  69. return float(v)
  70. except Exception:
  71. return None
  72. def _collect_dataset(
  73. sample_files: List[str],
  74. gateways_order: List[str],
  75. nan_fill: float,
  76. log: Callable[[str], None],
  77. verbose: bool,
  78. ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict[str, GatewayStats], Dict[str, Any]]:
  79. """
  80. Build dataset from per-point sample csv files.
  81. Each sample file is expected to contain:
  82. header: mac;x;y;z;<GW1>;<GW2>...
  83. 1 row: beacon_mac; x; y; z; rssi_gw1; rssi_gw2; ...
  84. Returns:
  85. X (N, G), y_floor (N,), y_xy (N,2), meta_xy (N,2),
  86. gw_stats, global_stats
  87. """
  88. X_rows: List[List[float]] = []
  89. y_floor: List[int] = []
  90. y_xy: List[List[float]] = []
  91. meta_xy: List[List[float]] = []
  92. gw_stats: Dict[str, GatewayStats] = {gw: GatewayStats(mac=gw) for gw in gateways_order}
  93. floors_counter: Dict[int, int] = {}
  94. bad_files: int = 0
  95. missing_cols_files: int = 0
  96. expected_cols: Optional[List[str]] = None
  97. for fp in sample_files:
  98. try:
  99. df = pd.read_csv(fp, sep=";", dtype=str)
  100. except Exception as e:
  101. bad_files += 1
  102. if verbose:
  103. log(f"TRAIN WARN: cannot read sample file {fp}: {type(e).__name__}: {e}")
  104. continue
  105. if df.shape[0] < 1:
  106. bad_files += 1
  107. if verbose:
  108. log(f"TRAIN WARN: empty sample file {fp}")
  109. continue
  110. row = df.iloc[0].to_dict()
  111. if verbose:
  112. cols = list(df.columns)
  113. if expected_cols is None:
  114. expected_cols = cols
  115. elif cols != expected_cols:
  116. missing_cols_files += 1
  117. if missing_cols_files <= 5:
  118. log(f"TRAIN WARN: columns mismatch in {os.path.basename(fp)} (expected {len(expected_cols)} cols, got {len(cols)})")
  119. x = _safe_float(row.get("x"))
  120. y = _safe_float(row.get("y"))
  121. z = _safe_float(row.get("z"))
  122. if x is None or y is None or z is None:
  123. bad_files += 1
  124. if verbose:
  125. log(f"TRAIN WARN: missing x/y/z in {fp}")
  126. continue
  127. z_i = int(round(z))
  128. floors_counter[z_i] = floors_counter.get(z_i, 0) + 1
  129. feats: List[float] = []
  130. for gw in gateways_order:
  131. v = row.get(gw)
  132. fv = _safe_float(v)
  133. if fv is None:
  134. feats.append(nan_fill)
  135. gw_stats[gw].add(nan_fill, is_missing=True)
  136. else:
  137. feats.append(fv)
  138. gw_stats[gw].add(fv, is_missing=False)
  139. X_rows.append(feats)
  140. y_floor.append(z_i)
  141. y_xy.append([x, y])
  142. meta_xy.append([x, y])
  143. if not X_rows:
  144. raise RuntimeError("No valid samples found in samples_dir (dataset empty).")
  145. X = np.asarray(X_rows, dtype=np.float32)
  146. y_floor_arr = np.asarray(y_floor, dtype=np.int32)
  147. y_xy_arr = np.asarray(y_xy, dtype=np.float32)
  148. meta_xy_arr = np.asarray(meta_xy, dtype=np.float32)
  149. global_stats: Dict[str, Any] = {
  150. "samples_total_files": len(sample_files),
  151. "samples_used": int(X.shape[0]),
  152. "samples_bad": int(bad_files),
  153. "floors_counts": dict(sorted(floors_counter.items(), key=lambda kv: kv[0])),
  154. "missing_cols_files": int(missing_cols_files),
  155. "gateways": int(len(gateways_order)),
  156. "nan_fill": float(nan_fill),
  157. }
  158. return X, y_floor_arr, y_xy_arr, meta_xy_arr, gw_stats, global_stats
  159. def _log_train_stats(
  160. log: Callable[[str], None],
  161. X: np.ndarray,
  162. y_floor: np.ndarray,
  163. y_xy: np.ndarray,
  164. gateways_order: List[str],
  165. nan_fill: float,
  166. gw_stats: Dict[str, GatewayStats],
  167. global_stats: Dict[str, Any],
  168. top_k: int = 8,
  169. ) -> None:
  170. """Human-friendly statistics for training runs."""
  171. log(
  172. "TRAIN stats: "
  173. f"samples_used={global_stats.get('samples_used')} "
  174. f"samples_bad={global_stats.get('samples_bad')} "
  175. f"files_total={global_stats.get('samples_total_files')} "
  176. f"gateways={len(gateways_order)} "
  177. f"floors={list(global_stats.get('floors_counts', {}).keys())}"
  178. )
  179. if global_stats.get("missing_cols_files", 0):
  180. log(f"TRAIN stats: files_with_column_mismatch={global_stats['missing_cols_files']} (see earlier WARN lines)")
  181. xs = y_xy[:, 0]
  182. ys = y_xy[:, 1]
  183. log(
  184. "TRAIN stats: XY range "
  185. f"X[min,max]=[{float(np.min(xs)):.2f},{float(np.max(xs)):.2f}] "
  186. f"Y[min,max]=[{float(np.min(ys)):.2f},{float(np.max(ys)):.2f}]"
  187. )
  188. miss = int((X == nan_fill).sum())
  189. total = int(X.size)
  190. miss_pct = (miss / total) * 100.0 if total else 0.0
  191. log(f"TRAIN stats: feature sparsity missing={miss}/{total} ({miss_pct:.1f}%) using nan_fill={nan_fill}")
  192. gw_list = list(gw_stats.values())
  193. gw_list_sorted = sorted(gw_list, key=lambda s: (s.missing_pct(), -s.non_missing), reverse=True)
  194. worst = gw_list_sorted[: max(1, min(top_k, len(gw_list_sorted)))]
  195. worst_str = " | ".join(
  196. f"{g.mac}: miss={g.missing_pct():.1f}% (seen={g.non_missing}) mean={g.mean():.1f} std={g.std():.1f}"
  197. for g in worst
  198. )
  199. log(f"TRAIN stats: gateways with highest missing%: {worst_str}")
  200. best = list(reversed(gw_list_sorted))[: max(1, min(top_k, len(gw_list_sorted)))]
  201. best_str = " | ".join(
  202. f"{g.mac}: miss={g.missing_pct():.1f}% (seen={g.non_missing}) mean={g.mean():.1f} std={g.std():.1f}"
  203. for g in best
  204. )
  205. log(f"TRAIN stats: gateways with lowest missing%: {best_str}")
  206. floors = global_stats.get("floors_counts", {})
  207. if floors:
  208. floor_str = ", ".join(f"z={k}:{v}" for k, v in floors.items())
  209. log(f"TRAIN stats: floor distribution: {floor_str}")
  210. def run_train(settings: Dict[str, Any], log: Optional[Callable[[str], None]] = None) -> None:
  211. """
  212. Train hierarchical KNN:
  213. - KNeighborsClassifier for floor (Z)
  214. - For each floor, a KNeighborsRegressor for (X,Y) as multioutput
  215. Model saved with joblib to paths.model (or train.model_path).
  216. """
  217. if log is None:
  218. def log(msg: str) -> None:
  219. print(msg, flush=True)
  220. # Build stamp for this module (helps verifying which file is running)
  221. try:
  222. import hashlib
  223. from pathlib import Path
  224. _b = Path(__file__).read_bytes()
  225. log(f"TRAIN_MODE build sha256={hashlib.sha256(_b).hexdigest()[:12]} size={len(_b)}")
  226. except Exception:
  227. pass
  228. train_cfg = _get(settings, "train", {})
  229. paths = _get(settings, "paths", {})
  230. debug = _get(settings, "debug", {})
  231. samples_dir = _get(train_cfg, "samples_dir", _get(paths, "samples_dir", "/data/train/samples"))
  232. gateways_csv = _get(train_cfg, "gateways_csv", _get(paths, "gateways_csv", "/data/config/gateway.csv"))
  233. model_path = _get(train_cfg, "model_path", _get(paths, "model", "/data/model/model.joblib"))
  234. nan_fill = float(_get(train_cfg, "nan_fill", -110.0))
  235. k_floor = int(_get(train_cfg, "k_floor", _get(_get(settings, "ml", {}), "k", 7)))
  236. k_xy = int(_get(train_cfg, "k_xy", _get(_get(settings, "ml", {}), "k", 7)))
  237. weights = str(_get(train_cfg, "weights", _get(_get(settings, "ml", {}), "weights", "distance")))
  238. metric = str(_get(train_cfg, "metric", _get(_get(settings, "ml", {}), "metric", "euclidean")))
  239. verbose = _as_bool(_get(debug, "train_verbose", True), True)
  240. top_k = int(_get(debug, "train_stats_top_k", 8))
  241. backup_existing_model = _as_bool(_get(train_cfg, "backup_existing_model", True), True)
  242. log(
  243. "TRAIN config: "
  244. f"samples_dir={samples_dir} "
  245. f"gateways_csv={gateways_csv} "
  246. f"model_path={model_path} "
  247. f"nan_fill={nan_fill} "
  248. f"k_floor={k_floor} k_xy={k_xy} "
  249. f"weights={weights} metric={metric} "
  250. f"train_verbose={verbose} backup_existing_model={backup_existing_model}"
  251. )
  252. # 1) Load gateways definition to know feature order
  253. gws = load_gateway_features_csv(str(gateways_csv))
  254. gateways_order = [g.mac for g in gws]
  255. if not gateways_order:
  256. raise RuntimeError("No gateways found in gateways_csv (feature-set empty).")
  257. if verbose:
  258. preview = ", ".join(gateways_order[: min(6, len(gateways_order))])
  259. log(f"TRAIN: gateways(feature-order)={len(gateways_order)} first=[{preview}{'...' if len(gateways_order) > 6 else ''}]")
  260. # 2) Collect sample files
  261. sample_files = sorted(glob.glob(os.path.join(samples_dir, "*.csv")))
  262. if not sample_files:
  263. raise RuntimeError(f"No sample files found in samples_dir={samples_dir}")
  264. X, y_floor, y_xy, meta_xy, gw_stats, global_stats = _collect_dataset(
  265. sample_files=sample_files,
  266. gateways_order=gateways_order,
  267. nan_fill=nan_fill,
  268. log=log,
  269. verbose=verbose,
  270. )
  271. if verbose:
  272. _log_train_stats(
  273. log=log,
  274. X=X,
  275. y_floor=y_floor,
  276. y_xy=meta_xy,
  277. gateways_order=gateways_order,
  278. nan_fill=nan_fill,
  279. gw_stats=gw_stats,
  280. global_stats=global_stats,
  281. top_k=top_k,
  282. )
  283. # 3) Fit floor classifier
  284. floor_clf = KNeighborsClassifier(
  285. n_neighbors=k_floor,
  286. weights=weights,
  287. metric=metric,
  288. )
  289. floor_clf.fit(X, y_floor)
  290. # 4) Fit per-floor XY regressors (multioutput)
  291. models_xy: Dict[int, Any] = {}
  292. floors = sorted(set(int(z) for z in y_floor.tolist()))
  293. for z in floors:
  294. idx = np.where(y_floor == z)[0]
  295. Xz = X[idx, :]
  296. yz = y_xy[idx, :] # (N,2)
  297. reg = KNeighborsRegressor(
  298. n_neighbors=k_xy,
  299. weights=weights,
  300. metric=metric,
  301. )
  302. reg.fit(Xz, yz)
  303. models_xy[int(z)] = reg
  304. if verbose:
  305. xs = yz[:, 0]
  306. ys = yz[:, 1]
  307. log(
  308. f"TRAIN: floor z={z} samples={int(len(idx))} "
  309. f"Xrange=[{float(np.min(xs)):.1f},{float(np.max(xs)):.1f}] "
  310. f"Yrange=[{float(np.min(ys)):.1f},{float(np.max(ys)):.1f}]"
  311. )
  312. model = {
  313. "type": "hier_knn_floor_xy",
  314. "gateways_order": gateways_order,
  315. "nan_fill": nan_fill,
  316. "k_floor": k_floor,
  317. "k_xy": k_xy,
  318. "weights": weights,
  319. "metric": metric,
  320. "floor_clf": floor_clf,
  321. "xy_by_floor": models_xy,
  322. "floors": floors,
  323. }
  324. os.makedirs(os.path.dirname(model_path), exist_ok=True)
  325. # Backup previous model (così inferenza può continuare ad usare una versione nota)
  326. backup_path = None
  327. if backup_existing_model and os.path.exists(model_path):
  328. root, ext = os.path.splitext(model_path)
  329. ts = int(time.time())
  330. # evita collisioni se lanci due train nello stesso secondo
  331. for bump in range(0, 1000):
  332. cand = f"{root}_{ts + bump}{ext}"
  333. if not os.path.exists(cand):
  334. backup_path = cand
  335. break
  336. try:
  337. if backup_path:
  338. os.replace(model_path, backup_path)
  339. log(f"TRAIN: previous model moved to {backup_path}")
  340. except Exception as e:
  341. log(f"TRAIN WARNING: cannot backup previous model {model_path}: {type(e).__name__}: {e}")
  342. # Metadata utile (tipo 'modinfo' minimale)
  343. model["created_at_utc"] = datetime.utcnow().replace(microsecond=0).isoformat() + "Z"
  344. model["sklearn_version"] = getattr(sklearn, "__version__", "unknown")
  345. model["numpy_version"] = getattr(np, "__version__", "unknown")
  346. joblib.dump(model, model_path)
  347. log(
  348. f"TRAIN DONE: model saved to {model_path} "
  349. f"(samples={int(X.shape[0])}, gateways={len(gateways_order)}, floors={len(floors)})"
  350. )