diff --git a/engine.py b/engine.py index 6690a15..cc3d6ea 100644 --- a/engine.py +++ b/engine.py @@ -80,8 +80,8 @@ def train_model(data, epochs=None, patience=None): patience_counter = 0 print(f"\n--- Heat PINN Training (Adam) on {device} ---") + model.train() for epoch in range(epochs): - model.train() optimizer.zero_grad() loss, L_pde, L_ic, L_bc = heat_pinn_loss( model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc'] @@ -109,7 +109,7 @@ def train_model(data, epochs=None, patience=None): # L-BFGS fine-tuning phase (standard PINN practice for convergence to better minima) print("\n--- L-BFGS fine-tuning ---") - ckpt = torch.load(MODEL_SAVE_PATH, map_location=device) + ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True) model.load_state_dict(ckpt['state_dict']) lbfgs = optim.LBFGS(model.parameters(), lr=config.LR_LBFGS, max_iter=50, @@ -145,7 +145,7 @@ def _load_model(device): if not os.path.exists(MODEL_SAVE_PATH): print("Error: model not found. Train the model first.") return None - ckpt = torch.load(MODEL_SAVE_PATH, map_location=device) + ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True) model = HeatPINN().to(device) model.load_state_dict(ckpt['state_dict']) model.eval()