295057e80b
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
213 lines
7.5 KiB
Python
213 lines
7.5 KiB
Python
import os
|
||
import sys
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
import random
|
||
import numpy as np
|
||
import torch
|
||
import torch.optim as optim
|
||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||
|
||
import config
|
||
from inverse.config_inverse import (
|
||
N_F, N_IC, N_BC,
|
||
EPOCHS_INV, LR_ADAM_INV, PATIENCE_INV,
|
||
SCHED_FACTOR, SCHED_PATIENCE, SCHED_MIN_LR,
|
||
MODEL_SAVE_PATH, MODELS_DIR,
|
||
)
|
||
from inverse.model import InverseHeatPINN
|
||
from inverse.loss import inverse_heat_pinn_loss
|
||
|
||
|
||
def _get_device():
|
||
if torch.cuda.is_available():
|
||
try:
|
||
t = torch.zeros(2, 2).cuda()
|
||
torch.mm(t, t)
|
||
return torch.device('cuda')
|
||
except RuntimeError:
|
||
pass
|
||
if torch.backends.mps.is_available():
|
||
return torch.device('mps')
|
||
return torch.device('cpu')
|
||
|
||
|
||
def _set_seed(seed=42):
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(seed)
|
||
|
||
|
||
def prepare_data_inverse():
|
||
_set_seed(42)
|
||
device = _get_device()
|
||
|
||
x_f = torch.rand(N_F, device=device) * config.L
|
||
t_f = torch.rand(N_F, device=device) * config.T_END
|
||
|
||
# Clustering vicino a X_SRC e T_STEP (stessa strategia del forward PINN)
|
||
n_extra = N_F // 4
|
||
x_near = config.X_SRC + (torch.rand(n_extra, device=device) - 0.5) * config.L * 0.1
|
||
x_near = x_near.clamp(0, config.L)
|
||
t_near = torch.rand(n_extra, device=device) * config.T_END
|
||
x_step = torch.rand(n_extra, device=device) * config.L
|
||
t_step = config.T_STEP + (torch.rand(n_extra, device=device) - 0.5) * 0.1
|
||
t_step = t_step.clamp(0, config.T_END)
|
||
|
||
x_f = torch.cat([x_f, x_near, x_step])
|
||
t_f = torch.cat([t_f, t_near, t_step])
|
||
|
||
x_ic = torch.rand(N_IC, device=device) * config.L
|
||
t_bc = torch.rand(N_BC, device=device) * config.T_END
|
||
|
||
return {'device': device, 'x_f': x_f, 't_f': t_f, 'x_ic': x_ic, 't_bc': t_bc}
|
||
|
||
|
||
def train_inverse(data, x_s, t_s, T_meas, identify=('alpha', 'k', 'h_conv'), init_vals=None, epochs=None,
|
||
w_pde=None, w_ic=None, w_bc=None, w_data=None):
|
||
device = data['device']
|
||
if epochs is None:
|
||
epochs = EPOCHS_INV
|
||
model = InverseHeatPINN(identify=identify, init_vals=init_vals).to(device)
|
||
|
||
print(f"\n--- Inverse PINN Training (Adam) su {device} ---")
|
||
print(f"Stime iniziali: α={model.alpha.item():.4f} k={model.k.item():.4f} h={model.h_conv.item():.4f}")
|
||
print(f"Valori veri: α={config.ALPHA:.4f} k={config.K:.4f} h={config.H_CONV:.4f}\n")
|
||
|
||
optimizer = optim.Adam(model.parameters(), lr=LR_ADAM_INV)
|
||
scheduler = ReduceLROnPlateau(optimizer, mode='min',
|
||
factor=SCHED_FACTOR,
|
||
patience=SCHED_PATIENCE,
|
||
min_lr=SCHED_MIN_LR)
|
||
|
||
os.makedirs(MODELS_DIR, exist_ok=True)
|
||
best_loss = float('inf')
|
||
patience_counter = 0
|
||
|
||
def _save_best(model):
|
||
torch.save({
|
||
'state_dict': model.state_dict(),
|
||
'identify': model.identify,
|
||
'alpha': model.alpha.item(),
|
||
'k': model.k.item(),
|
||
'h_conv': model.h_conv.item(),
|
||
}, MODEL_SAVE_PATH)
|
||
|
||
model.train()
|
||
try:
|
||
for epoch in range(epochs):
|
||
optimizer.zero_grad()
|
||
loss, L_pde, L_ic, L_bc, L_data = inverse_heat_pinn_loss(
|
||
model, data['x_f'], data['t_f'], data['x_ic'], data['t_bc'],
|
||
x_s, t_s, T_meas,
|
||
**{k: v for k, v in {'w_pde': w_pde, 'w_ic': w_ic, 'w_bc': w_bc, 'w_data': w_data}.items() if v is not None},
|
||
)
|
||
loss.backward()
|
||
optimizer.step()
|
||
scheduler.step(loss.item())
|
||
|
||
if loss.item() < best_loss - 1e-8:
|
||
best_loss = loss.item()
|
||
patience_counter = 0
|
||
_save_best(model)
|
||
else:
|
||
patience_counter += 1
|
||
|
||
if patience_counter >= PATIENCE_INV:
|
||
print(f"Early stopping a epoca {epoch + 1}")
|
||
break
|
||
|
||
if (epoch + 1) % 100 == 0:
|
||
print(
|
||
f"Epoch [{epoch+1}/{epochs}] | Loss: {loss.item():.6f} "
|
||
f"| PDE: {L_pde.item():.5f} IC: {L_ic.item():.5f} "
|
||
f"BC: {L_bc.item():.5f} Data: {L_data.item():.5f} "
|
||
f"| α={model.alpha.item():.5f} k={model.k.item():.5f} h={model.h_conv.item():.4f}"
|
||
)
|
||
except KeyboardInterrupt:
|
||
print(f"\nInterrotto a epoca {epoch + 1}. Salvo stato corrente...")
|
||
_save_best(model)
|
||
raise
|
||
|
||
print("\nTraining completato. Modello salvato.")
|
||
|
||
|
||
def evaluate_inverse():
|
||
device = _get_device()
|
||
|
||
if not os.path.exists(MODEL_SAVE_PATH):
|
||
print("Modello non trovato. Esegui prima il training.")
|
||
return
|
||
|
||
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True)
|
||
model = InverseHeatPINN().to(device)
|
||
model.load_state_dict(ckpt['state_dict'])
|
||
model.eval()
|
||
|
||
alpha_id = model.alpha.item()
|
||
k_id = model.k.item()
|
||
h_conv_id = model.h_conv.item()
|
||
|
||
print("\n--- Parametri identificati vs veri ---")
|
||
print(f"{'Param':<10} {'Vero':>10} {'Identificato':>14} {'Errore %':>10}")
|
||
print("-" * 48)
|
||
for name, true_val, id_val in [
|
||
('alpha', config.ALPHA, alpha_id),
|
||
('k', config.K, k_id),
|
||
('h_conv', config.H_CONV, h_conv_id),
|
||
]:
|
||
err = abs(id_val - true_val) / true_val * 100
|
||
print(f"{name:<10} {true_val:>10.5f} {id_val:>14.5f} {err:>9.2f}%")
|
||
|
||
# Errore L2 del campo T vs FDM
|
||
from fdm.solver import solve as fdm_solve
|
||
T_fdm, x_fdm, t_fdm = fdm_solve()
|
||
|
||
nx, nt = 100, 100
|
||
x_vals = torch.linspace(0, config.L, nx, device=device)
|
||
t_vals = torch.linspace(0, config.T_END, nt, device=device)
|
||
xx, tt = torch.meshgrid(x_vals, t_vals, indexing='ij')
|
||
inp = torch.stack([xx.reshape(-1), tt.reshape(-1)], dim=1)
|
||
with torch.no_grad():
|
||
T_pred = model(inp).reshape(nx, nt).cpu().numpy()
|
||
|
||
x_idx = np.linspace(0, T_fdm.shape[0] - 1, nx, dtype=int)
|
||
t_idx = np.linspace(0, T_fdm.shape[1] - 1, nt, dtype=int)
|
||
T_fdm_ds = T_fdm[np.ix_(x_idx, t_idx)]
|
||
|
||
l2_rel = np.sqrt(np.mean((T_pred - T_fdm_ds) ** 2)) / np.sqrt(np.mean(T_fdm_ds ** 2))
|
||
max_err = np.max(np.abs(T_pred - T_fdm_ds))
|
||
print(f"\nErrore L2 relativo campo T vs FDM: {l2_rel:.6f}")
|
||
print(f"Errore massimo assoluto: {max_err:.4f} °C")
|
||
|
||
|
||
def generate_visualization_inverse():
|
||
device = _get_device()
|
||
|
||
if not os.path.exists(MODEL_SAVE_PATH):
|
||
print("Modello non trovato. Esegui prima il training.")
|
||
return
|
||
|
||
ckpt = torch.load(MODEL_SAVE_PATH, map_location=device, weights_only=True)
|
||
identify = ckpt.get('identify', ['alpha', 'k', 'h_conv'])
|
||
model = InverseHeatPINN(identify=identify).to(device)
|
||
model.load_state_dict(ckpt['state_dict'])
|
||
model.eval()
|
||
|
||
nx, nt = 100, 100
|
||
x_vals = torch.linspace(0, config.L, nx, device=device)
|
||
t_vals = torch.linspace(0, config.T_END, nt, device=device)
|
||
xx, tt = torch.meshgrid(x_vals, t_vals, indexing='ij')
|
||
with torch.no_grad():
|
||
T_pred = model(torch.stack([xx.reshape(-1), tt.reshape(-1)], dim=1)).reshape(nx, nt).cpu().numpy()
|
||
|
||
from fdm.solver import solve as fdm_solve
|
||
T_fdm, _, _ = fdm_solve()
|
||
|
||
identified_params = {'α': model.alpha.item(), 'k': model.k.item(), 'h': model.h_conv.item()}
|
||
|
||
from inverse.visualizer import visualize_inverse
|
||
visualize_inverse(T_pred, x_vals.cpu().numpy(), t_vals.cpu().numpy(), T_fdm, identified_params)
|