Files
pinn/tests/conftest.py
T

25 lines
467 B
Python
Raw Normal View History

import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
import torch
@pytest.fixture(scope="session")
def device():
from engine import _get_device
return _get_device()
@pytest.fixture
def small_data():
from engine import prepare_data
return prepare_data(N_f=200, N_ic=50, N_bc=50)
@pytest.fixture
def pinn_model(device):
from model import HeatPINN
return HeatPINN().to(device)