From bca829bd7e4be8d3383535e460bb8583b88b1e7d Mon Sep 17 00:00:00 2001 From: Davide Grilli Date: Thu, 14 May 2026 14:21:17 +0200 Subject: [PATCH] PINN: sposta model.train() fuori dal loop e aggiunge weights_only a torch.load Co-Authored-By: Claude Sonnet 4.6 --- engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/engine.py b/engine.py index 6690a15..cc3d6ea 100644 --- a/engine.py +++ b/engine.py @@ -80,8 +80,8 @@ def train_model(data, epochs=None, patience=None): patience_counter = 0 print(f"\n--- Heat PINN Training (Adam) on {device} ---") + model.train() for epoch in range(epochs): - model.train() optimizer.zero_grad() loss, L_pde, L_ic, L_bc = heat_pinn_loss( model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc'] @@ -109,7 +109,7 @@ def train_model(data, epochs=None, patience=None): # L-BFGS fine-tuning phase (standard PINN practice for convergence to better minima) print("\n--- L-BFGS fine-tuning ---") - ckpt = torch.load(MODEL_SAVE_PATH, map_location=device) + ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True) model.load_state_dict(ckpt['state_dict']) lbfgs = optim.LBFGS(model.parameters(), lr=config.LR_LBFGS, max_iter=50, @@ -145,7 +145,7 @@ def _load_model(device): if not os.path.exists(MODEL_SAVE_PATH): print("Error: model not found. Train the model first.") return None - ckpt = torch.load(MODEL_SAVE_PATH, map_location=device) + ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True) model = HeatPINN().to(device) model.load_state_dict(ckpt['state_dict']) model.eval()