Commit 163c4795 by Tianqi Chen Committed by GitHub

[CODEGEN] Bugfix multiple condition generation (#558)

parent 10faa893
...@@ -131,26 +131,29 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -131,26 +131,29 @@ class CodeGenAMDGPU : public CodeGenLLVM {
} }
}; };
runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { inline int DetectROCMComputeVersion() {
CHECK(target.length( TVMContext tvm_ctx;
) >= 4 && tvm_ctx.device_type = kROCM;
target.substr(0, 4) == "rocm"); tvm_ctx.device_id = 0;
TVMContext tvmCtx;
tvmCtx.device_type = kROCM;
tvmCtx.device_id = 0;
TVMRetValue val; TVMRetValue val;
tvm::runtime::DeviceAPI::Get(tvmCtx)->GetAttr(tvmCtx, tvm::runtime::kExist, &val); tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
tvm_ctx, tvm::runtime::kExist, &val);
if (val.operator int() == 1) { if (val.operator int() == 1) {
tvm::runtime::DeviceAPI::Get(tvmCtx)->GetAttr(tvmCtx, tvm::runtime::kComputeVersion, &val); tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
return val.operator int();
} else { } else {
val = 803; return 803;
} }
}
llvm::TargetMachine* tm = \ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
GetLLVMTargetMachine("-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" + \ CHECK(target.length() >= 4 &&
std::to_string(val.operator int())+ target.substr(4, target.length() - 4)); target.substr(0, 4) == "rocm");
std::ostringstream config;
config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx"
<< DetectROCMComputeVersion()
<< target.substr(4, target.length() - 4);
llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU()); std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext()); std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init(funcs[0]->name, tm, ctx.get(), false, false); cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
...@@ -159,7 +162,6 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { ...@@ -159,7 +162,6 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
} }
std::unique_ptr<llvm::Module> module = cg->Finish(); std::unique_ptr<llvm::Module> module = cg->Finish();
llvm::SmallString<8> dataObj, data_ll, dataAsm; llvm::SmallString<8> dataObj, data_ll, dataAsm;
llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm); llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm);
destObj.SetUnbuffered(); destObj.SetUnbuffered();
......
...@@ -582,14 +582,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { ...@@ -582,14 +582,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block); builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
builder_->SetInsertPoint(then_block); builder_->SetInsertPoint(then_block);
llvm::Value* then_value = MakeValue(op->args[1]); llvm::Value* then_value = MakeValue(op->args[1]);
BasicBlock* then_value_block = builder_->GetInsertBlock();
builder_->CreateBr(end_block); builder_->CreateBr(end_block);
builder_->SetInsertPoint(else_block); builder_->SetInsertPoint(else_block);
llvm::Value* else_value = MakeValue(op->args[2]); llvm::Value* else_value = MakeValue(op->args[2]);
BasicBlock* else_value_block = builder_->GetInsertBlock();
builder_->CreateBr(end_block); builder_->CreateBr(end_block);
builder_->SetInsertPoint(end_block); builder_->SetInsertPoint(end_block);
llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2); llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
value->addIncoming(then_value, then_block); value->addIncoming(then_value, then_value_block);
value->addIncoming(else_value, else_block); value->addIncoming(else_value, else_value_block);
return value; return value;
} else { } else {
LOG(FATAL) << "unknown intrinsic " << op->name; LOG(FATAL) << "unknown intrinsic " << op->name;
......
...@@ -130,12 +130,34 @@ class CodeGenNVPTX : public CodeGenLLVM { ...@@ -130,12 +130,34 @@ class CodeGenNVPTX : public CodeGenLLVM {
} }
}; };
inline int DetectCUDAComputeVersion() {
TVMContext tvm_ctx;
tvm_ctx.device_type = kGPU;
tvm_ctx.device_id = 0;
TVMRetValue val;
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
tvm_ctx, tvm::runtime::kExist, &val);
if (val.operator int() == 1) {
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
tvm_ctx, tvm::runtime::kComputeVersion, &val);
std::string version = val;
std::istringstream is(version);
double ver;
is >> ver;
return static_cast<int>(ver * 10);
} else {
return 20;
}
}
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) { runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
CHECK(target.length() >= 5 && CHECK(target.length() >= 5 &&
target.substr(0, 5) == "nvptx"); target.substr(0, 5) == "nvptx");
llvm::TargetMachine* tm = GetLLVMTargetMachine( std::ostringstream config;
"-mtriple=nvptx64-nvidia-cuda -mcpu=sm_20" + config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_"
target.substr(5, target.length() - 5)); << DetectCUDAComputeVersion()
<< target.substr(5, target.length() - 5);
llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX()); std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext()); std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init(funcs[0]->name, tm, ctx.get(), false, false); cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
......
...@@ -22,6 +22,7 @@ def schedule_injective(outs): ...@@ -22,6 +22,7 @@ def schedule_injective(outs):
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm": if target.target_name != "llvm":
raise RuntimeError("schedule_injective not registered for '%s'" % target) raise RuntimeError("schedule_injective not registered for '%s'" % target)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
x = outs[0] x = outs[0]
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s) tvm.schedule.AutoInlineInjective(s)
......
...@@ -6,6 +6,7 @@ import tvm ...@@ -6,6 +6,7 @@ import tvm
def _default_schedule(outs, auto_inline): def _default_schedule(outs, auto_inline):
"""Default schedule for llvm.""" """Default schedule for llvm."""
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if target.target_name != "llvm": if target.target_name != "llvm":
raise RuntimeError("schedule_pool not registered for '%s'" % target) raise RuntimeError("schedule_pool not registered for '%s'" % target)
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
......
...@@ -124,11 +124,11 @@ def test_gemm(): ...@@ -124,11 +124,11 @@ def test_gemm():
t = timer_f(a, b, c).mean t = timer_f(a, b, c).mean
GFLOPS = num_flops / (t * 1e3) / 1e6 GFLOPS = num_flops / (t * 1e3) / 1e6
print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS)) print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
for device in ['cuda', 'opencl', 'rocm']: for device in ["cuda", "opencl", "rocm"]:
with tvm.build_config(auto_unroll_max_step=32, with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0, auto_unroll_min_depth=0,
unroll_explicit=device == 'rocm'): unroll_explicit=(device != "cuda")):
check_device(device) check_device(device)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -74,11 +74,9 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): ...@@ -74,11 +74,9 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
for _ in range(1): for _ in range(1):
foo(data_tvm, out_tvm) foo(data_tvm, out_tvm)
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3) np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
for device in ["cuda", "opencl", "metal", "llvm", "rocm"]:
check_device(device)
check_device("opencl")
check_device("cuda")
check_device("metal")
check_device("rocm")
def test_reduce_map(): def test_reduce_map():
verify_reduce_map_ele(in_shape=(128, 24, 128, 24), verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import numpy as np import numpy as np
import tvm import tvm
import topi import topi
import logging
from topi.util import get_const_tuple from topi.util import get_const_tuple
def verify_softmax(m, n): def verify_softmax(m, n):
...@@ -42,8 +43,6 @@ def verify_log_softmax(m, n): ...@@ -42,8 +43,6 @@ def verify_log_softmax(m, n):
# confirm lower works # confirm lower works
s = tvm.create_schedule([B.op]) s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True) tvm.lower(s, [A, B], simple_mode=True)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np) b_np = topi.testing.log_softmax_python(a_np)
...@@ -60,13 +59,15 @@ def verify_log_softmax(m, n): ...@@ -60,13 +59,15 @@ def verify_log_softmax(m, n):
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']: for device in ["cuda", "opencl", "metal", "rocm"]:
check_device(device) check_device(device)
def test_log_softmax(): def test_log_softmax():
verify_log_softmax(32, 10) verify_log_softmax(32, 10)
verify_log_softmax(3, 4) verify_log_softmax(3, 4)
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
test_softmax() test_softmax()
test_log_softmax() test_log_softmax()
...@@ -21,10 +21,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -21,10 +21,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("opencl") for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device("cuda") check_device(device)
check_device("metal")
check_device("rocm")
def verify_tranpose(in_shape, axes): def verify_tranpose(in_shape, axes):
...@@ -45,10 +43,9 @@ def verify_tranpose(in_shape, axes): ...@@ -45,10 +43,9 @@ def verify_tranpose(in_shape, axes):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("cuda") for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device("opencl") check_device(device)
check_device("metal")
check_device("rocm")
def verify_reshape(src_shape, dst_shape): def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
...@@ -68,10 +65,9 @@ def verify_reshape(src_shape, dst_shape): ...@@ -68,10 +65,9 @@ def verify_reshape(src_shape, dst_shape):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("cuda") for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device("opencl") check_device(device)
check_device("metal")
check_device("rocm")
def verify_squeeze(src_shape, axis): def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
...@@ -95,10 +91,8 @@ def verify_squeeze(src_shape, axis): ...@@ -95,10 +91,8 @@ def verify_squeeze(src_shape, axis):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("cuda") for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device("opencl") check_device(device)
check_device("metal")
check_device("rocm")
def verify_concatenate(shapes, axis): def verify_concatenate(shapes, axis):
tensor_l = [] tensor_l = []
...@@ -120,10 +114,9 @@ def verify_concatenate(shapes, axis): ...@@ -120,10 +114,9 @@ def verify_concatenate(shapes, axis):
foo(*(data_nds + [out_nd])) foo(*(data_nds + [out_nd]))
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("cuda") for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device("opencl") check_device(device)
check_device("metal")
check_device("rocm")
def verify_split(src_shape, indices_or_sections, axis): def verify_split(src_shape, indices_or_sections, axis):
A = tvm.placeholder(shape=src_shape, name="A") A = tvm.placeholder(shape=src_shape, name="A")
...@@ -144,10 +137,9 @@ def verify_split(src_shape, indices_or_sections, axis): ...@@ -144,10 +137,9 @@ def verify_split(src_shape, indices_or_sections, axis):
for out_nd, out_npy in zip(out_nds, out_npys): for out_nd, out_npy in zip(out_nds, out_npys):
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("cuda") for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device("opencl") check_device(device)
check_device("metal")
check_device("rocm")
def test_expand_dims(): def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
...@@ -175,6 +167,7 @@ def test_squeeze(): ...@@ -175,6 +167,7 @@ def test_squeeze():
def test_concatenate(): def test_concatenate():
verify_concatenate([(2,), (2,), (2,)], 0)
verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1) verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1) verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
verify_concatenate([(5, 6, 7, 3), verify_concatenate([(5, 6, 7, 3),
...@@ -190,9 +183,9 @@ def test_split(): ...@@ -190,9 +183,9 @@ def test_split():
verify_split((10, 12, 24), [5, 7, 9], -1) verify_split((10, 12, 24), [5, 7, 9], -1)
if __name__ == "__main__": if __name__ == "__main__":
test_concatenate()
test_tranpose() test_tranpose()
test_expand_dims() test_expand_dims()
test_reshape() test_reshape()
test_squeeze() test_squeeze()
test_concatenate()
test_split() test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment