Files
pinn/inverse/engine.py
T

213 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import random
import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import config
from inverse.config_inverse import (
N_F, N_IC, N_BC,
EPOCHS_INV, LR_ADAM_INV, PATIENCE_INV,
SCHED_FACTOR, SCHED_PATIENCE, SCHED_MIN_LR,
MODEL_SAVE_PATH, MODELS_DIR,
)
from inverse.model import InverseHeatPINN
from inverse.loss import inverse_heat_pinn_loss
def _get_device():
if torch.cuda.is_available():
try:
t = torch.zeros(2, 2).cuda()
torch.mm(t, t)
return torch.device('cuda')
except RuntimeError:
pass
if torch.backends.mps.is_available():
return torch.device('mps')
return torch.device('cpu')
def _set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def prepare_data_inverse():
_set_seed(42)
device = _get_device()
x_f = torch.rand(N_F, device=device) * config.L
t_f = torch.rand(N_F, device=device) * config.T_END
# Clustering vicino a X_SRC e T_STEP (stessa strategia del forward PINN)
n_extra = N_F // 4
x_near = config.X_SRC + (torch.rand(n_extra, device=device) - 0.5) * config.L * 0.1
x_near = x_near.clamp(0, config.L)
t_near = torch.rand(n_extra, device=device) * config.T_END
x_step = torch.rand(n_extra, device=device) * config.L
t_step = config.T_STEP + (torch.rand(n_extra, device=device) - 0.5) * 0.1
t_step = t_step.clamp(0, config.T_END)
x_f = torch.cat([x_f, x_near, x_step])
t_f = torch.cat([t_f, t_near, t_step])
x_ic = torch.rand(N_IC, device=device) * config.L
t_bc = torch.rand(N_BC, device=device) * config.T_END
return {'device': device, 'x_f': x_f, 't_f': t_f, 'x_ic': x_ic, 't_bc': t_bc}
def train_inverse(data, x_s, t_s, T_meas, identify=('alpha', 'k', 'h_conv'), init_vals=None, epochs=None,
w_pde=None, w_ic=None, w_bc=None, w_data=None):
device = data['device']
if epochs is None:
epochs = EPOCHS_INV
model = InverseHeatPINN(identify=identify, init_vals=init_vals).to(device)
print(f"\n--- Inverse PINN Training (Adam) su {device} ---")
print(f"Stime iniziali: α={model.alpha.item():.4f} k={model.k.item():.4f} h={model.h_conv.item():.4f}")
print(f"Valori veri: α={config.ALPHA:.4f} k={config.K:.4f} h={config.H_CONV:.4f}\n")
optimizer = optim.Adam(model.parameters(), lr=LR_ADAM_INV)
scheduler = ReduceLROnPlateau(optimizer, mode='min',
factor=SCHED_FACTOR,
patience=SCHED_PATIENCE,
min_lr=SCHED_MIN_LR)
os.makedirs(MODELS_DIR, exist_ok=True)
best_loss = float('inf')
patience_counter = 0
def _save_best(model):
torch.save({
'state_dict': model.state_dict(),
'identify': model.identify,
'alpha': model.alpha.item(),
'k': model.k.item(),
'h_conv': model.h_conv.item(),
}, MODEL_SAVE_PATH)
model.train()
try:
for epoch in range(epochs):
optimizer.zero_grad()
loss, L_pde, L_ic, L_bc, L_data = inverse_heat_pinn_loss(
model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc'],
x_s, t_s, T_meas,
**{k: v for k, v in {'w_pde': w_pde, 'w_ic': w_ic, 'w_bc': w_bc, 'w_data': w_data}.items() if v is not None},
)
loss.backward()
optimizer.step()
scheduler.step(loss.item())
if loss.item() < best_loss - 1e-8:
best_loss = loss.item()
patience_counter = 0
_save_best(model)
else:
patience_counter += 1
if patience_counter >= PATIENCE_INV:
print(f"Early stopping a epoca {epoch + 1}")
break
if (epoch + 1) % 100 == 0:
print(
f"Epoch [{epoch+1}/{epochs}] | Loss: {loss.item():.6f} "
f"| PDE: {L_pde.item():.5f} IC: {L_ic.item():.5f} "
f"BC: {L_bc.item():.5f} Data: {L_data.item():.5f} "
f"| α={model.alpha.item():.5f} k={model.k.item():.5f} h={model.h_conv.item():.4f}"
)
except KeyboardInterrupt:
print(f"\nInterrotto a epoca {epoch + 1}. Salvo stato corrente...")
_save_best(model)
raise
print("\nTraining completato. Modello salvato.")
def evaluate_inverse():
device = _get_device()
if not os.path.exists(MODEL_SAVE_PATH):
print("Modello non trovato. Esegui prima il training.")
return
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True)
model = InverseHeatPINN().to(device)
model.load_state_dict(ckpt['state_dict'])
model.eval()
alpha_id = model.alpha.item()
k_id = model.k.item()
h_conv_id = model.h_conv.item()
print("\n--- Parametri identificati vs veri ---")
print(f"{'Param':<10} {'Vero':>10} {'Identificato':>14} {'Errore %':>10}")
print("-" * 48)
for name, true_val, id_val in [
('alpha', config.ALPHA, alpha_id),
('k', config.K, k_id),
('h_conv', config.H_CONV, h_conv_id),
]:
err = abs(id_val - true_val) / true_val * 100
print(f"{name:<10} {true_val:>10.5f} {id_val:>14.5f} {err:>9.2f}%")
# Errore L2 del campo T vs FDM
from fdm.solver import solve as fdm_solve
T_fdm, x_fdm, t_fdm = fdm_solve()
nx, nt = 100, 100
x_vals = torch.linspace(0, config.L, nx, device=device)
t_vals = torch.linspace(0, config.T_END, nt, device=device)
xx, tt = torch.meshgrid(x_vals, t_vals, indexing='ij')
inp = torch.stack([xx.reshape(-1), tt.reshape(-1)], dim=1)
with torch.no_grad():
T_pred = model(inp).reshape(nx, nt).cpu().numpy()
x_idx = np.linspace(0, T_fdm.shape[0] - 1, nx, dtype=int)
t_idx = np.linspace(0, T_fdm.shape[1] - 1, nt, dtype=int)
T_fdm_ds = T_fdm[np.ix_(x_idx, t_idx)]
l2_rel = np.sqrt(np.mean((T_pred - T_fdm_ds) ** 2)) / np.sqrt(np.mean(T_fdm_ds ** 2))
max_err = np.max(np.abs(T_pred - T_fdm_ds))
print(f"\nErrore L2 relativo campo T vs FDM: {l2_rel:.6f}")
print(f"Errore massimo assoluto: {max_err:.4f} °C")
def generate_visualization_inverse():
device = _get_device()
if not os.path.exists(MODEL_SAVE_PATH):
print("Modello non trovato. Esegui prima il training.")
return
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True)
identify = ckpt.get('identify', ['alpha', 'k', 'h_conv'])
model = InverseHeatPINN(identify=identify).to(device)
model.load_state_dict(ckpt['state_dict'])
model.eval()
nx, nt = 100, 100
x_vals = torch.linspace(0, config.L, nx, device=device)
t_vals = torch.linspace(0, config.T_END, nt, device=device)
xx, tt = torch.meshgrid(x_vals, t_vals, indexing='ij')
with torch.no_grad():
T_pred = model(torch.stack([xx.reshape(-1), tt.reshape(-1)], dim=1)).reshape(nx, nt).cpu().numpy()
from fdm.solver import solve as fdm_solve
T_fdm, _, _ = fdm_solve()
identified_params = {'α': model.alpha.item(), 'k': model.k.item(), 'h': model.h_conv.item()}
from inverse.visualizer import visualize_inverse
visualize_inverse(T_pred, x_vals.cpu().numpy(), t_vals.cpu().numpy(), T_fdm, identified_params)