import os import random import numpy as np import torch import torch.optim as optim from torch.optim.lr_scheduler import ReduceLROnPlateau import config from model import HeatPINN, heat_pinn_loss from visualizer import visualize_heat_field BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODELS_DIR = os.path.join(BASE_DIR, 'models') MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'best_heat_pinn_model.pth') def set_seed(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def _get_device(): if torch.cuda.is_available(): try: t = torch.zeros(2, 2).cuda() torch.mm(t, t) return torch.device('cuda') except RuntimeError: pass if torch.backends.mps.is_available(): return torch.device('mps') return torch.device('cpu') def prepare_data(N_f=None, N_ic=None, N_bc=None): if N_f is None: N_f = config.N_F if N_ic is None: N_ic = config.N_IC if N_bc is None: N_bc = config.N_BC set_seed(42) device = _get_device() # Uniform collocation points x_f = torch.rand(N_f, device=device) * config.L t_f = torch.rand(N_f, device=device) * config.T_END # Extra points near X_SRC (steep gradient at source) and t=T_STEP (flux step) n_extra = N_f // 4 x_near_src = config.X_SRC + (torch.rand(n_extra, device=device) - 0.5) * config.L * 0.1 x_near_src = x_near_src.clamp(0, config.L) t_near_src = torch.rand(n_extra, device=device) * config.T_END x_step = torch.rand(n_extra, device=device) * config.L t_step = config.T_STEP + (torch.rand(n_extra, device=device) - 0.5) * 0.1 # t near T_STEP t_step = t_step.clamp(0, config.T_END) x_f = torch.cat([x_f, x_near_src, x_step]) t_f = torch.cat([t_f, t_near_src, t_step]) x_ic = torch.rand(N_ic, device=device) * config.L t_bc = torch.rand(N_bc, device=device) * config.T_END return {'device': device, 'x_f': x_f, 't_f': t_f, 'x_ic': x_ic, 't_bc': t_bc} def train_model(data, epochs=None, patience=None): if epochs is None: epochs = config.EPOCHS if patience is None: patience = config.PATIENCE device = data['device'] model = HeatPINN().to(device) optimizer = optim.Adam(model.parameters(), lr=config.LR_ADAM) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=config.SCHED_FACTOR, patience=config.SCHED_PATIENCE, min_lr=config.SCHED_MIN_LR) os.makedirs(MODELS_DIR, exist_ok=True) best_loss = float('inf') patience_counter = 0 print(f"\n--- Heat PINN Training (Adam) on {device} ---") for epoch in range(epochs): model.train() optimizer.zero_grad() loss, L_pde, L_ic, L_bc = heat_pinn_loss( model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc'] ) loss.backward() optimizer.step() scheduler.step(loss.item()) if loss.item() < best_loss - 1e-7: best_loss = loss.item() patience_counter = 0 torch.save({'state_dict': model.state_dict()}, MODEL_SAVE_PATH) else: patience_counter += 1 if patience_counter >= patience: print(f"Early stopping at epoch {epoch + 1}") break if (epoch + 1) % 100 == 0: print( f"Epoch [{epoch+1}/{epochs}] | Loss: {loss.item():.6f} " f"| PDE: {L_pde.item():.6f} | IC: {L_ic.item():.6f} | BC: {L_bc.item():.6f}" ) # L-BFGS fine-tuning phase (standard PINN practice for convergence to better minima) print("\n--- L-BFGS fine-tuning ---") ckpt = torch.load(MODEL_SAVE_PATH, map_location=device) model.load_state_dict(ckpt['state_dict']) lbfgs = optim.LBFGS(model.parameters(), lr=config.LR_LBFGS, max_iter=50, history_size=50, tolerance_grad=1e-7, line_search_fn='strong_wolfe') _last = {} for step in range(config.LBFGS_STEPS): def closure(): lbfgs.zero_grad() loss, L_pde, L_ic, L_bc = heat_pinn_loss( model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc'] ) loss.backward() _last['loss'] = loss.item() _last['pde'] = L_pde.item() _last['ic'] = L_ic.item() _last['bc'] = L_bc.item() return loss lbfgs.step(closure) if _last['loss'] < best_loss: best_loss = _last['loss'] torch.save({'state_dict': model.state_dict()}, MODEL_SAVE_PATH) if (step + 1) % 5 == 0: print(f"L-BFGS step {step+1}/{config.LBFGS_STEPS} | Loss: {_last['loss']:.6f} " f"| PDE: {_last['pde']:.6f} | IC: {_last['ic']:.6f} | BC: {_last['bc']:.6f}") print("Training complete! Model saved.") def _load_model(device): if not os.path.exists(MODEL_SAVE_PATH): print("Error: model not found. Train the model first.") return None ckpt = torch.load(MODEL_SAVE_PATH, map_location=device) model = HeatPINN().to(device) model.load_state_dict(ckpt['state_dict']) model.eval() return model def _predict_grid(model, device, nx=100, nt=100): x_vals = torch.linspace(0, config.L, nx, device=device) t_vals = torch.linspace(0, config.T_END, nt, device=device) xx, tt = torch.meshgrid(x_vals, t_vals, indexing='ij') inp = torch.stack([xx.reshape(-1), tt.reshape(-1)], dim=1) with torch.no_grad(): T_pred = model(inp).reshape(nx, nt).cpu().numpy() return T_pred, x_vals.cpu().numpy(), t_vals.cpu().numpy() def evaluate_model(data): model = _load_model(data['device']) if model is None: return T_pred, x_vals, t_vals = _predict_grid(model, data['device']) # FDM reference from fdm.solver import solve as fdm_solve T_fdm, _, _ = fdm_solve() # Downsample FDM grid (NX, NT) to match PINN prediction grid (100x100) nx, nt = T_pred.shape x_indices = np.linspace(0, T_fdm.shape[0] - 1, nx, dtype=int) t_indices = np.linspace(0, T_fdm.shape[1] - 1, nt, dtype=int) T_fdm_ds = T_fdm[np.ix_(x_indices, t_indices)] # (100, 100) l2_rel = np.sqrt(np.mean((T_pred - T_fdm_ds) ** 2)) / np.sqrt(np.mean(T_fdm_ds ** 2)) max_err = np.max(np.abs(T_pred - T_fdm_ds)) print(f"\n--- Evaluation vs FDM ---") print(f"Relative L2 error vs FDM: {l2_rel:.6f}") print(f"Max absolute error: {max_err:.6f} °C\n") def generate_visualization(data): model = _load_model(data['device']) if model is None: return T_pred, x_vals, t_vals = _predict_grid(model, data['device']) from fdm.solver import solve as fdm_solve T_fdm, _, _ = fdm_solve() visualize_heat_field(T_pred, x_vals, t_vals, T_fdm)