import os from datetime import datetime import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots import config BASE_DIR = os.path.dirname(os.path.abspath(__file__)) def visualize_heat_field(T_pred, x_vals, t_vals, T_fdm): timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') out_dir = os.path.join(BASE_DIR, 'results', 'pinn', timestamp) os.makedirs(out_dir, exist_ok=True) # Downsample T_fdm from shape (NX_fdm, NT_fdm) to match PINN grid nx_pred = len(x_vals) nt_pred = len(t_vals) x_indices = np.linspace(0, T_fdm.shape[0] - 1, nx_pred, dtype=int) t_indices = np.linspace(0, T_fdm.shape[1] - 1, nt_pred, dtype=int) T_fdm_ds = T_fdm[np.ix_(x_indices, t_indices)] # --- Static heatmap: PINN vs FDM --- fig_map = make_subplots( rows=1, cols=2, subplot_titles=["PINN Prediction T(x,t)", "FDM Reference T(x,t)"], shared_yaxes=True, ) colorscale = 'Hot_r' zmin = float(min(np.min(T_pred), np.min(T_fdm_ds))) zmax = float(max(np.max(T_pred), np.max(T_fdm_ds))) fig_map.add_trace(go.Heatmap( z=T_pred.T, x=x_vals, y=t_vals, colorscale=colorscale, zmin=zmin, zmax=zmax, colorbar=dict(x=0.46, title='T [°C]'), ), row=1, col=1) fig_map.add_trace(go.Heatmap( z=T_fdm_ds.T, x=x_vals, y=t_vals, colorscale=colorscale, zmin=zmin, zmax=zmax, colorbar=dict(x=1.01, title='T [°C]'), ), row=1, col=2) fig_map.update_xaxes(title_text='x') fig_map.update_yaxes(title_text='t', row=1, col=1) fig_map.update_layout(title_text='Heat Equation PINN vs FDM', height=450) map_path = os.path.join(out_dir, 'heatmap.html') fig_map.write_html(map_path) print(f"Heatmap saved → {map_path}") # --- Animated profile T(x) evolving in time --- n_frames = len(t_vals) frames = [] for i in range(n_frames): frames.append(go.Frame( data=[ go.Scatter(x=x_vals, y=T_pred[:, i], mode='lines', line=dict(color='royalblue', width=2), name='PINN'), go.Scatter(x=x_vals, y=T_fdm_ds[:, i], mode='lines', line=dict(color='firebrick', width=2, dash='dash'), name='FDM'), ], name=str(i), layout=go.Layout(title_text=f'Heat Equation PINN vs FDM | t = {t_vals[i]:.3f}'), )) fig_anim = go.Figure( data=frames[0].data, layout=go.Layout( title=f'Heat Equation PINN vs FDM | t = {t_vals[0]:.3f}', xaxis=dict(title='x [m]', range=[-0.02, config.L + 0.02]), yaxis=dict(title='T [°C]', range=[zmin - 1, zmax + 1]), legend=dict(x=0.75, y=0.95), updatemenus=[dict( type='buttons', showactive=False, y=1.15, x=0.5, xanchor='center', buttons=[ dict(label='▶ Play', method='animate', args=[None, dict(frame=dict(duration=40, redraw=False), fromcurrent=True, mode='immediate')]), dict(label='⏸ Pause', method='animate', args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')]), ], )], sliders=[dict( steps=[dict(method='animate', args=[[str(i)], dict(mode='immediate', frame=dict(duration=0, redraw=False))], label=f'{t_vals[i]:.2f}') for i in range(n_frames)], transition=dict(duration=0), x=0.05, y=0, len=0.9, currentvalue=dict(prefix='t = ', font=dict(size=14)), )], ), frames=frames, ) anim_path = os.path.join(out_dir, 'animation.html') fig_anim.write_html(anim_path) print(f"Animation saved → {anim_path}") # --- Time-series comparison at fixed spatial points --- # Spatial indices for x=0, x=L/2, x=L nx = len(x_vals) idx_x0 = 0 idx_xmid = nx // 2 idx_xL = nx - 1 points = [ (idx_x0, 'x=0', 'blue'), (idx_xmid, 'x=L/2', 'green'), (idx_xL, 'x=L', 'red'), ] fig_ts = go.Figure() for idx, label, color in points: fig_ts.add_trace(go.Scatter( x=t_vals, y=T_pred[idx, :], mode='lines', line=dict(color=color, width=2, dash='solid'), name=f'PINN {label}', )) fig_ts.add_trace(go.Scatter( x=t_vals, y=T_fdm_ds[idx, :], mode='lines', line=dict(color=color, width=2, dash='dash'), name=f'FDM {label}', )) # Vertical dashed line at T_STEP ("Heat ON") fig_ts.add_vline( x=config.T_STEP, line=dict(color='red', dash='dash', width=1.5), annotation_text='Heat ON', annotation_position='top right', ) fig_ts.update_layout( title='Heat Equation PINN vs FDM — Time Series at Fixed Points', xaxis_title='t', yaxis_title='T(x,t)', legend=dict(x=0.01, y=0.99), height=500, ) comparison_path = os.path.join(out_dir, 'comparison.html') fig_ts.write_html(comparison_path) print(f"Comparison saved → {comparison_path}")