Commit d17b10f0 by Tianqi Chen Committed by GitHub

[LANG/CODEGEN] Intrinsics and Extern Math (#101)

* [LANG/CODEGEN] Intrinsics and Extern Math

* fix lint
parent 181edb4a
......@@ -5,6 +5,7 @@ Python API
:maxdepth: 2
tvm
intrin
tensor
schedule
build
......
tvm.intrin
----------
.. automodule:: tvm.intrin
.. autosummary::
tvm.call_packed
tvm.call_pure_intrin
tvm.call_pure_extern
tvm.register_intrin_rule
tvm.exp
tvm.log
.. autofunction:: tvm.call_packed
.. autofunction:: tvm.call_pure_intrin
.. autofunction:: tvm.call_pure_extern
.. autofunction:: tvm.register_intrin_rule
.. autofunction:: tvm.exp
.. autofunction:: tvm.log
......@@ -13,7 +13,6 @@ The user facing API for computation declaration.
tvm.compute
tvm.scan
tvm.extern
tvm.call_packed
tvm.decl_buffer
tvm.reduce_axis
tvm.thread_axis
......@@ -30,7 +29,6 @@ The user facing API for computation declaration.
.. autofunction:: tvm.compute
.. autofunction:: tvm.scan
.. autofunction:: tvm.extern
.. autofunction:: tvm.call_packed
.. autofunction:: tvm.decl_buffer
.. autofunction:: tvm.reduce_axis
.. autofunction:: tvm.thread_axis
......
......@@ -241,6 +241,14 @@ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
* \return Transformed function.
*/
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.
* \param target The target device.
* \return Transformed function.
*/
LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
} // namespace ir
} // namespace tvm
......
......@@ -335,8 +335,10 @@ TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
*
* \param name The name of the function.
* \param f The function to be registered.
* \param override Whether allow override already registered function.
*/
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f);
TVM_DLL int TVMFuncRegisterGlobal(
const char* name, TVMFunctionHandle f, int override);
/*!
* \brief Get a global function.
......
......@@ -50,8 +50,10 @@ class Registry {
/*!
* \brief Register a function with given name
* \param name The name of the function.
* \param override Whether allow oveeride existing function.
* \return Reference to theregistry.
*/
static Registry& Register(const std::string& name); // NOLINT(*)
static Registry& Register(const std::string& name, bool override = false); // NOLINT(*)
/*!
* \brief Erase global function from registry, if exist.
* \param name The name of the function.
......
......@@ -22,6 +22,7 @@ from ._ctypes._function import Function
from ._base import TVMError
from ._base import __version__
from .api import *
from .intrin import *
from .node import register_node
from .schedule import create_schedule
from .build import build, lower
......@@ -75,7 +75,7 @@ def convert_to_tvm_func(pyfunc):
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
return Function(handle)
return Function(handle, False)
def _make_tvm_args(args, temp_args):
......@@ -160,7 +160,7 @@ class Function(object):
"""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
def __init__(self, handle, is_global=False):
def __init__(self, handle, is_global):
"""Initialize the function with handle
Parameters
......@@ -168,8 +168,8 @@ class Function(object):
handle : FunctionHandle
the handle to the underlying function.
is_global : bool, optional
Whether it is global function
is_global : bool
Whether this is a global function in python
"""
self.handle = handle
self.is_global = is_global
......@@ -242,7 +242,7 @@ class ModuleBase(object):
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return Function(ret_handle)
return Function(ret_handle, False)
def import_module(self, module):
"""Add module to the import list of current one.
......@@ -308,7 +308,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
def register_func(func_name, f=None):
def register_func(func_name, f=None, override=False):
"""Register global function
Parameters
......@@ -319,6 +319,9 @@ def register_func(func_name, f=None):
f : function, optional
The function to be registered.
override: boolean optional
Whether override existing entry.
Returns
-------
fregister : function
......@@ -350,12 +353,14 @@ def register_func(func_name, f=None):
if not isinstance(func_name, str):
raise ValueError("expect string function name")
ioverride = ctypes.c_int(override)
def register(myf):
"""internal register function"""
if not isinstance(myf, Function):
myf = convert_to_tvm_func(myf)
check_call(_LIB.TVMFuncRegisterGlobal(
c_str(func_name), myf.handle))
c_str(func_name), myf.handle, ioverride))
if f:
register(f)
else:
......@@ -377,7 +382,7 @@ def get_global_func(name):
"""
handle = FunctionHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
return Function(handle, True)
return Function(handle, False)
def list_global_func_names():
......@@ -401,6 +406,7 @@ def list_global_func_names():
def _get_api(f):
flocal = f
flocal.is_global = True
def my_api_func(*args):
"""
......
......@@ -284,19 +284,6 @@ def extern(shape, inputs, fcompute,
return res[0] if len(res) == 1 else res
def call_packed(*args):
"""Build expression by call an external packed function
Parameters
----------
args : list
Positional arguments.
"""
args = convert(args)
return _make.Call(
int32, "tvm_call_packed", args, 4, None, 0)
def decl_buffer(shape, dtype=None,
name="buffer",
data=None,
......
......@@ -140,6 +140,7 @@ def build(sch,
warp_size = 32 if target == "cuda" else 1
fapi = ir_pass.LowerThreadAllreduce(fapi, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(fapi)]
if len(fsplits) > 1:
mhost = codegen.build_module(fsplits[0], target_host)
if target:
......
"""Intrinsics and math functions in TVM."""
from __future__ import absolute_import as _abs
from .expr import Call as _Call
from . import make as _make
from ._ctypes._function import register_func as _register_func
from .api import convert
def call_packed(*args):
"""Build expression by call an external packed function
Parameters
----------
args : list
Positional arguments.
"""
return _make.Call(
"int32", "tvm_call_packed", args, _Call.Intrinsic, None, 0)
def call_pure_intrin(dtype, func_name, *args):
"""Build expression by calling a pure intrinsic function.
Intrinsics can be overloaded with multiple data types via
the intrinsic translation rule.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The intrinsic function name.
args : list
Positional arguments.
"""
args = convert(args)
return _make.Call(
dtype, func_name, convert(args), _Call.PureIntrinsic, None, 0)
def call_pure_extern(dtype, func_name, *args):
"""Build expression by calling a pure extern function.
Parameters
----------
dtype : str
The data type of the result.
func_name: str
The intrinsic function name.
args : list
Positional arguments.
"""
return _make.Call(
dtype, func_name, convert(args), _Call.PureExtern, None, 0)
def exp(x):
"""Take exponetial of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "exp", x)
def log(x):
"""Take log of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "log", x)
# Intrinsic rule related code
def register_intrin_rule(target, intrin, f=None, override=False):
"""Register an intrinsic function generation rule.
Intrinsic generation rules are callback functions for
code generator to get device specific calls.
This function simply translates to.
:code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)`
TVM may already pre-register intrinsic rules in the backend.
However, user can use this function to change the intrinsic translation
behavior or add new intrinsic rules during runtime.
Parameters
----------
target : str
The name of codegen target.
intrin : str
The name of the instrinsic.
f : function, optional
The function to be registered.
override: boolean optional
Whether override existing entry.
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
The following code registers exp expansion rule for opencl.
.. code-block:: python
register_intrin_rule("opencl", "exp", my_exp_rule, override=True)
"""
return _register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)
def _rule_float_suffix(op):
"""Intrinsic rule: Add float suffix if it is float32.
This is an example intrinsic generation rule.
Parameters
----------
op : Expr
The call expression of original intrinsic.
Returns
-------
ret : Expr
The translated intrinsic rule.
Return same op if no translation is possible.
See Also
--------
register_intrin_rule : The registeration function for intrin rule.
"""
if op.dtype == "float32":
return call_pure_extern(op.dtype, "%sf" % op.name, *op.args)
elif op.dtype == "float64":
return call_pure_extern(op.dtype, op.name, *op.args)
else:
return op
def _rule_float_direct(op):
"""Intrinsic rule: Directly call pure extern function for floats.
This is an example intrinsic generation rule.
Parameters
----------
op : Expr
The call expression of original intrinsic.
Returns
-------
ret : Expr
The translated intrinsic rule.
Return same op if no translation is possible.
See Also
--------
register_intrin_rule : The registeration function for intrin rule.
"""
if str(op.dtype).startswith("float"):
return call_pure_extern(op.dtype, op.name, *op.args)
else:
return None
# opencl pattern for exp
register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
# default pattern for exp
register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
......@@ -19,6 +19,12 @@ class TensorSlice(SliceBase, _expr.ExprOp):
indices = (indices,)
return TensorSlice(self.tensor, self.indices + indices)
@property
def dtype(self):
"""Data content of the tensor."""
return self.tensor.dtype
itervar_cls = None
@register_node
......
......@@ -73,7 +73,8 @@ REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerIntrin);
} // namespace ir
} // namespace tvm
......@@ -4,6 +4,7 @@
* \brief Common utilities to generated C style code.
*/
#include <tvm/codegen.h>
#include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
......@@ -18,11 +19,16 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
mode = mode.substr(0, pos);
}
std::string build_f_name = "codegen.build_" + mode;
// Lower intrinsic functions
Array<LoweredFunc> func_list;
for (LoweredFunc f : funcs) {
func_list.push_back(ir::LowerIntrin(f, target));
}
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
<< "Target " << target << " is not enabled";
runtime::Module m = (*bf)(funcs, target);
runtime::Module m = (*bf)(func_list, target);
return m;
}
......
......@@ -396,7 +396,17 @@ void CodeGenC::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*)
}
void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(Call::bitwise_and)) {
if (op->call_type == Call::Extern ||
op->call_type == Call::PureExtern) {
os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) {
this->PrintExpr(op->args[i], os);
if (i < op->args.size() - 1) {
os << ", ";
}
}
os << ")";
} else if (op->is_intrinsic(Call::bitwise_and)) {
PrintBinaryIntrinsitc(op, " & ", os, this);
} else if (op->is_intrinsic(Call::bitwise_xor)) {
PrintBinaryIntrinsitc(op, " ^ ", os, this);
......@@ -462,19 +472,18 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
this->PrintExpr(op->args[0], os);
os << " == NULL)";
} else {
os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) {
this->PrintExpr(op->args[i], os);
if (i < op->args.size() - 1) {
os << ", ";
}
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
LOG(FATAL) << "Unresolved intrinsic " << op->name
<< " with return type " << op->type;
} else {
LOG(FATAL) << "Unresolved call type " << op->call_type;
}
os << ")";
}
}
void CodeGenC::PrintVecBinaryOp(
const std::string&op, Type t,
const std::string& op, Type t,
Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*)
if (isalpha(op[0])) {
os << op << "(";
......
......@@ -202,5 +202,7 @@ void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
CodeGenC::VisitStmt_(op);
}
}
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_default.cc
* \brief Default intrinsic rules.
*/
#include "./intrin_rule.h"
namespace tvm {
namespace codegen {
namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
.set_body(DispatchExtern<FloatSuffix>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule.h
* \brief Utility to generate intrinsic rules
*/
#ifndef TVM_CODEGEN_INTRIN_RULE_H_
#define TVM_CODEGEN_INTRIN_RULE_H_
#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/api_registry.h>
#include <tvm/runtime/registry.h>
#include <string>
namespace tvm {
namespace codegen {
namespace intrin {
using namespace ir;
// Add float suffix to the intrinsics
struct FloatSuffix {
std::string operator()(Type t, std::string name) const {
if (t == Float(32)) {
return name + 'f';
} else if (t == Float(64)) {
return name;
} else {
return "";
}
}
};
// Add float suffix to the intrinsics
struct FloatDirect {
std::string operator()(Type t, std::string name) const {
if (t.is_float()) {
return name;
} else {
return "";
}
}
};
// Directly call pure extern function for floats.
template<typename T>
inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
std::string name = T()(call->type, call->name);
if (name.length() != 0) {
*rv = Call::make(
call->type, name, call->args, Call::PureExtern);
} else {
*rv = e;
}
}
} // namespace intrin
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_INTRIN_RULE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_cuda.cc
* \brief CUDA intrinsic rules.
*/
#include "./intrin_rule.h"
namespace tvm {
namespace codegen {
namespace intrin {
// Add float suffix to the intrinsics, CUDA fast math.
struct CUDAFastMath {
std::string operator()(Type t, std::string name) const {
if (t.lanes() == 1) {
if (t.is_float()) {
switch (t.bits()) {
case 64: return name;
case 32: return "__" + name + 'f';
case 16: return 'h' + name;
default: return "";
}
}
}
return "";
}
};
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
.set_body(DispatchExtern<CUDAFastMath>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_opencl.cc
* \brief OpenCL intrinsic rules.
*/
#include "./intrin_rule.h"
namespace tvm {
namespace codegen {
namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<FloatDirect>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log")
.set_body(DispatchExtern<FloatDirect>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
.set_body(DispatchExtern<FloatDirect>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -597,7 +597,19 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
}
llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
if (op->is_intrinsic(Call::bitwise_and)) {
if (op->is_intrinsic("llvm_intrin")) {
std::vector<llvm::Value*> arg_values;
std::vector<llvm::Type*> arg_types;
for (size_t i = 1; i < op->args.size(); ++i) {
llvm::Value* v = MakeValue(op->args[i]);
arg_values.push_back(v);
arg_types.push_back(v->getType());
}
auto id = static_cast<llvm::Intrinsic::ID>(op->args[0].as<UIntImm>()->value);
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, arg_types);
return builder_->CreateCall(f, arg_values);
} else if (op->is_intrinsic(Call::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateAnd(
MakeValue(op->args[0]), MakeValue(op->args[1]));
......
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_llvm.cc
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include <tvm/codegen.h>
#include <string>
#include "./llvm_common.h"
namespace tvm {
namespace codegen {
namespace llvm {
using namespace ir;
template<unsigned id>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(UIntImm::make(UInt(32), id));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = Call::make(
call->type, "llvm_intrin", cargs, Call::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::exp>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>);
} // namespace llvm
} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
......@@ -12,6 +12,7 @@
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/IR/Value.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Constants.h>
......
/*!
* Copyright (c) 2017 by Contributors
* Lower intrinsic calls to device specific ir when possible.
* \file lower_intrin.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/api_registry.h>
#include <unordered_set>
#include "./ir_util.h"
namespace tvm {
namespace ir {
class IntrinInjecter : public IRMutator {
public:
explicit IntrinInjecter(std::string target) {
patterns_.push_back("tvm.intrin.rule." + target + ".");
if (!strncmp(target.c_str(), "llvm", 4) && target != "llvm") {
patterns_.push_back("tvm.intrin.rule.llvm.");
}
patterns_.push_back("tvm.intrin.rule.default.");
}
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
for (size_t i = 0; i < patterns_.size(); ++i) {
std::string& p = patterns_[i];
size_t psize = p.length();
p.resize(psize + op->name.length());
op->name.copy(&p[0] + psize, op->name.length());
const runtime::PackedFunc* f = runtime::Registry::Get(p);
p.resize(psize);
// if pattern exists.
if (f != nullptr) {
Expr r = (*f)(e);
CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) {
return this->Mutate(r);
}
}
}
}
return IRMutator::Mutate_(op, e);
}
private:
std::vector<std::string> patterns_;
};
LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = IntrinInjecter(target).Mutate(n->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
......@@ -34,16 +34,20 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*)
return *this;
}
Registry& Registry::Register(const std::string& name) { // NOLINT(*)
Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*)
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
CHECK(it == m->fmap.end())
<< "Global PackedFunc " << name << " is already registered";
if (it == m->fmap.end()) {
Registry* r = new Registry();
r->name_ = name;
m->fmap[name] = r;
return *r;
} else {
CHECK(override)
<< "Global PackedFunc " << name << " is already registered";
return *it->second;
}
}
bool Registry::Remove(const std::string& name) {
......@@ -89,9 +93,10 @@ struct TVMFuncThreadLocalEntry {
typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
int TVMFuncRegisterGlobal(
const char* name, TVMFunctionHandle f, int override) {
API_BEGIN();
tvm::runtime::Registry::Register(name)
tvm::runtime::Registry::Register(name, override != 0)
.set_body(*static_cast<tvm::runtime::PackedFunc*>(f));
API_END();
}
......@@ -102,7 +107,7 @@ int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
tvm::runtime::Registry::Get(name);
CHECK(fp != nullptr)
<< "Cannot find global function " << name;
*out = (TVMFunctionHandle)(fp); // NOLINT(*)
*out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*)
API_END();
}
......
import tvm
import numpy as np
def test_exp():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: tvm.exp(A(*i)), name='B')
s = tvm.create_schedule(B.op)
# create iter var and assign them tags.
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host):
return
if not tvm.codegen.enabled(device):
return
fexp = tvm.build(s, [A, B],
device, host,
name="myexp")
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
# launch the kernel.
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fexp(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda", "llvm")
check_device("opencl")
def test_log_llvm():
# graph
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: tvm.log(A(*i)), name='B')
s = tvm.create_schedule(B.op)
# create iter var and assign them tags.
bx, tx = s[B].split(B.op.axis[0], factor=32)
# one line to build the function.
if not tvm.codegen.enabled("llvm"):
return
flog = tvm.build(s, [A, B],
"llvm", name="mylog")
ctx = tvm.cpu(0)
# launch the kernel.
n = 1028
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
flog(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.log(a.asnumpy()), rtol=1e-5)
def test_add():
# graph
n = tvm.convert(1024)
......@@ -46,4 +106,6 @@ def test_add():
if __name__ == "__main__":
test_log_llvm()
test_exp()
test_add()
"""
Intrinsics and Math Functions
=============================
**Author**: `Tianqi Chen <https://tqchen.github.io>`_
While tvm support basic arithmetic operations. In many cases
usually we will need more complicated buildin functions.
For example :code:`exp` to take the exponetial of the function.
These functions are target system dependent and may have different
names of different target platforms. In this tutorial, we will learn
how we can invoke these target specific functions, and how we can unify
the interface via tvm's intrinsic API.
"""
from __future__ import absolute_import, print_function
import tvm
import numpy as np
######################################################################
# Direct Declare Extern Math Call
# -------------------------------
# The most straight-forward way to call target specific function is via
# extern function call construct in tvm.
# In th following example, we use :any:`tvm.call_pure_extern` to call
# :code:`__expf` function, which is only available under CUDA.
#
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape,
lambda i: tvm.call_pure_extern("float32", "__expf", A[i]),
name="B")
s = tvm.create_schedule(B.op)
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
f = tvm.build(s, [A, B], "cuda", name="myexp")
print(f.imported_modules[0].get_source())
######################################################################
# Unified Intrinsic Call
# ----------------------
# The above code verifies that direct external call can be used to
# call into device specific functions.
# However, the above way only works for CUDA target with float type.
# Ideally, we want to write same code for any device and any data type.
#
# TVM intrinsic provides the user a mechanism to achieve this, and this
# is the recommended way to solve the problem.
# The following code use tvm.exp instead, which create an intrinsic call
# :any:`tvm.exp` to do the exponential.
#
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda i: tvm.exp(A[i]), name="B")
s = tvm.create_schedule(B.op)
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
fcuda = tvm.build(s, [A, B], "cuda", name="myexp")
print(fcuda.imported_modules[0].get_source())
######################################################################
# We can find that the code works for both CUDA and opencl.
# The same tvm.exp can also be used for float64 data types.
#
tvm.module.init_opencl()
fopencl = tvm.build(s, [A, B], "opencl", name="myexp")
print(fopencl.imported_modules[0].get_source())
######################################################################
# Intrinsic Lowering Rule
# -----------------------
# When :any:`tvm.exp` is called, TVM creates an intrinsic Call Expr.
# TVM uses transformation rules to transform the intrinsic
# call to device specific extern calls.
#
# TVM also allows user to customize the rules during runtime.
# The following example customizes CUDA lowering rule for :code:`exp`.
#
def my_cuda_math_rule(op):
"""Customized CUDA intrinsic lowering rule"""
assert isinstance(op, tvm.expr.Call)
if op.dtype == "float32":
# call float function
return tvm.call_pure_extern("float32", "%sf" % op.name, op.args[0])
elif op.dtype == "float64":
# call double function
return tvm.call_pure_extern("float32", op.name, op.args[0])
else:
# cannot do translation, return self.
return op
tvm.register_intrin_rule("cuda", "exp", my_cuda_math_rule, override=True)
######################################################################
# Register the rule to TVM with override option to override existing rule.
# Notice the difference between the printed code from previous one:
# our new rule uses math function :code:`expf` instead of
# fast math version :code:`__expf`.
#
fcuda = tvm.build(s, [A, B], "cuda", name="myexp")
print(fcuda.imported_modules[0].get_source())
######################################################################
# Add Your Own Intrinsic
# ----------------------
# If there is an instrinsic that is not provided by TVM.
# User can easily add new intrinsic by using the intrinsic rule system.
# The following example add an intrinsic :code:`mylog` to the system.
#
def mylog(x):
"""customized log intrinsic function"""
return tvm.call_pure_intrin(x.dtype, "mylog", x)
def my_cuda_mylog_rule(op):
"""CUDA lowering rule for log"""
if op.dtype == "float32":
return tvm.call_pure_extern("float32", "logf", op.args[0])
elif op.dtype == "float64":
return tvm.call_pure_extern("float32", "log", op.args[0])
else:
return op
tvm.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True)
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda i: mylog(A[i]), name="B")
s = tvm.create_schedule(B.op)
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
fcuda = tvm.build(s, [A, B], "cuda", name="mylog")
print(fcuda.imported_modules[0].get_source())
######################################################################
# Summary
# -------
# - TVM call call extern target dependent math function.
# - Use intrinsic to defined a unified interface for the functions.
# - For more intrinsics available in tvm, take a look at :any:`tvm.intrin`
# - You can customize the intrinsic behavior by defining your own rules.
#
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment