PINN: sposta model.train() fuori dal loop e aggiunge weights_only a torch.load
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -80,8 +80,8 @@ def train_model(data, epochs=None, patience=None):
|
||||
patience_counter = 0
|
||||
|
||||
print(f"\n--- Heat PINN Training (Adam) on {device} ---")
|
||||
model.train()
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
loss, L_pde, L_ic, L_bc = heat_pinn_loss(
|
||||
model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc']
|
||||
@@ -109,7 +109,7 @@ def train_model(data, epochs=None, patience=None):
|
||||
|
||||
# L-BFGS fine-tuning phase (standard PINN practice for convergence to better minima)
|
||||
print("\n--- L-BFGS fine-tuning ---")
|
||||
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device)
|
||||
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt['state_dict'])
|
||||
|
||||
lbfgs = optim.LBFGS(model.parameters(), lr=config.LR_LBFGS, max_iter=50,
|
||||
@@ -145,7 +145,7 @@ def _load_model(device):
|
||||
if not os.path.exists(MODEL_SAVE_PATH):
|
||||
print("Error: model not found. Train the model first.")
|
||||
return None
|
||||
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device)
|
||||
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True)
|
||||
model = HeatPINN().to(device)
|
||||
model.load_state_dict(ckpt['state_dict'])
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user