Files
pinn/tests/test_engine_data.py
T

68 lines
2.1 KiB
Python
Raw Normal View History

import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import config
from engine import set_seed, _get_device, prepare_data
def test_set_seed_reproducibility():
set_seed(42)
r1 = torch.rand(10)
set_seed(42)
r2 = torch.rand(10)
torch.testing.assert_close(r1, r2)
def test_get_device_valid():
device = _get_device()
assert isinstance(device, torch.device)
assert device.type in ('cpu', 'cuda', 'mps')
def test_prepare_data_keys():
data = prepare_data(N_f=100, N_ic=50, N_bc=50)
assert set(data.keys()) == {'device', 'x_f', 't_f', 'x_ic', 't_bc'}
def test_prepare_data_shapes():
N_f, N_ic, N_bc = 100, 50, 50
data = prepare_data(N_f=N_f, N_ic=N_ic, N_bc=N_bc)
# engine.py aggiunge 2 * (N_f // 4) punti di clustering
expected_f = N_f + 2 * (N_f // 4)
assert data['x_f'].shape == (expected_f,)
assert data['t_f'].shape == (expected_f,)
assert data['x_ic'].shape == (N_ic,)
assert data['t_bc'].shape == (N_bc,)
def test_prepare_data_x_bounds():
data = prepare_data(N_f=500, N_ic=100, N_bc=100)
assert data['x_f'].min().item() >= 0.0
assert data['x_f'].max().item() <= config.L
assert data['x_ic'].min().item() >= 0.0
assert data['x_ic'].max().item() <= config.L
def test_prepare_data_t_bounds():
data = prepare_data(N_f=500, N_ic=100, N_bc=100)
assert data['t_f'].min().item() >= 0.0
assert data['t_f'].max().item() <= config.T_END
def test_prepare_data_device_consistency():
data = prepare_data(N_f=100, N_ic=50, N_bc=50)
expected = data['device'].type
for key in ('x_f', 't_f', 'x_ic', 't_bc'):
assert data[key].device.type == expected, f"{key} sul device sbagliato"
def test_prepare_data_deterministic():
"""Due chiamate con lo stesso seed (fissato in prepare_data) producono dati identici."""
d1 = prepare_data(N_f=100, N_ic=50, N_bc=50)
d2 = prepare_data(N_f=100, N_ic=50, N_bc=50)
torch.testing.assert_close(d1['x_f'], d2['x_f'])
torch.testing.assert_close(d1['t_f'], d2['t_f'])
torch.testing.assert_close(d1['x_ic'], d2['x_ic'])