68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
|
|
import sys
|
||
|
|
import os
|
||
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import config
|
||
|
|
from engine import set_seed, _get_device, prepare_data
|
||
|
|
|
||
|
|
|
||
|
|
def test_set_seed_reproducibility():
|
||
|
|
set_seed(42)
|
||
|
|
r1 = torch.rand(10)
|
||
|
|
set_seed(42)
|
||
|
|
r2 = torch.rand(10)
|
||
|
|
torch.testing.assert_close(r1, r2)
|
||
|
|
|
||
|
|
|
||
|
|
def test_get_device_valid():
|
||
|
|
device = _get_device()
|
||
|
|
assert isinstance(device, torch.device)
|
||
|
|
assert device.type in ('cpu', 'cuda', 'mps')
|
||
|
|
|
||
|
|
|
||
|
|
def test_prepare_data_keys():
|
||
|
|
data = prepare_data(N_f=100, N_ic=50, N_bc=50)
|
||
|
|
assert set(data.keys()) == {'device', 'x_f', 't_f', 'x_ic', 't_bc'}
|
||
|
|
|
||
|
|
|
||
|
|
def test_prepare_data_shapes():
|
||
|
|
N_f, N_ic, N_bc = 100, 50, 50
|
||
|
|
data = prepare_data(N_f=N_f, N_ic=N_ic, N_bc=N_bc)
|
||
|
|
# engine.py aggiunge 2 * (N_f // 4) punti di clustering
|
||
|
|
expected_f = N_f + 2 * (N_f // 4)
|
||
|
|
assert data['x_f'].shape == (expected_f,)
|
||
|
|
assert data['t_f'].shape == (expected_f,)
|
||
|
|
assert data['x_ic'].shape == (N_ic,)
|
||
|
|
assert data['t_bc'].shape == (N_bc,)
|
||
|
|
|
||
|
|
|
||
|
|
def test_prepare_data_x_bounds():
|
||
|
|
data = prepare_data(N_f=500, N_ic=100, N_bc=100)
|
||
|
|
assert data['x_f'].min().item() >= 0.0
|
||
|
|
assert data['x_f'].max().item() <= config.L
|
||
|
|
assert data['x_ic'].min().item() >= 0.0
|
||
|
|
assert data['x_ic'].max().item() <= config.L
|
||
|
|
|
||
|
|
|
||
|
|
def test_prepare_data_t_bounds():
|
||
|
|
data = prepare_data(N_f=500, N_ic=100, N_bc=100)
|
||
|
|
assert data['t_f'].min().item() >= 0.0
|
||
|
|
assert data['t_f'].max().item() <= config.T_END
|
||
|
|
|
||
|
|
|
||
|
|
def test_prepare_data_device_consistency():
|
||
|
|
data = prepare_data(N_f=100, N_ic=50, N_bc=50)
|
||
|
|
expected = data['device'].type
|
||
|
|
for key in ('x_f', 't_f', 'x_ic', 't_bc'):
|
||
|
|
assert data[key].device.type == expected, f"{key} sul device sbagliato"
|
||
|
|
|
||
|
|
|
||
|
|
def test_prepare_data_deterministic():
|
||
|
|
"""Due chiamate con lo stesso seed (fissato in prepare_data) producono dati identici."""
|
||
|
|
d1 = prepare_data(N_f=100, N_ic=50, N_bc=50)
|
||
|
|
d2 = prepare_data(N_f=100, N_ic=50, N_bc=50)
|
||
|
|
torch.testing.assert_close(d1['x_f'], d2['x_f'])
|
||
|
|
torch.testing.assert_close(d1['t_f'], d2['t_f'])
|
||
|
|
torch.testing.assert_close(d1['x_ic'], d2['x_ic'])
|