diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index a5cefd88eda4a..032710df4e8c1 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -3,6 +3,7 @@ from .transformer import TaichiSyntaxError from .ndrange import ndrange, GroupedNDRange from copy import deepcopy as _deepcopy +import functools import os core = taichi_lang_core @@ -208,6 +209,13 @@ def _get_or_make_arch_checkers(kwargs): def all_archs_with(**kwargs): kwargs = _deepcopy(kwargs) def decorator(test): + # @pytest.mark.parametrize decorator only knows about regular function args, + # without *args or **kwargs. By decorating with @functools.wraps, the + # signature of |test| is preserved, so that @ti.all_archs can be used after + # the parametrization decorator. + # + # Full discussion: https://github.com/pytest-dev/pytest/issues/6810 + @functools.wraps(test) def wrapped(*test_args, **test_kwargs): import taichi as ti can_run_on = test_kwargs.pop( @@ -252,6 +260,7 @@ def archs_excluding(*excluded_archs, **kwargs): excluded_archs = set(excluded_archs) def decorator(test): + @functools.wraps(test) def wrapped(*test_args, **test_kwargs): def checker(arch): return arch not in excluded_archs _get_or_make_arch_checkers(test_kwargs).register(checker) @@ -275,6 +284,7 @@ def require(*exts): assert all([isinstance(e, core.Extension) for e in exts]) def decorator(test): + @functools.wraps(test) def wrapped(*test_args, **test_kwargs): def checker(arch): return all([is_supported(arch, e) for e in exts]) _get_or_make_arch_checkers(test_kwargs).register(checker) diff --git a/taichi/platform/metal/metal_runtime.cpp b/taichi/platform/metal/metal_runtime.cpp index 4113349de2c7d..37ce4f90213e2 100644 --- a/taichi/platform/metal/metal_runtime.cpp +++ b/taichi/platform/metal/metal_runtime.cpp @@ -278,6 +278,7 @@ class MetalRuntime::Impl { auto *llvm_ctx = params.llvm_ctx; auto *llvm_rtm = params.llvm_runtime; + TI_ASSERT(llvm_ctx != nullptr && llvm_rtm != nullptr); const size_t rtm_root_mem_size = llvm_ctx->lookup_function( "Runtime_get_root_mem_size")(llvm_rtm); if (rtm_root_mem_size > 0) { @@ -286,6 +287,7 @@ class MetalRuntime::Impl { rtm_root_mem_size); auto *rtm_root_mem = params.llvm_ctx->lookup_function( "Runtime_get_root")(llvm_rtm); + TI_ASSERT(rtm_root_mem != nullptr); root_buffer_ = new_mtl_buffer_no_copy(device_.get(), rtm_root_mem, rtm_root_mem_size); } else { diff --git a/taichi/program.cpp b/taichi/program.cpp index 7efce04fa87a8..d278a8d8a3f75 100644 --- a/taichi/program.cpp +++ b/taichi/program.cpp @@ -190,7 +190,8 @@ void Program::materialize_layout() { std::unique_ptr scomp = StructCompiler::make(this, Arch::x64); scomp->run(*snode_root, true); - if (arch_is_cpu(config.arch) || config.arch == Arch::cuda) { + if (arch_is_cpu(config.arch) || config.arch == Arch::cuda || + config.arch == Arch::metal) { initialize_runtime_system(scomp.get()); } diff --git a/tests/python/test_types.py b/tests/python/test_types.py index b6b010f5256a1..62b4a5eb9b464 100644 --- a/tests/python/test_types.py +++ b/tests/python/test_types.py @@ -1,73 +1,86 @@ import taichi as ti +import pytest -def all_data_types_and_test(foo): - def wrapped(): - tests = [] - for dt in [ti.i32, ti.i64, ti.i8, ti.i16, ti.u8, ti.u16, ti.u32, ti.u64, ti.f32, ti.f64]: - tests.append(foo(dt)) - for test in tests: - # variables are expected to be declared before kernel invocation, discuss at: - # https://github.com/taichi-dev/taichi/pull/505#issuecomment-588644274 - test() - return wrapped +_TI_TYPES = [ti.i8, ti.i16, ti.i32, ti.u8, ti.u16, ti.u32, ti.f32] +_TI_64_TYPES = [ti.i64, ti.u64, ti.f64] -@ti.all_archs -@all_data_types_and_test -def test_type_assign_argument(dt): +def _test_type_assign_argument(dt): x = ti.var(dt, shape=()) - def tester(): - @ti.kernel - def func(value: dt): - x[None] = value + @ti.kernel + def func(value: dt): + x[None] = value + + func(3) + assert x[None] == 3 - func(3) - assert x[None] == 3 +@pytest.mark.parametrize('dt', _TI_TYPES) +# Metal backend doesn't support arg type other than 32-bit yet. +@ti.archs_excluding(ti.metal) +def test_type_assign_argument(dt): + _test_type_assign_argument(dt) - return tester +@pytest.mark.parametrize('dt', _TI_64_TYPES) +@ti.require(ti.extension.data64) @ti.all_archs -@all_data_types_and_test -def test_type_operator(dt): +def test_type_assign_argument64(dt): + _test_type_assign_argument(dt) + +def _test_type_operator(dt): x = ti.var(dt, shape=()) y = ti.var(dt, shape=()) add = ti.var(dt, shape=()) mul = ti.var(dt, shape=()) - def tester(): - @ti.kernel - def func(): - add[None] = x[None] + y[None] - mul[None] = x[None] * y[None] - - for i in range(0, 3): - for j in range(0, 3): - x[None] = i - y[None] = j - func() - assert add[None] == x[None] + y[None] - assert mul[None] == x[None] * y[None] - - return tester + @ti.kernel + def func(): + add[None] = x[None] + y[None] + mul[None] = x[None] * y[None] + + for i in range(0, 3): + for j in range(0, 3): + x[None] = i + y[None] = j + func() + assert add[None] == x[None] + y[None] + assert mul[None] == x[None] * y[None] + +@pytest.mark.parametrize('dt', _TI_TYPES) +@ti.all_archs +def test_type_operator(dt): + _test_type_operator(dt) +@pytest.mark.parametrize('dt', _TI_64_TYPES) +@ti.require(ti.extension.data64) @ti.all_archs -@all_data_types_and_test -def test_type_tensor(dt): +def test_type_operator64(dt): + _test_type_operator(dt) + +def _test_type_tensor(dt): x = ti.var(dt, shape=(3, 2)) - def tester(): - @ti.kernel - def func(i: ti.i32, j: ti.i32): - x[i, j] = 3 + @ti.kernel + def func(i: ti.i32, j: ti.i32): + x[i, j] = 3 - for i in range(0, 3): - for j in range(0, 2): - func(i, j) - assert x[i, j] == 3 + for i in range(0, 3): + for j in range(0, 2): + func(i, j) + assert x[i, j] == 3 - return tester +@pytest.mark.parametrize('dt', _TI_TYPES) @ti.all_archs +def test_type_tensor(dt): + _test_type_tensor(dt) + +@pytest.mark.parametrize('dt', _TI_64_TYPES) +@ti.require(ti.extension.data64) +@ti.all_archs +def test_type_tensor64(dt): + _test_type_tensor(dt) + def _test_overflow(dt, n): a = ti.var(dt, shape=()) b = ti.var(dt, shape=()) @@ -90,12 +103,23 @@ def func(): else: assert c[None] == 2 ** n // 3 * 2 # does not overflow -def test_overflow(): - _test_overflow(ti.i8, 8) - _test_overflow(ti.u8, 8) - _test_overflow(ti.i16, 16) - _test_overflow(ti.u16, 16) - _test_overflow(ti.i32, 32) - _test_overflow(ti.u32, 32) - _test_overflow(ti.i64, 64) - _test_overflow(ti.u64, 64) +@pytest.mark.parametrize('dt,n', [ + (ti.i8, 8), + (ti.u8, 8), + (ti.i16, 16), + (ti.u16, 16), + (ti.i32, 32), + (ti.u32, 32), +]) +@ti.all_archs +def test_overflow(dt, n): + _test_overflow(dt, n) + +@pytest.mark.parametrize('dt,n', [ + (ti.i64, 64), + (ti.u64, 64), +]) +@ti.require(ti.extension.data64) +@ti.all_archs +def test_overflow64(dt, n): + _test_overflow(dt, n)