You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

147 line
6.2 KiB

  1. import streamlit as st
  2. import pandas as pd
  3. import os
  4. import joblib
  5. import numpy as np
  6. import math
  7. import folium
  8. import json
  9. import base64
  10. from io import BytesIO
  11. from pathlib import Path
  12. from PIL import Image
  13. from streamlit_folium import st_folium
  14. # --- UTILS ---
  15. @st.cache_data
  16. def get_image_base64(img_path):
  17. img = Image.open(img_path).convert("RGBA")
  18. w, h = img.size
  19. buffered = BytesIO()
  20. img.save(buffered, format="PNG")
  21. img_str = base64.b64encode(buffered.getvalue()).decode("ascii")
  22. return f"data:image/png;base64,{img_str}", w, h
  23. def calculate_error(z_real, x_real, y_real, z_pred, x_pred, y_pred):
  24. z_err = abs(z_real - z_pred)
  25. dist_err = math.sqrt((x_real - x_pred)**2 + (y_real - y_pred)**2)
  26. return z_err, dist_err
  27. def show_test_inference(cfg):
  28. st.subheader("🧪 Test Inferenza Offline")
  29. MODEL_DIR = Path("/data/model")
  30. TEST_SAMPLES_DIR = Path("/data/train/testsamples")
  31. MAPS_DIR = Path("/data/maps")
  32. available_models = sorted([m.name for m in MODEL_DIR.glob("model_camp_*.joblib")], reverse=True)
  33. test_files = sorted([f.name for f in TEST_SAMPLES_DIR.glob("*.csv")])
  34. if not available_models or not test_files:
  35. st.warning("Verifica la presenza di modelli e campioni di test.")
  36. return
  37. col1, col2 = st.columns(2)
  38. selected_model = col1.selectbox("🎯 Seleziona Modello:", available_models)
  39. selected_test = col2.selectbox("📄 Seleziona Fingerprint:", test_files)
  40. if "test_results" not in st.session_state:
  41. st.session_state.test_results = None
  42. if st.button("🚀 ESEGUI TEST DI PRECISIONE", type="primary", use_container_width=True):
  43. try:
  44. m_pkg = joblib.load(MODEL_DIR / selected_model)
  45. delim = cfg.get('paths', {}).get('csv_delimiter', ';')
  46. df_test = pd.read_csv(TEST_SAMPLES_DIR / selected_test, sep=delim)
  47. row = df_test.iloc[0]
  48. z_real, x_real, y_real = int(round(float(row['z']))), float(row['x']), float(row['y'])
  49. gws = m_pkg['gateways_order']
  50. fill = m_pkg.get('nan_fill', -110.0)
  51. X_in = np.array([[float(row.get(gw, fill)) for gw in gws]])
  52. z_pred = int(m_pkg['floor_clf'].predict(X_in)[0])
  53. x_pred, y_pred = -1.0, -1.0
  54. if z_pred in m_pkg['xy_by_floor']:
  55. xy = m_pkg['xy_by_floor'][z_pred].predict(X_in)[0]
  56. x_pred, y_pred = float(xy[0]), float(xy[1])
  57. z_err, dist_err = calculate_error(z_real, x_real, y_real, z_pred, x_pred, y_pred)
  58. st.session_state.test_results = {
  59. "z_real": z_real, "x_real": x_real, "y_real": y_real,
  60. "z_pred": z_pred, "x_pred": x_pred, "y_pred": y_pred,
  61. "z_err": z_err, "dist_err": dist_err,
  62. "model_used": selected_model
  63. }
  64. except Exception as e:
  65. st.error(f"Errore: {e}")
  66. if st.session_state.test_results:
  67. res = st.session_state.test_results
  68. # --- RIPRISTINO INFORMAZIONI COORDINATE ---
  69. c_info1, c_info2 = st.columns(2)
  70. with c_info1:
  71. st.info(f"📍 **Test Reale** | Piano: {res['z_real']} | X: **{res['x_real']}** | Y: **{res['y_real']}**")
  72. with c_info2:
  73. st.success(f"🔮 **Predizione** | Piano: {res['z_pred']} | X: **{round(res['x_pred'],1)}** | Y: **{round(res['y_pred'],1)}**")
  74. # --- LEGENDA AGGIORNATA CON NUOVI COLORI ---
  75. st.markdown(f"""
  76. <div style="background-color: #f8f9fa; padding: 12px; border-radius: 8px; border-left: 5px solid #00bcd4; margin: 10px 0;">
  77. <h5 style="margin:0 0 8px 0;">📍 Legenda Mappa</h5>
  78. <span style="color: #00bcd4; font-weight: bold;">● PUNTO DI TEST:</span> Posizione reale del rilievo.<br>
  79. <span style="color: #ff9800; font-weight: bold;">● PUNTO PREDETTO:</span> Posizione calcolata dal modello AI.<br>
  80. <span style="color: #ffeb3b; font-weight: bold;">--- LINEA GIALLA:</span> Scostamento di <b>{round(res['dist_err'], 2)} px</b>.
  81. </div>
  82. """, unsafe_allow_html=True)
  83. img_filename = f"floor_{res['z_pred']}.png"
  84. meta_filename = f"meta_{res['z_pred']}.json"
  85. img_p = MAPS_DIR / img_filename
  86. meta_p = MAPS_DIR / meta_filename
  87. if img_p.exists() and meta_p.exists():
  88. with open(meta_p, "r") as f: meta = json.load(f)
  89. img_data, w, h = get_image_base64(img_p)
  90. bounds = [[0, 0], [h, w]]
  91. m = folium.Map(location=[h/2, w/2], crs="Simple", tiles=None, zoom_start=0, attribution_control=False)
  92. m.fit_bounds(bounds)
  93. m.options.update({"minZoom": -5, "maxZoom": 5, "maxBounds": bounds, "zoomSnap": 0.25})
  94. folium.raster_layers.ImageOverlay(image=img_data, bounds=bounds).add_to(m)
  95. if meta.get("calibrated"):
  96. def to_px(mx, my):
  97. px = (mx * meta["pixel_ratio"]) + meta["origin"][0]
  98. py = meta["origin"][1] - (my * meta["pixel_ratio"])
  99. return [py, px]
  100. p_real = to_px(res['x_real'], res['y_real'])
  101. p_pred = to_px(res['x_pred'], res['y_pred'])
  102. # 🔵 PUNTO DI TEST (Celeste)
  103. folium.CircleMarker(
  104. location=p_real, radius=11, color="#00838f", fill=True,
  105. fill_color="#00bcd4", fill_opacity=0.85,
  106. tooltip="PUNTO DI TEST (REALE)"
  107. ).add_to(m)
  108. # 🟠 PUNTO PREDETTO (Arancio)
  109. if res['x_pred'] != -1.0:
  110. folium.CircleMarker(
  111. location=p_pred, radius=11, color="#e65100", fill=True,
  112. fill_color="#ff9800", fill_opacity=0.85,
  113. tooltip=f"PREDIZIONE (Modello: {res['model_used']})"
  114. ).add_to(m)
  115. # Linea Errore
  116. folium.PolyLine(
  117. locations=[p_real, p_pred], color="#ffeb3b",
  118. weight=4, opacity=0.7, dash_array='8'
  119. ).add_to(m)
  120. st_folium(m, height=700, width=None, key=f"test_map_final", use_container_width=True)