You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

163 rivejä
7.0 KiB

  1. import streamlit as st
  2. import pandas as pd
  3. import os
  4. import joblib
  5. import numpy as np
  6. import math
  7. import folium
  8. import json
  9. import base64
  10. from io import BytesIO
  11. from pathlib import Path
  12. from PIL import Image
  13. from streamlit_folium import st_folium
  14. # --- UTILS ---
  15. @st.cache_data
  16. def get_image_base64(img_path):
  17. img = Image.open(img_path).convert("RGBA")
  18. w, h = img.size
  19. buffered = BytesIO()
  20. img.save(buffered, format="PNG")
  21. img_str = base64.b64encode(buffered.getvalue()).decode("ascii")
  22. return f"data:image/png;base64,{img_str}", w, h
  23. def calculate_error(z_real, x_real, y_real, z_pred, x_pred, y_pred):
  24. z_err = abs(z_real - z_pred)
  25. dist_err = math.sqrt((x_real - x_pred)**2 + (y_real - y_pred)**2)
  26. return z_err, dist_err
  27. def show_test_inference(cfg):
  28. st.subheader("🧪 Test Inferenza Offline")
  29. MODEL_DIR = Path("/data/model")
  30. TEST_SAMPLES_DIR = Path("/data/train/testsamples")
  31. MAPS_DIR = Path("/data/maps")
  32. available_models = sorted([m.name for m in MODEL_DIR.glob("model_camp_*.joblib")], reverse=True)
  33. test_files = sorted([f.name for f in TEST_SAMPLES_DIR.glob("*.csv")])
  34. if not available_models or not test_files:
  35. st.warning("Verifica la presenza di modelli e campioni di test.")
  36. return
  37. col1, col2 = st.columns(2)
  38. selected_model = col1.selectbox("🎯 Seleziona Modello:", available_models)
  39. selected_test = col2.selectbox("📄 Seleziona Fingerprint:", test_files)
  40. if "test_results" not in st.session_state:
  41. st.session_state.test_results = None
  42. if st.button("🚀 ESEGUI TEST DI PRECISIONE", type="primary", use_container_width=True):
  43. try:
  44. m_pkg = joblib.load(MODEL_DIR / selected_model)
  45. delim = cfg.get('paths', {}).get('csv_delimiter', ';')
  46. df_test = pd.read_csv(TEST_SAMPLES_DIR / selected_test, sep=delim)
  47. row = df_test.iloc[0]
  48. z_real, x_real, y_real = int(round(float(row['z']))), float(row['x']), float(row['y'])
  49. gws = m_pkg['gateways_order']
  50. fill = m_pkg.get('nan_fill', -110.0)
  51. # --- DEBUG VETTORE INPUT ---
  52. raw_vals = []
  53. for gw in gws:
  54. val = row.get(gw)
  55. raw_vals.append(float(val) if val is not None and not pd.isna(val) else fill)
  56. X_in = np.array([raw_vals])
  57. # Visualizzazione Debug in UI per l'operatore
  58. with st.expander("🔍 Analisi Vettore di Input (Fingerprint vs Modello)"):
  59. debug_df = pd.DataFrame({
  60. "Gateway MAC": gws,
  61. "RSSI Letto": [row.get(gw, "NON TROVATO") for gw in gws],
  62. "RSSI Finale (Input AI)": raw_vals
  63. })
  64. st.dataframe(debug_df)
  65. match_count = np.sum(X_in[0] > fill)
  66. st.write(f"**Gateway Corrispondenti:** {match_count} su {len(gws)}")
  67. if match_count == 0:
  68. st.error("ERRORE: Il file di test non contiene nessuno dei Gateway usati per l'addestramento!")
  69. # Predizione
  70. z_pred = int(m_pkg['floor_clf'].predict(X_in)[0])
  71. x_pred, y_pred = -1.0, -1.0
  72. if z_pred in m_pkg['xy_by_floor']:
  73. xy = m_pkg['xy_by_floor'][z_pred].predict(X_in)[0]
  74. x_pred, y_pred = float(xy[0]), float(xy[1])
  75. z_err, dist_err = calculate_error(z_real, x_real, y_real, z_pred, x_pred, y_pred)
  76. st.session_state.test_results = {
  77. "z_real": z_real, "x_real": x_real, "y_real": y_real,
  78. "z_pred": z_pred, "x_pred": x_pred, "y_pred": y_pred,
  79. "z_err": z_err, "dist_err": dist_err, "model_used": selected_model
  80. }
  81. except Exception as e:
  82. st.error(f"Errore durante l'inferenza: {e}")
  83. # --- VISUALIZZAZIONE GRAFICA DEI RISULTATI ---
  84. if st.session_state.test_results:
  85. res = st.session_state.test_results
  86. c_info1, c_info2 = st.columns(2)
  87. with c_info1:
  88. st.info(f"📍 **Test Reale** | Piano: {res['z_real']} | X: **{res['x_real']}** | Y: **{res['y_real']}**")
  89. with c_info2:
  90. st.success(f"🔮 **Predizione** | Piano: {res['z_pred']} | X: **{round(res['x_pred'],1)}** | Y: **{round(res['y_pred'],1)}**")
  91. st.markdown(f"""
  92. <div style="background-color: #f8f9fa; padding: 12px; border-radius: 8px; border-left: 5px solid #00bcd4; margin: 10px 0;">
  93. <h5 style="margin:0 0 8px 0;">📍 Legenda Mappa</h5>
  94. <span style="color: #00bcd4; font-weight: bold;">● PUNTO DI TEST:</span> Posizione reale del rilievo.<br>
  95. <span style="color: #ff9800; font-weight: bold;">● PUNTO PREDETTO:</span> Posizione calcolata dal modello AI.<br>
  96. <span style="color: #ffeb3b; font-weight: bold;">--- LINEA GIALLA:</span> Scostamento di <b>{round(res['dist_err'], 2)} px</b>.
  97. </div>
  98. """, unsafe_allow_html=True)
  99. img_filename = f"floor_{res['z_pred']}.png"
  100. meta_filename = f"meta_{res['z_pred']}.json"
  101. img_p = MAPS_DIR / img_filename
  102. meta_p = MAPS_DIR / meta_filename
  103. if img_p.exists() and meta_p.exists():
  104. with open(meta_p, "r") as f: meta = json.load(f)
  105. img_data, w, h = get_image_base64(img_p)
  106. bounds = [[0, 0], [h, w]]
  107. m = folium.Map(location=[h/2, w/2], crs="Simple", tiles=None, zoom_start=0, attribution_control=False)
  108. m.fit_bounds(bounds)
  109. m.options.update({"minZoom": -5, "maxZoom": 5, "maxBounds": bounds, "zoomSnap": 0.25})
  110. folium.raster_layers.ImageOverlay(image=img_data, bounds=bounds).add_to(m)
  111. if meta.get("calibrated"):
  112. def to_px(mx, my):
  113. px = (mx * meta["pixel_ratio"]) + meta["origin"][0]
  114. py = meta["origin"][1] - (my * meta["pixel_ratio"])
  115. return [py, px]
  116. p_real = to_px(res['x_real'], res['y_real'])
  117. p_pred = to_px(res['x_pred'], res['y_pred'])
  118. # Marker Posizione Reale
  119. folium.CircleMarker(
  120. location=p_real, radius=11, color="#00838f", fill=True,
  121. fill_color="#00bcd4", fill_opacity=0.85, tooltip="PUNTO DI TEST (REALE)"
  122. ).add_to(m)
  123. # Marker Predizione
  124. if res['x_pred'] != -1.0:
  125. folium.CircleMarker(
  126. location=p_pred, radius=11, color="#e65100", fill=True,
  127. fill_color="#ff9800", fill_opacity=0.85,
  128. tooltip=f"PREDIZIONE (Modello: {res['model_used']})"
  129. ).add_to(m)
  130. # Linea di errore (Gialla tratteggiata)
  131. folium.PolyLine(
  132. locations=[p_real, p_pred], color="#ffeb3b",
  133. weight=4, opacity=0.7, dash_array='8'
  134. ).add_to(m)
  135. st_folium(m, height=700, width=None, key=f"test_map_final", use_container_width=True)