Files
davide 9e77deffd5 PINN: risolve problemi minori — sigma in config, scale precompilate, closure fuori loop
- config.py: aggiunge GAUSS_SIGMA = 0.02 nella sezione parametri fisici
- model.py: T_char, grad_char, pde_scale diventano costanti di modulo (_T_char,
  _grad_char, _pde_scale) calcolate una sola volta all'import
- engine.py: closure L-BFGS definita una volta sola fuori dal loop

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-14 14:23:33 +02:00

197 lines
6.8 KiB
Python

import os
import random
import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import config
from model import HeatPINN, heat_pinn_loss
from visualizer import visualize_heat_field
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_DIR, 'models')
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'best_heat_pinn_model.pth')
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _get_device():
if torch.cuda.is_available():
try:
t = torch.zeros(2, 2).cuda()
torch.mm(t, t)
return torch.device('cuda')
except RuntimeError:
pass
if torch.backends.mps.is_available():
return torch.device('mps')
return torch.device('cpu')
def prepare_data(N_f=None, N_ic=None, N_bc=None):
if N_f is None: N_f = config.N_F
if N_ic is None: N_ic = config.N_IC
if N_bc is None: N_bc = config.N_BC
set_seed(42)
device = _get_device()
# Uniform collocation points
x_f = torch.rand(N_f, device=device) * config.L
t_f = torch.rand(N_f, device=device) * config.T_END
# Extra points near X_SRC (steep gradient at source) and t=T_STEP (flux step)
n_extra = N_f // 4
x_near_src = config.X_SRC + (torch.rand(n_extra, device=device) - 0.5) * config.L * 0.1
x_near_src = x_near_src.clamp(0, config.L)
t_near_src = torch.rand(n_extra, device=device) * config.T_END
x_step = torch.rand(n_extra, device=device) * config.L
t_step = config.T_STEP + (torch.rand(n_extra, device=device) - 0.5) * 0.1 # t near T_STEP
t_step = t_step.clamp(0, config.T_END)
x_f = torch.cat([x_f, x_near_src, x_step])
t_f = torch.cat([t_f, t_near_src, t_step])
x_ic = torch.rand(N_ic, device=device) * config.L
t_bc = torch.rand(N_bc, device=device) * config.T_END
return {'device': device, 'x_f': x_f, 't_f': t_f, 'x_ic': x_ic, 't_bc': t_bc}
def train_model(data, epochs=None, patience=None):
if epochs is None: epochs = config.EPOCHS
if patience is None: patience = config.PATIENCE
device = data['device']
model = HeatPINN().to(device)
optimizer = optim.Adam(model.parameters(), lr=config.LR_ADAM)
scheduler = ReduceLROnPlateau(optimizer, mode='min',
factor=config.SCHED_FACTOR,
patience=config.SCHED_PATIENCE,
min_lr=config.SCHED_MIN_LR)
os.makedirs(MODELS_DIR, exist_ok=True)
best_loss = float('inf')
patience_counter = 0
print(f"\n--- Heat PINN Training (Adam) on {device} ---")
model.train()
for epoch in range(epochs):
optimizer.zero_grad()
loss, L_pde, L_ic, L_bc = heat_pinn_loss(
model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc']
)
loss.backward()
optimizer.step()
scheduler.step(loss.item())
if loss.item() < best_loss - 1e-7:
best_loss = loss.item()
patience_counter = 0
torch.save({'state_dict': model.state_dict()}, MODEL_SAVE_PATH)
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch + 1}")
break
if (epoch + 1) % 100 == 0:
print(
f"Epoch [{epoch+1}/{epochs}] | Loss: {loss.item():.6f} "
f"| PDE: {L_pde.item():.6f} | IC: {L_ic.item():.6f} | BC: {L_bc.item():.6f}"
)
# L-BFGS fine-tuning phase (standard PINN practice for convergence to better minima)
print("\n--- L-BFGS fine-tuning ---")
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True)
model.load_state_dict(ckpt['state_dict'])
lbfgs = optim.LBFGS(model.parameters(), lr=config.LR_LBFGS, max_iter=50,
history_size=50, tolerance_grad=1e-7, line_search_fn='strong_wolfe')
_last = {}
def closure():
lbfgs.zero_grad()
loss, L_pde, L_ic, L_bc = heat_pinn_loss(
model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc']
)
loss.backward()
_last['loss'] = loss.item()
_last['pde'] = L_pde.item()
_last['ic'] = L_ic.item()
_last['bc'] = L_bc.item()
return loss
for step in range(config.LBFGS_STEPS):
lbfgs.step(closure)
if _last['loss'] < best_loss:
best_loss = _last['loss']
torch.save({'state_dict': model.state_dict()}, MODEL_SAVE_PATH)
if (step + 1) % 5 == 0:
print(f"L-BFGS step {step+1}/{config.LBFGS_STEPS} | Loss: {_last['loss']:.6f} "
f"| PDE: {_last['pde']:.6f} | IC: {_last['ic']:.6f} | BC: {_last['bc']:.6f}")
print("Training complete! Model saved.")
def _load_model(device):
if not os.path.exists(MODEL_SAVE_PATH):
print("Error: model not found. Train the model first.")
return None
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True)
model = HeatPINN().to(device)
model.load_state_dict(ckpt['state_dict'])
model.eval()
return model
def _predict_grid(model, device, nx=100, nt=100):
x_vals = torch.linspace(0, config.L, nx, device=device)
t_vals = torch.linspace(0, config.T_END, nt, device=device)
xx, tt = torch.meshgrid(x_vals, t_vals, indexing='ij')
inp = torch.stack([xx.reshape(-1), tt.reshape(-1)], dim=1)
with torch.no_grad():
T_pred = model(inp).reshape(nx, nt).cpu().numpy()
return T_pred, x_vals.cpu().numpy(), t_vals.cpu().numpy()
def evaluate_model(data):
model = _load_model(data['device'])
if model is None:
return
T_pred, x_vals, t_vals = _predict_grid(model, data['device'])
# FDM reference
from fdm.solver import solve as fdm_solve
T_fdm, _, _ = fdm_solve()
# Downsample FDM grid (NX, NT) to match PINN prediction grid (100x100)
nx, nt = T_pred.shape
x_indices = np.linspace(0, T_fdm.shape[0] - 1, nx, dtype=int)
t_indices = np.linspace(0, T_fdm.shape[1] - 1, nt, dtype=int)
T_fdm_ds = T_fdm[np.ix_(x_indices, t_indices)] # (100, 100)
l2_rel = np.sqrt(np.mean((T_pred - T_fdm_ds) ** 2)) / np.sqrt(np.mean(T_fdm_ds ** 2))
max_err = np.max(np.abs(T_pred - T_fdm_ds))
print(f"\n--- Evaluation vs FDM ---")
print(f"Relative L2 error vs FDM: {l2_rel:.6f}")
print(f"Max absolute error: {max_err:.6f} °C\n")
def generate_visualization(data):
model = _load_model(data['device'])
if model is None:
return
T_pred, x_vals, t_vals = _predict_grid(model, data['device'])
from fdm.solver import solve as fdm_solve
T_fdm, _, _ = fdm_solve()
visualize_heat_field(T_pred, x_vals, t_vals, T_fdm)