feat: aggiunge grafico data-fit nella visualizzazione inversa (opzione 4)

Sovrappone la curva T predetta ai punti di training per ogni sensore,
salvando data_fit.html nella stessa cartella degli altri grafici.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-21 09:41:44 +02:00
parent b92c39ffb0
commit 5a6cb2d518
2 changed files with 76 additions and 2 deletions
+10 -1
View File
@@ -183,6 +183,7 @@ def evaluate_inverse():
print(f"Errore massimo assoluto: {max_err:.4f} °C")
def generate_visualization_inverse():
device = _get_device()
@@ -208,5 +209,13 @@ def generate_visualization_inverse():
identified_params = {'α': model.alpha.item(), 'k': model.k.item(), 'h': model.h_conv.item()}
x_s = t_s = T_meas = None
from inverse.data import load_measurements
try:
x_s, t_s, T_meas = load_measurements(device)
except FileNotFoundError:
print("[skip data-fit] measurements.csv non trovato — esegui prima 'Carica misure'.")
from inverse.visualizer import visualize_inverse
visualize_inverse(T_pred, x_vals.cpu().numpy(), t_vals.cpu().numpy(), T_fdm, identified_params)
visualize_inverse(T_pred, x_vals.cpu().numpy(), t_vals.cpu().numpy(), T_fdm, identified_params,
model=model, device=device, x_s=x_s, t_s=t_s, T_meas=T_meas)
+66 -1
View File
@@ -4,6 +4,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from datetime import datetime
import numpy as np
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import config
@@ -11,7 +12,8 @@ import config
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def visualize_inverse(T_pred, x_vals, t_vals, T_fdm, identified_params: dict):
def visualize_inverse(T_pred, x_vals, t_vals, T_fdm, identified_params: dict,
model=None, device=None, x_s=None, t_s=None, T_meas=None):
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
out_dir = os.path.join(BASE_DIR, 'results', 'inverse', timestamp)
os.makedirs(out_dir, exist_ok=True)
@@ -116,3 +118,66 @@ def visualize_inverse(T_pred, x_vals, t_vals, T_fdm, identified_params: dict):
comparison_path = os.path.join(out_dir, 'comparison.html')
fig_ts.write_html(comparison_path)
print(f"Comparison saved → {comparison_path}")
if model is not None and x_s is not None:
fit_path = plot_data_fit(model, device, x_s, t_s, T_meas, out_dir)
print(f"Data fit saved → {fit_path}")
def plot_data_fit(model, device, x_s, t_s, T_meas, out_dir):
"""Grafico data-fit: curva T_pred per ogni sensore + punti di training sovrapposti."""
x_s_np = x_s.cpu().numpy() if isinstance(x_s, torch.Tensor) else np.asarray(x_s)
t_s_np = t_s.cpu().numpy() if isinstance(t_s, torch.Tensor) else np.asarray(t_s)
T_meas_np = T_meas.cpu().numpy() if isinstance(T_meas, torch.Tensor) else np.asarray(T_meas)
sensors = sorted(np.unique(np.round(x_s_np, 6)))
t_dense = np.linspace(0, config.T_END, 200)
colors = [
'royalblue', 'firebrick', 'seagreen', 'darkorange',
'mediumpurple', 'deeppink', 'saddlebrown', 'teal',
]
param_str = f'α={model.alpha.item():.5f} k={model.k.item():.5f} h={model.h_conv.item():.4f}'
fig = go.Figure()
for i, x_val in enumerate(sensors):
color = colors[i % len(colors)]
# Curva predetta densa
x_in = torch.full((len(t_dense),), float(x_val), dtype=torch.float32, device=device)
t_in = torch.tensor(t_dense, dtype=torch.float32, device=device)
with torch.no_grad():
T_line = model(torch.stack([x_in, t_in], dim=1)).cpu().numpy().ravel()
label = f'x={x_val:.3f} m'
fig.add_trace(go.Scatter(
x=t_dense, y=T_line, mode='lines',
line=dict(color=color, width=2),
name=f'PINN {label}',
legendgroup=label,
))
# Punti di training per questo sensore
mask = np.abs(x_s_np - x_val) < 1e-5
fig.add_trace(go.Scatter(
x=t_s_np[mask], y=T_meas_np[mask], mode='markers',
marker=dict(color=color, size=6, symbol='circle-open', line=dict(width=1.5)),
name=f'Misure {label}',
legendgroup=label,
))
fig.add_vline(x=config.T_STEP, line=dict(color='gray', dash='dash', width=1),
annotation_text='Heat ON', annotation_position='top right')
fig.update_layout(
title=f'Data fit — T(x_sensor, t)<br><sup>Parametri identificati: {param_str}</sup>',
xaxis_title='t [s]',
yaxis_title='T [°C]',
legend=dict(groupclick='toggleitem'),
height=520,
)
path = os.path.join(out_dir, 'data_fit.html')
fig.write_html(path)
return path