Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.
 
 
 
 

86 lignes
2.9 KiB

  1. # app/train_mode.py
  2. import os
  3. import glob
  4. import time
  5. import math
  6. import joblib
  7. import numpy as np
  8. import pandas as pd
  9. from datetime import datetime
  10. from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
  11. # Import assoluti garantiti
  12. from csv_config import load_gateway_features_csv
  13. from logger_utils import log_msg as log
  14. def run_train(settings, log_fn=None, target_files=None):
  15. """
  16. Esegue l'addestramento Hierarchical KNN su un set di file specifico (Campagna).
  17. """
  18. if log_fn is None:
  19. log_fn = log
  20. train_cfg = settings.get("train", {})
  21. knn_cfg = train_cfg.get("knn", {})
  22. # Parametri da config.yaml
  23. samples_dir = train_cfg.get("samples_dir", "/data/train/samples")
  24. gateways_csv = train_cfg.get("gateways_csv", "/data/config/gateway.csv")
  25. model_path = train_cfg.get("model_path", "/data/model/model.joblib")
  26. nan_fill = float(train_cfg.get("nan_fill", -110.0))
  27. k_val = int(knn_cfg.get("k", 5))
  28. weights = knn_cfg.get("weights", "distance")
  29. metric = knn_cfg.get("metric", "euclidean")
  30. log_fn(f"TRAIN: Caricamento gateway da {gateways_csv}")
  31. gws = load_gateway_features_csv(gateways_csv)
  32. gateways_order = [g.mac for g in gws]
  33. # Selezione file (Campagna specifica o globale)
  34. files = target_files if target_files else glob.glob(os.path.join(samples_dir, "*.csv"))
  35. if not files:
  36. raise RuntimeError("Nessun file CSV trovato per l'addestramento.")
  37. # Costruzione dataset
  38. X_list, y_z, y_xy = [], [], []
  39. for fp in files:
  40. try:
  41. df = pd.read_csv(fp, sep=";")
  42. if df.empty: continue
  43. row = df.iloc[0]
  44. X_list.append([float(row.get(gw, nan_fill)) for gw in gateways_order])
  45. y_z.append(int(round(float(row.get("z")))))
  46. y_xy.append([float(row.get("x")), float(row.get("y"))])
  47. except: continue
  48. X, Y_z, Y_xy = np.array(X_list), np.array(y_z), np.array(y_xy)
  49. # Step 1: Classificatore Piano (Z)
  50. log_fn(f"TRAIN: Fitting Piano Classifier (K={k_val})")
  51. floor_clf = KNeighborsClassifier(n_neighbors=k_val, weights=weights, metric=metric).fit(X, Y_z)
  52. # Step 2: Regressori X,Y per ogni piano trovato
  53. models_xy = {}
  54. for z in np.unique(Y_z):
  55. idx = np.where(Y_z == z)[0]
  56. log_fn(f"TRAIN: Fitting XY Regressor piano {z} ({len(idx)} campioni)")
  57. models_xy[int(z)] = KNeighborsRegressor(
  58. n_neighbors=min(k_val, len(idx)),
  59. weights=weights,
  60. metric=metric
  61. ).fit(X[idx], Y_xy[idx])
  62. # Salvataggio
  63. model_data = {
  64. "floor_clf": floor_clf,
  65. "xy_by_floor": models_xy,
  66. "gateways_order": gateways_order,
  67. "nan_fill": nan_fill,
  68. "created_at": datetime.now().isoformat()
  69. }
  70. os.makedirs(os.path.dirname(model_path), exist_ok=True)
  71. joblib.dump(model_data, model_path)
  72. log_fn(f"✅ TRAIN SUCCESS: Modello salvato in {model_path}")