Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.
 
 
 
 

113 rindas
4.0 KiB

  1. import os
  2. import json
  3. import time
  4. import joblib
  5. import pandas as pd
  6. import numpy as np
  7. from pathlib import Path
  8. from datetime import datetime
  9. from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
  10. # Import delle utility esistenti
  11. from .logger_utils import log_msg as log
  12. from .csv_config import load_gateway_features_csv
  13. def process_train_jobs():
  14. """Monitora ed esegue i job di addestramento salvando backup cronologici."""
  15. JOBS_DIR = Path("/data/train/train_jobs")
  16. JOBS_DIR.mkdir(parents=True, exist_ok=True)
  17. job_files = list(JOBS_DIR.glob("*.lock"))
  18. if not job_files:
  19. return
  20. for job_path in job_files:
  21. try:
  22. log(f"[TRAIN-CORE] Rilevato nuovo job: {job_path.name}")
  23. with open(job_path, "r") as f:
  24. job = json.load(f)
  25. campagna = job["campaign"]
  26. knn_cfg = job["knn"]
  27. nan_fill = job["nan_fill"]
  28. gw_csv = job["gateways_csv"]
  29. # --- GENERAZIONE NOME FILE CON TIMESTAMP (BACKUP) ---
  30. now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
  31. model_filename = f"model_camp_{campagna}_{now_str}.joblib"
  32. model_path = Path("/data/model") / model_filename
  33. # Caricamento Gateway
  34. gws = load_gateway_features_csv(gw_csv)
  35. gateways_order = [g.mac for g in gws]
  36. # Analisi file campioni
  37. samples_dir = Path("/data/train/samples")
  38. sample_files = list(samples_dir.glob(f"{campagna}_*.csv"))
  39. X_list, y_z, y_xy = [], [], []
  40. for fp in sample_files:
  41. try:
  42. df = pd.read_csv(fp, sep=";")
  43. if df.empty: continue
  44. row = df.iloc[0]
  45. # Mapping RSSI basato su gateway.csv (risolve errore 'mac')
  46. X_list.append([float(row.get(gw, nan_fill)) for gw in gateways_order])
  47. y_z.append(int(round(float(row.get("z")))))
  48. y_xy.append([float(row.get("x")), float(row.get("y"))])
  49. except: continue
  50. if not X_list:
  51. log(f"[TRAIN-CORE] ERRORE: Dati non validi per campagna {campagna}")
  52. job_path.unlink()
  53. continue
  54. X, Y_z, Y_xy = np.array(X_list), np.array(y_z), np.array(y_xy)
  55. # Fitting KNN
  56. log(f"[TRAIN-CORE] Fitting modello per {model_filename}...")
  57. floor_clf = KNeighborsClassifier(
  58. n_neighbors=int(knn_cfg.get('k', 5)),
  59. weights=knn_cfg.get('weights', 'distance'),
  60. metric=knn_cfg.get('metric', 'euclidean')
  61. ).fit(X, Y_z)
  62. models_xy = {}
  63. for z in np.unique(Y_z):
  64. idx = np.where(Y_z == z)[0]
  65. models_xy[int(z)] = KNeighborsRegressor(
  66. n_neighbors=min(int(knn_cfg.get('k', 5)), len(idx)),
  67. weights=knn_cfg.get('weights', 'distance'),
  68. metric=knn_cfg.get('metric', 'euclidean')
  69. ).fit(X[idx], Y_xy[idx])
  70. # Salvataggio Pacchetto
  71. model_pkg = {
  72. "floor_clf": floor_clf,
  73. "xy_by_floor": models_xy,
  74. "gateways_order": gateways_order,
  75. "nan_fill": nan_fill,
  76. "created_at": datetime.now().isoformat(),
  77. "campaign": campagna,
  78. "filename": model_filename
  79. }
  80. Path("/data/model").mkdir(parents=True, exist_ok=True)
  81. joblib.dump(model_pkg, model_path)
  82. log(f"[TRAIN-CORE] ✅ Addestramento COMPLETATO: {model_filename}")
  83. except Exception as e:
  84. log(f"[TRAIN-CORE] ❌ ERRORE CRITICO: {str(e)}")
  85. finally:
  86. if job_path.exists():
  87. job_path.unlink()
  88. def run_train_monitor():
  89. """Loop di monitoraggio per il Core Orchestrator."""
  90. while True:
  91. process_train_jobs()
  92. time.sleep(5)