187 lines
6.2 KiB
Python
187 lines
6.2 KiB
Python
import os
|
|
import random
|
|
import numpy as np
|
|
import torch
|
|
import torch.optim as optim
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
|
|
import config
|
|
from model import HeatPINN, heat_pinn_loss
|
|
from visualizer import visualize_heat_field
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
|
MODEL_SAVE_PATH = os.path.join(MODELS_DIR, 'best_heat_pinn_model.pth')
|
|
|
|
|
|
def set_seed(seed=42):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def _get_device():
|
|
if torch.cuda.is_available():
|
|
try:
|
|
t = torch.zeros(2, 2).cuda()
|
|
torch.mm(t, t)
|
|
return torch.device('cuda')
|
|
except RuntimeError:
|
|
pass
|
|
if torch.backends.mps.is_available():
|
|
return torch.device('mps')
|
|
return torch.device('cpu')
|
|
|
|
|
|
def prepare_data(N_f=4000, N_ic=400, N_bc=400):
|
|
set_seed(42)
|
|
device = _get_device()
|
|
|
|
# Uniform collocation points
|
|
x_f = torch.rand(N_f, device=device) * config.L
|
|
t_f = torch.rand(N_f, device=device) * config.T_END
|
|
|
|
# Extra points near x=0 (steep Neumann gradient) and t=T_STEP (flux step)
|
|
n_extra = N_f // 4
|
|
x_near0 = torch.rand(n_extra, device=device) * config.L * 0.1 # x in [0, 0.1]
|
|
t_near0 = torch.rand(n_extra, device=device) * config.T_END
|
|
x_step = torch.rand(n_extra, device=device) * config.L
|
|
t_step = config.T_STEP + (torch.rand(n_extra, device=device) - 0.5) * 0.1 # t near T_STEP
|
|
t_step = t_step.clamp(0, config.T_END)
|
|
|
|
x_f = torch.cat([x_f, x_near0, x_step])
|
|
t_f = torch.cat([t_f, t_near0, t_step])
|
|
|
|
x_ic = torch.rand(N_ic, device=device) * config.L
|
|
t_bc = torch.rand(N_bc, device=device) * config.T_END
|
|
|
|
return {'device': device, 'x_f': x_f, 't_f': t_f, 'x_ic': x_ic, 't_bc': t_bc}
|
|
|
|
|
|
def train_model(data, epochs=5000, patience=100):
|
|
device = data['device']
|
|
model = HeatPINN().to(device)
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
|
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=30, min_lr=1e-6)
|
|
|
|
os.makedirs(MODELS_DIR, exist_ok=True)
|
|
best_loss = float('inf')
|
|
patience_counter = 0
|
|
|
|
print(f"\n--- Heat PINN Training (Adam) on {device} ---")
|
|
for epoch in range(epochs):
|
|
model.train()
|
|
optimizer.zero_grad()
|
|
loss, L_pde, L_ic, L_bc = heat_pinn_loss(
|
|
model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc']
|
|
)
|
|
loss.backward()
|
|
optimizer.step()
|
|
scheduler.step(loss.item())
|
|
|
|
if loss.item() < best_loss - 1e-7:
|
|
best_loss = loss.item()
|
|
patience_counter = 0
|
|
torch.save({'state_dict': model.state_dict()}, MODEL_SAVE_PATH)
|
|
else:
|
|
patience_counter += 1
|
|
|
|
if patience_counter >= patience:
|
|
print(f"Early stopping at epoch {epoch + 1}")
|
|
break
|
|
|
|
if (epoch + 1) % 100 == 0:
|
|
print(
|
|
f"Epoch [{epoch+1}/{epochs}] | Loss: {loss.item():.6f} "
|
|
f"| PDE: {L_pde.item():.6f} | IC: {L_ic.item():.6f} | BC: {L_bc.item():.6f}"
|
|
)
|
|
|
|
# L-BFGS fine-tuning phase (standard PINN practice for convergence to better minima)
|
|
print("\n--- L-BFGS fine-tuning ---")
|
|
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device)
|
|
model.load_state_dict(ckpt['state_dict'])
|
|
|
|
lbfgs = optim.LBFGS(model.parameters(), lr=0.1, max_iter=50,
|
|
history_size=50, tolerance_grad=1e-7, line_search_fn='strong_wolfe')
|
|
|
|
_last = {}
|
|
|
|
for step in range(20):
|
|
def closure():
|
|
lbfgs.zero_grad()
|
|
loss, L_pde, L_ic, L_bc = heat_pinn_loss(
|
|
model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc']
|
|
)
|
|
loss.backward()
|
|
_last['loss'] = loss.item()
|
|
_last['pde'] = L_pde.item()
|
|
_last['ic'] = L_ic.item()
|
|
_last['bc'] = L_bc.item()
|
|
return loss
|
|
|
|
lbfgs.step(closure)
|
|
if _last['loss'] < best_loss:
|
|
best_loss = _last['loss']
|
|
torch.save({'state_dict': model.state_dict()}, MODEL_SAVE_PATH)
|
|
if (step + 1) % 5 == 0:
|
|
print(f"L-BFGS step {step+1}/20 | Loss: {_last['loss']:.6f} "
|
|
f"| PDE: {_last['pde']:.6f} | IC: {_last['ic']:.6f} | BC: {_last['bc']:.6f}")
|
|
|
|
print("Training complete! Model saved.")
|
|
|
|
|
|
def _load_model(device):
|
|
if not os.path.exists(MODEL_SAVE_PATH):
|
|
print("Error: model not found. Train the model first.")
|
|
return None
|
|
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device)
|
|
model = HeatPINN().to(device)
|
|
model.load_state_dict(ckpt['state_dict'])
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def _predict_grid(model, device, nx=100, nt=100):
|
|
x_vals = torch.linspace(0, config.L, nx, device=device)
|
|
t_vals = torch.linspace(0, config.T_END, nt, device=device)
|
|
xx, tt = torch.meshgrid(x_vals, t_vals, indexing='ij')
|
|
inp = torch.stack([xx.reshape(-1), tt.reshape(-1)], dim=1)
|
|
with torch.no_grad():
|
|
T_pred = model(inp).reshape(nx, nt).cpu().numpy()
|
|
return T_pred, x_vals.cpu().numpy(), t_vals.cpu().numpy()
|
|
|
|
|
|
def evaluate_model(data):
|
|
model = _load_model(data['device'])
|
|
if model is None:
|
|
return
|
|
T_pred, x_vals, t_vals = _predict_grid(model, data['device'])
|
|
|
|
# FDM reference
|
|
from fdm.solver import solve as fdm_solve
|
|
T_fdm, _, _ = fdm_solve()
|
|
|
|
# Downsample FDM time axis (NX=100, NT=5000) to match PINN grid (100x100)
|
|
nx, nt = T_pred.shape
|
|
t_indices = np.linspace(0, T_fdm.shape[1] - 1, nt, dtype=int)
|
|
T_fdm_ds = T_fdm[:, t_indices] # (100, 100)
|
|
|
|
l2_rel = np.sqrt(np.mean((T_pred - T_fdm_ds) ** 2)) / np.sqrt(np.mean(T_fdm_ds ** 2))
|
|
max_err = np.max(np.abs(T_pred - T_fdm_ds))
|
|
|
|
print(f"\n--- Evaluation vs FDM ---")
|
|
print(f"Relative L2 error vs FDM: {l2_rel:.6f}")
|
|
print(f"Max absolute error: {max_err:.6f} °C\n")
|
|
|
|
|
|
def generate_visualization(data):
|
|
model = _load_model(data['device'])
|
|
if model is None:
|
|
return
|
|
T_pred, x_vals, t_vals = _predict_grid(model, data['device'])
|
|
from fdm.solver import solve as fdm_solve
|
|
T_fdm, _, _ = fdm_solve()
|
|
visualize_heat_field(T_pred, x_vals, t_vals, T_fdm)
|