Files
pinn/inverse/visualizer.py
T
davide 5a6cb2d518 feat: aggiunge grafico data-fit nella visualizzazione inversa (opzione 4)
Sovrappone la curva T predetta ai punti di training per ogni sensore,
salvando data_fit.html nella stessa cartella degli altri grafici.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 09:41:44 +02:00

184 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from datetime import datetime
import numpy as np
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import config
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def visualize_inverse(T_pred, x_vals, t_vals, T_fdm, identified_params: dict,
model=None, device=None, x_s=None, t_s=None, T_meas=None):
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
out_dir = os.path.join(BASE_DIR, 'results', 'inverse', timestamp)
os.makedirs(out_dir, exist_ok=True)
x_indices = np.linspace(0, T_fdm.shape[0] - 1, len(x_vals), dtype=int)
t_indices = np.linspace(0, T_fdm.shape[1] - 1, len(t_vals), dtype=int)
T_fdm_ds = T_fdm[np.ix_(x_indices, t_indices)]
param_str = ' | '.join(f'{k}={v:.4f}' for k, v in identified_params.items())
subtitle = f'Parametri identificati: {param_str}'
colorscale = 'Hot_r'
zmin = float(min(np.min(T_pred), np.min(T_fdm_ds)))
zmax = float(max(np.max(T_pred), np.max(T_fdm_ds)))
# --- Heatmap ---
fig_map = make_subplots(
rows=1, cols=2,
subplot_titles=["Inverse PINN T(x,t)", "FDM Reference T(x,t)"],
shared_yaxes=True,
)
fig_map.add_trace(go.Heatmap(
z=T_pred.T, x=x_vals, y=t_vals,
colorscale=colorscale, zmin=zmin, zmax=zmax,
colorbar=dict(x=0.46, title='T [°C]'),
), row=1, col=1)
fig_map.add_trace(go.Heatmap(
z=T_fdm_ds.T, x=x_vals, y=t_vals,
colorscale=colorscale, zmin=zmin, zmax=zmax,
colorbar=dict(x=1.01, title='T [°C]'),
), row=1, col=2)
fig_map.update_xaxes(title_text='x')
fig_map.update_yaxes(title_text='t', row=1, col=1)
fig_map.update_layout(title_text=f'Inverse PINN vs FDM<br><sup>{subtitle}</sup>', height=450)
map_path = os.path.join(out_dir, 'heatmap.html')
fig_map.write_html(map_path)
print(f"Heatmap saved → {map_path}")
# --- Animated profile T(x) ---
n_frames = len(t_vals)
frames = [
go.Frame(
data=[
go.Scatter(x=x_vals, y=T_pred[:, i], mode='lines',
line=dict(color='royalblue', width=2), name='Inverse PINN'),
go.Scatter(x=x_vals, y=T_fdm_ds[:, i], mode='lines',
line=dict(color='firebrick', width=2, dash='dash'), name='FDM'),
],
name=str(i),
layout=go.Layout(title_text=f'Inverse PINN vs FDM | t = {t_vals[i]:.3f}'),
)
for i in range(n_frames)
]
fig_anim = go.Figure(
data=frames[0].data,
layout=go.Layout(
title=f'Inverse PINN vs FDM | t = {t_vals[0]:.3f}',
xaxis=dict(title='x [m]', range=[-0.02, config.L + 0.02]),
yaxis=dict(title='T [°C]', range=[zmin - 1, zmax + 1]),
legend=dict(x=0.75, y=0.95),
updatemenus=[dict(
type='buttons', showactive=False, y=1.15, x=0.5, xanchor='center',
buttons=[
dict(label='▶ Play', method='animate',
args=[None, dict(frame=dict(duration=40, redraw=False),
fromcurrent=True, mode='immediate')]),
dict(label='⏸ Pause', method='animate',
args=[[None], dict(frame=dict(duration=0, redraw=False),
mode='immediate')]),
],
)],
sliders=[dict(
steps=[dict(method='animate', args=[[str(i)],
dict(mode='immediate', frame=dict(duration=0, redraw=False))],
label=f'{t_vals[i]:.2f}') for i in range(n_frames)],
transition=dict(duration=0), x=0.05, y=0, len=0.9,
currentvalue=dict(prefix='t = ', font=dict(size=14)),
)],
),
frames=frames,
)
anim_path = os.path.join(out_dir, 'animation.html')
fig_anim.write_html(anim_path)
print(f"Animation saved → {anim_path}")
# --- Time-series a x=0, x=L/2, x=L ---
nx = len(x_vals)
points = [(0, 'x=0', 'blue'), (nx // 2, 'x=L/2', 'green'), (nx - 1, 'x=L', 'red')]
fig_ts = go.Figure()
for idx, label, color in points:
fig_ts.add_trace(go.Scatter(x=t_vals, y=T_pred[idx, :], mode='lines',
line=dict(color=color, width=2), name=f'Inv.PINN {label}'))
fig_ts.add_trace(go.Scatter(x=t_vals, y=T_fdm_ds[idx, :], mode='lines',
line=dict(color=color, width=2, dash='dash'), name=f'FDM {label}'))
fig_ts.add_vline(x=config.T_STEP, line=dict(color='red', dash='dash', width=1.5),
annotation_text='Heat ON', annotation_position='top right')
fig_ts.update_layout(
title=f'Inverse PINN vs FDM — Time Series<br><sup>{subtitle}</sup>',
xaxis_title='t', yaxis_title='T(x,t)',
legend=dict(x=0.01, y=0.99), height=500,
)
comparison_path = os.path.join(out_dir, 'comparison.html')
fig_ts.write_html(comparison_path)
print(f"Comparison saved → {comparison_path}")
if model is not None and x_s is not None:
fit_path = plot_data_fit(model, device, x_s, t_s, T_meas, out_dir)
print(f"Data fit saved → {fit_path}")
def plot_data_fit(model, device, x_s, t_s, T_meas, out_dir):
"""Grafico data-fit: curva T_pred per ogni sensore + punti di training sovrapposti."""
x_s_np = x_s.cpu().numpy() if isinstance(x_s, torch.Tensor) else np.asarray(x_s)
t_s_np = t_s.cpu().numpy() if isinstance(t_s, torch.Tensor) else np.asarray(t_s)
T_meas_np = T_meas.cpu().numpy() if isinstance(T_meas, torch.Tensor) else np.asarray(T_meas)
sensors = sorted(np.unique(np.round(x_s_np, 6)))
t_dense = np.linspace(0, config.T_END, 200)
colors = [
'royalblue', 'firebrick', 'seagreen', 'darkorange',
'mediumpurple', 'deeppink', 'saddlebrown', 'teal',
]
param_str = f'α={model.alpha.item():.5f} k={model.k.item():.5f} h={model.h_conv.item():.4f}'
fig = go.Figure()
for i, x_val in enumerate(sensors):
color = colors[i % len(colors)]
# Curva predetta densa
x_in = torch.full((len(t_dense),), float(x_val), dtype=torch.float32, device=device)
t_in = torch.tensor(t_dense, dtype=torch.float32, device=device)
with torch.no_grad():
T_line = model(torch.stack([x_in, t_in], dim=1)).cpu().numpy().ravel()
label = f'x={x_val:.3f} m'
fig.add_trace(go.Scatter(
x=t_dense, y=T_line, mode='lines',
line=dict(color=color, width=2),
name=f'PINN {label}',
legendgroup=label,
))
# Punti di training per questo sensore
mask = np.abs(x_s_np - x_val) < 1e-5
fig.add_trace(go.Scatter(
x=t_s_np[mask], y=T_meas_np[mask], mode='markers',
marker=dict(color=color, size=6, symbol='circle-open', line=dict(width=1.5)),
name=f'Misure {label}',
legendgroup=label,
))
fig.add_vline(x=config.T_STEP, line=dict(color='gray', dash='dash', width=1),
annotation_text='Heat ON', annotation_position='top right')
fig.update_layout(
title=f'Data fit — T(x_sensor, t)<br><sup>Parametri identificati: {param_str}</sup>',
xaxis_title='t [s]',
yaxis_title='T [°C]',
legend=dict(groupclick='toggleitem'),
height=520,
)
path = os.path.join(out_dir, 'data_fit.html')
fig.write_html(path)
return path