5a6cb2d518
Sovrappone la curva T predetta ai punti di training per ogni sensore, salvando data_fit.html nella stessa cartella degli altri grafici. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
184 lines
7.6 KiB
Python
184 lines
7.6 KiB
Python
import os
|
||
import sys
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from datetime import datetime
|
||
import numpy as np
|
||
import torch
|
||
import plotly.graph_objects as go
|
||
from plotly.subplots import make_subplots
|
||
import config
|
||
|
||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
|
||
def visualize_inverse(T_pred, x_vals, t_vals, T_fdm, identified_params: dict,
|
||
model=None, device=None, x_s=None, t_s=None, T_meas=None):
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
out_dir = os.path.join(BASE_DIR, 'results', 'inverse', timestamp)
|
||
os.makedirs(out_dir, exist_ok=True)
|
||
|
||
x_indices = np.linspace(0, T_fdm.shape[0] - 1, len(x_vals), dtype=int)
|
||
t_indices = np.linspace(0, T_fdm.shape[1] - 1, len(t_vals), dtype=int)
|
||
T_fdm_ds = T_fdm[np.ix_(x_indices, t_indices)]
|
||
|
||
param_str = ' | '.join(f'{k}={v:.4f}' for k, v in identified_params.items())
|
||
subtitle = f'Parametri identificati: {param_str}'
|
||
|
||
colorscale = 'Hot_r'
|
||
zmin = float(min(np.min(T_pred), np.min(T_fdm_ds)))
|
||
zmax = float(max(np.max(T_pred), np.max(T_fdm_ds)))
|
||
|
||
# --- Heatmap ---
|
||
fig_map = make_subplots(
|
||
rows=1, cols=2,
|
||
subplot_titles=["Inverse PINN T(x,t)", "FDM Reference T(x,t)"],
|
||
shared_yaxes=True,
|
||
)
|
||
fig_map.add_trace(go.Heatmap(
|
||
z=T_pred.T, x=x_vals, y=t_vals,
|
||
colorscale=colorscale, zmin=zmin, zmax=zmax,
|
||
colorbar=dict(x=0.46, title='T [°C]'),
|
||
), row=1, col=1)
|
||
fig_map.add_trace(go.Heatmap(
|
||
z=T_fdm_ds.T, x=x_vals, y=t_vals,
|
||
colorscale=colorscale, zmin=zmin, zmax=zmax,
|
||
colorbar=dict(x=1.01, title='T [°C]'),
|
||
), row=1, col=2)
|
||
fig_map.update_xaxes(title_text='x')
|
||
fig_map.update_yaxes(title_text='t', row=1, col=1)
|
||
fig_map.update_layout(title_text=f'Inverse PINN vs FDM<br><sup>{subtitle}</sup>', height=450)
|
||
map_path = os.path.join(out_dir, 'heatmap.html')
|
||
fig_map.write_html(map_path)
|
||
print(f"Heatmap saved → {map_path}")
|
||
|
||
# --- Animated profile T(x) ---
|
||
n_frames = len(t_vals)
|
||
frames = [
|
||
go.Frame(
|
||
data=[
|
||
go.Scatter(x=x_vals, y=T_pred[:, i], mode='lines',
|
||
line=dict(color='royalblue', width=2), name='Inverse PINN'),
|
||
go.Scatter(x=x_vals, y=T_fdm_ds[:, i], mode='lines',
|
||
line=dict(color='firebrick', width=2, dash='dash'), name='FDM'),
|
||
],
|
||
name=str(i),
|
||
layout=go.Layout(title_text=f'Inverse PINN vs FDM | t = {t_vals[i]:.3f}'),
|
||
)
|
||
for i in range(n_frames)
|
||
]
|
||
fig_anim = go.Figure(
|
||
data=frames[0].data,
|
||
layout=go.Layout(
|
||
title=f'Inverse PINN vs FDM | t = {t_vals[0]:.3f}',
|
||
xaxis=dict(title='x [m]', range=[-0.02, config.L + 0.02]),
|
||
yaxis=dict(title='T [°C]', range=[zmin - 1, zmax + 1]),
|
||
legend=dict(x=0.75, y=0.95),
|
||
updatemenus=[dict(
|
||
type='buttons', showactive=False, y=1.15, x=0.5, xanchor='center',
|
||
buttons=[
|
||
dict(label='▶ Play', method='animate',
|
||
args=[None, dict(frame=dict(duration=40, redraw=False),
|
||
fromcurrent=True, mode='immediate')]),
|
||
dict(label='⏸ Pause', method='animate',
|
||
args=[[None], dict(frame=dict(duration=0, redraw=False),
|
||
mode='immediate')]),
|
||
],
|
||
)],
|
||
sliders=[dict(
|
||
steps=[dict(method='animate', args=[[str(i)],
|
||
dict(mode='immediate', frame=dict(duration=0, redraw=False))],
|
||
label=f'{t_vals[i]:.2f}') for i in range(n_frames)],
|
||
transition=dict(duration=0), x=0.05, y=0, len=0.9,
|
||
currentvalue=dict(prefix='t = ', font=dict(size=14)),
|
||
)],
|
||
),
|
||
frames=frames,
|
||
)
|
||
anim_path = os.path.join(out_dir, 'animation.html')
|
||
fig_anim.write_html(anim_path)
|
||
print(f"Animation saved → {anim_path}")
|
||
|
||
# --- Time-series a x=0, x=L/2, x=L ---
|
||
nx = len(x_vals)
|
||
points = [(0, 'x=0', 'blue'), (nx // 2, 'x=L/2', 'green'), (nx - 1, 'x=L', 'red')]
|
||
fig_ts = go.Figure()
|
||
for idx, label, color in points:
|
||
fig_ts.add_trace(go.Scatter(x=t_vals, y=T_pred[idx, :], mode='lines',
|
||
line=dict(color=color, width=2), name=f'Inv.PINN {label}'))
|
||
fig_ts.add_trace(go.Scatter(x=t_vals, y=T_fdm_ds[idx, :], mode='lines',
|
||
line=dict(color=color, width=2, dash='dash'), name=f'FDM {label}'))
|
||
fig_ts.add_vline(x=config.T_STEP, line=dict(color='red', dash='dash', width=1.5),
|
||
annotation_text='Heat ON', annotation_position='top right')
|
||
fig_ts.update_layout(
|
||
title=f'Inverse PINN vs FDM — Time Series<br><sup>{subtitle}</sup>',
|
||
xaxis_title='t', yaxis_title='T(x,t)',
|
||
legend=dict(x=0.01, y=0.99), height=500,
|
||
)
|
||
comparison_path = os.path.join(out_dir, 'comparison.html')
|
||
fig_ts.write_html(comparison_path)
|
||
print(f"Comparison saved → {comparison_path}")
|
||
|
||
if model is not None and x_s is not None:
|
||
fit_path = plot_data_fit(model, device, x_s, t_s, T_meas, out_dir)
|
||
print(f"Data fit saved → {fit_path}")
|
||
|
||
|
||
def plot_data_fit(model, device, x_s, t_s, T_meas, out_dir):
|
||
"""Grafico data-fit: curva T_pred per ogni sensore + punti di training sovrapposti."""
|
||
x_s_np = x_s.cpu().numpy() if isinstance(x_s, torch.Tensor) else np.asarray(x_s)
|
||
t_s_np = t_s.cpu().numpy() if isinstance(t_s, torch.Tensor) else np.asarray(t_s)
|
||
T_meas_np = T_meas.cpu().numpy() if isinstance(T_meas, torch.Tensor) else np.asarray(T_meas)
|
||
|
||
sensors = sorted(np.unique(np.round(x_s_np, 6)))
|
||
t_dense = np.linspace(0, config.T_END, 200)
|
||
|
||
colors = [
|
||
'royalblue', 'firebrick', 'seagreen', 'darkorange',
|
||
'mediumpurple', 'deeppink', 'saddlebrown', 'teal',
|
||
]
|
||
|
||
param_str = f'α={model.alpha.item():.5f} k={model.k.item():.5f} h={model.h_conv.item():.4f}'
|
||
fig = go.Figure()
|
||
|
||
for i, x_val in enumerate(sensors):
|
||
color = colors[i % len(colors)]
|
||
|
||
# Curva predetta densa
|
||
x_in = torch.full((len(t_dense),), float(x_val), dtype=torch.float32, device=device)
|
||
t_in = torch.tensor(t_dense, dtype=torch.float32, device=device)
|
||
with torch.no_grad():
|
||
T_line = model(torch.stack([x_in, t_in], dim=1)).cpu().numpy().ravel()
|
||
|
||
label = f'x={x_val:.3f} m'
|
||
fig.add_trace(go.Scatter(
|
||
x=t_dense, y=T_line, mode='lines',
|
||
line=dict(color=color, width=2),
|
||
name=f'PINN {label}',
|
||
legendgroup=label,
|
||
))
|
||
|
||
# Punti di training per questo sensore
|
||
mask = np.abs(x_s_np - x_val) < 1e-5
|
||
fig.add_trace(go.Scatter(
|
||
x=t_s_np[mask], y=T_meas_np[mask], mode='markers',
|
||
marker=dict(color=color, size=6, symbol='circle-open', line=dict(width=1.5)),
|
||
name=f'Misure {label}',
|
||
legendgroup=label,
|
||
))
|
||
|
||
fig.add_vline(x=config.T_STEP, line=dict(color='gray', dash='dash', width=1),
|
||
annotation_text='Heat ON', annotation_position='top right')
|
||
fig.update_layout(
|
||
title=f'Data fit — T(x_sensor, t)<br><sup>Parametri identificati: {param_str}</sup>',
|
||
xaxis_title='t [s]',
|
||
yaxis_title='T [°C]',
|
||
legend=dict(groupclick='toggleitem'),
|
||
height=520,
|
||
)
|
||
|
||
path = os.path.join(out_dir, 'data_fit.html')
|
||
fig.write_html(path)
|
||
return path
|
||
|