25 lines
467 B
Python
25 lines
467 B
Python
|
|
import sys
|
||
|
|
import os
|
||
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import torch
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(scope="session")
|
||
|
|
def device():
|
||
|
|
from engine import _get_device
|
||
|
|
return _get_device()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def small_data():
|
||
|
|
from engine import prepare_data
|
||
|
|
return prepare_data(N_f=200, N_ic=50, N_bc=50)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def pinn_model(device):
|
||
|
|
from model import HeatPINN
|
||
|
|
return HeatPINN().to(device)
|