PINN: sposta model.train() fuori dal loop e aggiunge weights_only a torch.load

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-14 14:21:17 +02:00
parent 98bfc78651
commit bca829bd7e
+3 -3
View File
@@ -80,8 +80,8 @@ def train_model(data, epochs=None, patience=None):
patience_counter = 0
print(f"\n--- Heat PINN Training (Adam) on {device} ---")
model.train()
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
loss, L_pde, L_ic, L_bc = heat_pinn_loss(
model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc']
@@ -109,7 +109,7 @@ def train_model(data, epochs=None, patience=None):
# L-BFGS fine-tuning phase (standard PINN practice for convergence to better minima)
print("\n--- L-BFGS fine-tuning ---")
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device)
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True)
model.load_state_dict(ckpt['state_dict'])
lbfgs = optim.LBFGS(model.parameters(), lr=config.LR_LBFGS, max_iter=50,
@@ -145,7 +145,7 @@ def _load_model(device):
if not os.path.exists(MODEL_SAVE_PATH):
print("Error: model not found. Train the model first.")
return None
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device)
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True)
model = HeatPINN().to(device)
model.load_state_dict(ckpt['state_dict'])
model.eval()