import os import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots import config BASE_DIR = os.path.dirname(os.path.abspath(__file__)) ANIMATIONS_DIR = os.path.join(BASE_DIR, 'animations') def visualize_heat_field(T_pred, x_vals, t_vals, T_fdm): os.makedirs(ANIMATIONS_DIR, exist_ok=True) # Downsample T_fdm from shape (NX, NT) to (NX, len(t_vals)) nt_pred = len(t_vals) t_indices = np.linspace(0, T_fdm.shape[1] - 1, nt_pred, dtype=int) T_fdm_ds = T_fdm[:, t_indices] # now shape (NX, nt_pred) # --- Static heatmap: PINN vs FDM --- fig_map = make_subplots( rows=1, cols=2, subplot_titles=["PINN Prediction T(x,t)", "FDM Reference T(x,t)"], shared_yaxes=True, ) colorscale = 'Hot_r' zmin = float(min(np.min(T_pred), np.min(T_fdm_ds))) zmax = float(max(np.max(T_pred), np.max(T_fdm_ds))) fig_map.add_trace(go.Heatmap( z=T_pred.T, x=x_vals, y=t_vals, colorscale=colorscale, zmin=zmin, zmax=zmax, colorbar=dict(x=0.46, title='T [°C]'), ), row=1, col=1) fig_map.add_trace(go.Heatmap( z=T_fdm_ds.T, x=x_vals, y=t_vals, colorscale=colorscale, zmin=zmin, zmax=zmax, colorbar=dict(x=1.01, title='T [°C]'), ), row=1, col=2) fig_map.update_xaxes(title_text='x') fig_map.update_yaxes(title_text='t', row=1, col=1) fig_map.update_layout(title_text='Heat Equation PINN vs FDM', height=450) map_path = _next_path('heatmap', '.html') fig_map.write_html(map_path) print(f"Heatmap saved → {map_path}") # --- Animated profile T(x) evolving in time --- n_frames = len(t_vals) frames = [] for i in range(n_frames): frames.append(go.Frame( data=[ go.Scatter(x=x_vals, y=T_pred[:, i], mode='lines', line=dict(color='royalblue', width=2), name='PINN'), go.Scatter(x=x_vals, y=T_fdm_ds[:, i], mode='lines', line=dict(color='firebrick', width=2, dash='dash'), name='FDM'), ], name=str(i), layout=go.Layout(title_text=f'Heat Equation PINN vs FDM | t = {t_vals[i]:.3f}'), )) fig_anim = go.Figure( data=frames[0].data, layout=go.Layout( title=f'Heat Equation PINN vs FDM | t = {t_vals[0]:.3f}', xaxis=dict(title='x [m]', range=[-0.02, config.L + 0.02]), yaxis=dict(title='T [°C]', range=[zmin - 1, zmax + 1]), legend=dict(x=0.75, y=0.95), updatemenus=[dict( type='buttons', showactive=False, y=1.15, x=0.5, xanchor='center', buttons=[ dict(label='▶ Play', method='animate', args=[None, dict(frame=dict(duration=40, redraw=False), fromcurrent=True, mode='immediate')]), dict(label='⏸ Pause', method='animate', args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')]), ], )], sliders=[dict( steps=[dict(method='animate', args=[[str(i)], dict(mode='immediate', frame=dict(duration=0, redraw=False))], label=f'{t_vals[i]:.2f}') for i in range(n_frames)], transition=dict(duration=0), x=0.05, y=0, len=0.9, currentvalue=dict(prefix='t = ', font=dict(size=14)), )], ), frames=frames, ) anim_path = _next_path('heat_animation', '.html') fig_anim.write_html(anim_path) print(f"Animation saved → {anim_path}") # --- Time-series comparison at fixed spatial points --- # Spatial indices for x=0, x=L/2, x=L nx = len(x_vals) idx_x0 = 0 idx_xmid = nx // 2 idx_xL = nx - 1 points = [ (idx_x0, 'x=0', 'blue'), (idx_xmid, 'x=L/2', 'green'), (idx_xL, 'x=L', 'red'), ] fig_ts = go.Figure() for idx, label, color in points: fig_ts.add_trace(go.Scatter( x=t_vals, y=T_pred[idx, :], mode='lines', line=dict(color=color, width=2, dash='solid'), name=f'PINN {label}', )) fig_ts.add_trace(go.Scatter( x=t_vals, y=T_fdm_ds[idx, :], mode='lines', line=dict(color=color, width=2, dash='dash'), name=f'FDM {label}', )) # Vertical dashed line at T_STEP ("Heat ON") fig_ts.add_vline( x=config.T_STEP, line=dict(color='red', dash='dash', width=1.5), annotation_text='Heat ON', annotation_position='top right', ) fig_ts.update_layout( title='Heat Equation PINN vs FDM — Time Series at Fixed Points', xaxis_title='t', yaxis_title='T(x,t)', legend=dict(x=0.01, y=0.99), height=500, ) comparison_path = _next_path('comparison', '.html') fig_ts.write_html(comparison_path) print(f"Time-series saved → {comparison_path}") def _next_path(prefix, ext): i = 1 while True: path = os.path.join(ANIMATIONS_DIR, f'{prefix}_{i:03d}{ext}') if not os.path.exists(path): return path i += 1