e868c47190
- forward(): divide t per T_END prima di passarlo alla rete, evita saturazione di Tanh per t∈[0,10] e migliora la sensibilità temporale del modello - _pde_scale: include il picco gaussiano della sorgente come denominatore; con GAUSS_SIGMA=0.01 il picco (~60 °C/s) supera T_char/T_END (15), rendendo la loss PDE non normalizzata senza questa correzione - PATIENCE 100→500, SCHED_PATIENCE 30→150: il training ha ora spazio per convergere prima che l'early stopping o lo scheduler blocchino l'ottimizzatore Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
73 lines
3.5 KiB
Python
73 lines
3.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import config
|
|
|
|
|
|
class HeatPINN(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
h = config.HIDDEN_SIZE
|
|
layers = [nn.Linear(2, h), nn.Tanh()]
|
|
for _ in range(config.N_HIDDEN_LAYERS - 1):
|
|
layers += [nn.Linear(h, h), nn.Tanh()]
|
|
layers.append(nn.Linear(h, 1))
|
|
self.net = nn.Sequential(*layers)
|
|
|
|
def forward(self, xt):
|
|
# Normalize t to [0,1] to avoid Tanh saturation for large t values (up to T_END)
|
|
x = xt[:, :1]
|
|
t = xt[:, 1:] / config.T_END
|
|
return config.T_AMB + (config.Q_VAL * config.L / config.K) * self.net(torch.cat([x, t], dim=1))
|
|
|
|
|
|
# Precomputed loss scales (depend only on config constants)
|
|
_T_char = config.Q_VAL * config.L / config.K # ~150 °C — temperature scale
|
|
_grad_char = (config.Q_VAL / config.K) ** 2 # ~22500 — gradient scale²
|
|
# _pde_scale covers both dT/dt and the Gaussian source peak (dominant with small sigma)
|
|
_src_peak = config.ALPHA * config.Q_VAL / (config.K * config.GAUSS_SIGMA * (2 * 3.141592653589793) ** 0.5)
|
|
_pde_scale = max((_T_char / config.T_END) ** 2, _src_peak ** 2) + 1e-8
|
|
|
|
|
|
def heat_pinn_loss(model, x_f, t_f, x_ic, t_bc,
|
|
w_pde=None, w_ic=None, w_bc=None):
|
|
if w_pde is None: w_pde = config.W_PDE
|
|
if w_ic is None: w_ic = config.W_IC
|
|
if w_bc is None: w_bc = config.W_BC
|
|
T_char = _T_char
|
|
grad_char = _grad_char
|
|
|
|
# PDE residual: dT/dt - alpha * d2T/dx2 - source(x,t) = 0 (normalized by T_char/t_char)
|
|
x_f = x_f.detach().requires_grad_(True)
|
|
t_f = t_f.detach().requires_grad_(True)
|
|
T_f = model(torch.stack([x_f, t_f], dim=1))
|
|
dT_dt, dT_dx = torch.autograd.grad(T_f.sum(), [t_f, x_f], create_graph=True)
|
|
d2T_dx2 = torch.autograd.grad(dT_dx.sum(), x_f, create_graph=True)[0]
|
|
Q_t_f = torch.where(t_f >= config.T_STEP,
|
|
torch.tensor(config.Q_VAL, device=t_f.device, dtype=t_f.dtype),
|
|
torch.tensor(0.0, device=t_f.device, dtype=t_f.dtype))
|
|
sigma = config.GAUSS_SIGMA
|
|
gauss = torch.exp(-0.5 * ((x_f - config.X_SRC) / sigma) ** 2) / (sigma * (2 * torch.pi) ** 0.5)
|
|
source_term = (config.ALPHA / config.K) * Q_t_f * gauss
|
|
L_pde = ((dT_dt - config.ALPHA * d2T_dx2 - source_term) ** 2).mean() / _pde_scale
|
|
|
|
# IC: T(x, 0) = T0 — normalized by T_char²
|
|
T_ic_pred = model(torch.stack([x_ic, torch.zeros_like(x_ic)], dim=1))
|
|
T_ic_true = torch.full_like(T_ic_pred, config.T0)
|
|
L_ic = ((T_ic_pred - T_ic_true) ** 2).mean() / (_T_char ** 2 + 1e-8)
|
|
|
|
# BC x=0: Robin — dT/dx + H_CONV/K * (T(0,t) - T_AMB) = 0
|
|
x_left = torch.zeros(t_bc.shape[0], device=t_bc.device).requires_grad_(True)
|
|
T_left = model(torch.stack([x_left, t_bc.detach()], dim=1))
|
|
dT_dx_left = torch.autograd.grad(T_left.sum(), x_left, create_graph=True)[0]
|
|
L_bc_left = ((dT_dx_left + (config.H_CONV / config.K) * (T_left.squeeze() - config.T_AMB)) ** 2).mean() / grad_char
|
|
|
|
# BC x=L: Robin — dT/dx + H_CONV/K * (T(L,t) - T_AMB) = 0
|
|
x_right = torch.full((t_bc.shape[0],), config.L, device=t_bc.device).requires_grad_(True)
|
|
T_right = model(torch.stack([x_right, t_bc.detach()], dim=1))
|
|
dT_dx_right = torch.autograd.grad(T_right.sum(), x_right, create_graph=True)[0]
|
|
L_bc_right = ((dT_dx_right + (config.H_CONV / config.K) * (T_right.squeeze() - config.T_AMB)) ** 2).mean() / grad_char
|
|
|
|
L_bc = L_bc_left + L_bc_right
|
|
total = w_pde * L_pde + w_ic * L_ic + w_bc * L_bc
|
|
return total, L_pde, L_ic, L_bc
|