Commit 7bcb3f53 by Tianqi Chen Committed by ziheng

[REFACTOR] collections->container, RPC returns func, time_evaluator r… (#244)

* [REFACTOR] collections->container, RPC returns func, time_evaluator returns struct

* fix executor
parent c324494f
......@@ -37,7 +37,7 @@ def bind(g, ctx):
def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name)
return f if isinstance(
f, (tvm.collections.Array, tuple, list)) else [f]
f, (tvm.container.Array, tuple, list)) else [f]
@tvm.register_func("tvm_graph.build_target")
......
tvm.collections
---------------
.. automodule:: tvm.collections
:members:
tvm.container
-------------
.. automodule:: tvm.container
:members:
......@@ -9,7 +9,7 @@ from . import stmt
from . import make
from . import ir_pass
from . import codegen
from . import collections
from . import container
from . import schedule
from . import module
from . import node
......
......@@ -15,7 +15,7 @@ from . import make as _make
from . import expr as _expr
from . import tensor as _tensor
from . import schedule as _schedule
from . import collections as _collections
from . import container as _container
from . import tag as _tag
int32 = "int32"
......@@ -493,7 +493,7 @@ def _IterVar(dom, name, iter_type, thread_tag=''):
raise TypeError("need to be list of ranges")
dom = Range(dom[0], dom[1])
if not isinstance(dom, _collections.Range):
if not isinstance(dom, _container.Range):
raise TypeError("dom need to be Range")
name = name if name else 'iter'
v = var(name)
......@@ -628,7 +628,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
expr = convert(expr)
if isinstance(expr, _collections.Array):
if isinstance(expr, _container.Array):
size = len(expr)
larr = []
rarr = []
......
......@@ -9,7 +9,7 @@ from . import tensor
from . import schedule
from . import expr
from . import ir_pass
from . import collections
from . import container
from . import module
from . import codegen
from . import ndarray
......@@ -276,19 +276,19 @@ def build(sch,
flist = lower(sch, args,
name=name,
binds=binds)
if isinstance(flist, collections.LoweredFunc):
if isinstance(flist, container.LoweredFunc):
flist = [flist]
elif isinstance(sch, collections.LoweredFunc):
elif isinstance(sch, container.LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc")
flist = [sch]
elif isinstance(sch, (list, tuple, collections.Array)):
elif isinstance(sch, (list, tuple, container.Array)):
flist = sch
else:
raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc")
fname_set = set()
for x in flist:
if not isinstance(x, collections.LoweredFunc):
if not isinstance(x, container.LoweredFunc):
raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc")
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
......@@ -296,7 +296,7 @@ def build(sch,
fhost = []
fdevice = []
for func in flist:
if func.func_type == collections.LoweredFunc.MixedFunc:
if func.func_type == container.LoweredFunc.MixedFunc:
if BuildConfig.current.detect_global_barrier:
func = ir_pass.StorageSync(func, "global")
func = ir_pass.StorageSync(func, "shared")
......@@ -306,9 +306,9 @@ def build(sch,
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == collections.LoweredFunc.HostFunc:
elif func.func_type == container.LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == collections.LoweredFunc.DeviceFunc:
elif func.func_type == container.LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
......
"""Collections contains data structures used in TVM DSL."""
"""Container data structures used in TVM DSL."""
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from . import _api_internal
......
......@@ -6,7 +6,7 @@ from . import stmt as _stmt
from . import expr as _expr
from . import make as _make
from . import ir_pass as _pass
from . import collections as _collections
from . import container as _container
from ._ffi.base import string_types
from ._ffi.node import NodeGeneric
from .expr import Call as _Call
......@@ -280,7 +280,7 @@ class IRBuilder(object):
The buffer var representing the buffer.
"""
buffer_var = _api.var(name, dtype="handle")
if not isinstance(shape, (list, tuple, _collections.Array)):
if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape]
if scope:
self.scope_attr(buffer_var, "storage_scope", scope)
......
"""Container of compiled functions of TVM."""
from __future__ import absolute_import as _abs
from collections import namedtuple
from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api
from .contrib import cc_compiler as _cc, util as _util
ProfileResult = namedtuple("ProfileResult", ["mean"])
class Module(ModuleBase):
"""Module container of all TVM generated functions"""
def __repr__(self):
......@@ -120,8 +124,14 @@ class Module(ModuleBase):
and return a float representing seconds per function call.
"""
try:
return _RPCTimeEvaluator(
feval = _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number)
def evaluator(*args):
"""Internal wrapped evaluator."""
# Wrap feval so we can add more stats in future.
mean = feval(*args)
return ProfileResult(mean=mean)
return evaluator
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")
......
......@@ -4,7 +4,7 @@ from ._ffi.node import NodeBase, register_node
from . import _api_internal
from . import tensor as _tensor
from . import expr as _expr
from . import collections as _collections
from . import container as _container
from ._ffi.function import _init_api
@register_node
......@@ -74,7 +74,7 @@ def create_schedule(ops):
sch : schedule.Schedule
The created schedule.
"""
if not isinstance(ops, (list, _collections.Array)):
if not isinstance(ops, (list, _container.Array)):
ops = [ops]
return _api_internal._CreateSchedule(ops)
......
......@@ -13,16 +13,28 @@ namespace runtime {
// Wrapped remote function to packed func.
struct RPCWrappedFunc {
public:
RPCWrappedFunc(void* handle, std::shared_ptr<RPCSession> sess)
: handle_(handle), sess_(sess) {}
RPCWrappedFunc(void* handle,
std::shared_ptr<RPCSession> sess)
: handle_(handle), sess_(sess) {
fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
WrapRemote(sess, args.values[0].v_handle, args.type_codes[0], rv);
});
}
void operator()(TVMArgs args, TVMRetValue *rv) const {
sess_->CallFunc(handle_, args, rv);
sess_->CallFunc(handle_, args, rv, &fwrap_);
}
~RPCWrappedFunc() {
sess_->CallRemote(RPCCode::kFreeFunc, handle_);
}
static void WrapRemote(std::shared_ptr<RPCSession> sess,
void* handle,
int tcode,
TVMRetValue* rv);
private:
PackedFunc fwrap_;
void* handle_{nullptr};
std::shared_ptr<RPCSession> sess_;
};
......@@ -94,8 +106,28 @@ class RPCModuleNode final : public ModuleNode {
void* module_handle_{nullptr};
// The local channel
std::shared_ptr<RPCSession> sess_;
// Wrap function to wrap remote module/function.
PackedFunc fwrap_;
};
void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
void* handle,
int tcode,
TVMRetValue *rv) {
if (handle == nullptr) return;
if (tcode == kFuncHandle) {
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess);
*rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv);
});
} else {
CHECK_EQ(tcode, kModuleHandle);
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(handle, sess);
*rv = Module(n);
}
}
Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(nullptr, sess);
......
......@@ -69,7 +69,10 @@ class RPCSession::EventHandler {
void FinishCopyAck() {
this->SwitchToState(kRecvCode);
}
RPCCode HandleNextEvent(TVMRetValue* rv) {
RPCCode HandleNextEvent(TVMRetValue* rv,
bool client_mode,
const PackedFunc* fwrap) {
std::swap(client_mode_, client_mode);
while (this->Ready()) {
switch (state_) {
case kRecvCode: HandleRecvCode(); break;
......@@ -110,19 +113,29 @@ class RPCSession::EventHandler {
case kReturnReceived: {
CHECK_EQ(arg_buf_->value.size(), 1U);
TVMArgValue argv = arg_buf_->AsTVMArgs()[0];
if (argv.type_code() == kFuncHandle ||
argv.type_code() == kModuleHandle) {
CHECK(fwrap != nullptr) << "function/module wrapper not available";
fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv);
} else {
*rv = argv;
}
arg_buf_.reset();
this->SwitchToState(kRecvCode);
std::swap(client_mode_, client_mode);
return RPCCode::kReturn;
}
case kCopyAckReceived: {
std::swap(client_mode_, client_mode);
return RPCCode::kCopyAck;
}
case kShutdownReceived: {
std::swap(client_mode_, client_mode);
return RPCCode::kShutdown;
}
}
}
std::swap(client_mode_, client_mode);
return RPCCode::kNone;
}
// Reset and clear all states.
......@@ -161,6 +174,8 @@ class RPCSession::EventHandler {
writer_->Write(&value, sizeof(TVMValue));
break;
}
case kFuncHandle:
case kModuleHandle:
case kHandle: {
// always send handle in 64 bit.
uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
......@@ -231,6 +246,8 @@ class RPCSession::EventHandler {
int arg_index_;
// The stage of each argument receiver.
int arg_recv_stage_;
// Whether current handler is client or server mode.
bool client_mode_{false};
// Argument buffer
std::unique_ptr<RPCArgBuffer> arg_buf_;
// Temp byte buffer.
......@@ -305,7 +322,15 @@ class RPCSession::EventHandler {
case kHandle:
case kStr:
case kBytes:
case kTVMContext: this->RequestBytes(sizeof(TVMValue)); break;
case kTVMContext: {
this->RequestBytes(sizeof(TVMValue)); break;
}
case kFuncHandle:
case kModuleHandle: {
CHECK(client_mode_)
<< "Only client can receive remote functions";
this->RequestBytes(sizeof(TVMValue)); break;
}
case kNull: break;
case kArrayHandle: {
this->RequestBytes(sizeof(uint64_t));
......@@ -337,6 +362,8 @@ class RPCSession::EventHandler {
this->SwitchToState(kRecvPackedSeqArg);
break;
}
case kFuncHandle:
case kModuleHandle:
case kHandle: {
// always send handle in 64 bit.
uint64_t handle;
......@@ -558,6 +585,13 @@ class RPCSession::EventHandler {
ret_value.v_handle = &arr;
ret_tcode = kBytes;
SendPackedSeq(&ret_value, &ret_tcode, 1);
} else if (rv.type_code() == kFuncHandle ||
rv.type_code() == kModuleHandle) {
// always send handle in 64 bit.
CHECK(!client_mode_)
<< "Only server can send function and module handle back.";
rv.MoveToCHost(&ret_value, &ret_tcode);
SendPackedSeq(&ret_value, &ret_tcode, 1);
} else {
ret_value = rv.value();
ret_tcode = rv.type_code();
......@@ -634,7 +668,8 @@ struct RPCSessTable {
std::array<std::weak_ptr<RPCSession>, kMaxRPCSession> tbl_;
};
RPCCode RPCSession::HandleUntilReturnEvent(TVMRetValue* rv) {
RPCCode RPCSession::HandleUntilReturnEvent(
TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap) {
RPCCode code = RPCCode::kCallFunc;
while (code != RPCCode::kReturn &&
code != RPCCode::kShutdown &&
......@@ -657,7 +692,7 @@ RPCCode RPCSession::HandleUntilReturnEvent(TVMRetValue* rv) {
}
}
}
code = handler_->HandleNextEvent(rv);
code = handler_->HandleNextEvent(rv, client_mode, fwrap);
}
return code;
}
......@@ -668,7 +703,7 @@ void RPCSession::Init() {
// Quick function to call remote.
call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
RPCCode code = HandleUntilReturnEvent(rv);
RPCCode code = HandleUntilReturnEvent(rv, true, nullptr);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
});
}
......@@ -712,7 +747,7 @@ void RPCSession::Shutdown() {
void RPCSession::ServerLoop() {
std::lock_guard<std::recursive_mutex> lock(mutex_);
TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv) == RPCCode::kShutdown);
CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown);
LOG(INFO) << "Shutdown...";
channel_.reset(nullptr);
}
......@@ -722,7 +757,7 @@ bool RPCSession::ServerOnMessageHandler(const std::string& bytes) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
reader_.Write(bytes.c_str(), bytes.length());
TVMRetValue rv;
RPCCode code = handler_->HandleNextEvent(&rv);
RPCCode code = handler_->HandleNextEvent(&rv, false, nullptr);
while (writer_.bytes_available() != 0) {
writer_.ReadWithCallback([this](const void *data, size_t size) {
return channel_->Send(data, size);
......@@ -733,13 +768,18 @@ bool RPCSession::ServerOnMessageHandler(const std::string& bytes) {
}
// Get remote function with name
void RPCSession::CallFunc(void* h, TVMArgs args, TVMRetValue* rv) {
void RPCSession::CallFunc(void* h,
TVMArgs args,
TVMRetValue* rv,
const PackedFunc* fwrap) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
RPCCode code = RPCCode::kCallFunc;
writer_.Write(&code, sizeof(code));
uint64_t handle = reinterpret_cast<uint64_t>(h);
writer_.Write(&handle, sizeof(handle));
call_remote_.CallPacked(args, rv);
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
code = HandleUntilReturnEvent(rv, true, fwrap);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
}
void RPCSession::CopyToRemote(void* from,
......@@ -761,7 +801,7 @@ void RPCSession::CopyToRemote(void* from,
writer_.Write(&ctx_to, sizeof(ctx_to));
writer_.Write(reinterpret_cast<char*>(from) + from_offset, data_size);
TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv) == RPCCode::kReturn);
CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn);
}
void RPCSession::CopyFromRemote(void* from,
......@@ -782,7 +822,7 @@ void RPCSession::CopyFromRemote(void* from,
writer_.Write(&size, sizeof(size));
writer_.Write(&ctx_from, sizeof(ctx_from));
TVMRetValue rv;
CHECK(HandleUntilReturnEvent(&rv) == RPCCode::kCopyAck);
CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck);
reader_.Reserve(data_size);
while (reader_.bytes_available() < data_size) {
size_t bytes_needed = data_size - reader_.bytes_available();
......
......@@ -95,13 +95,16 @@ class RPCSession {
bool ServerOnMessageHandler(const std::string& bytes);
/*!
* \brief Call into remote function
* \param sptr_to_self shared_ptr to self.
* \param handle The function handle
* \param args The arguments
* \param rv The return value.
* \param fwrapper Wrapper function to turn Function/Module handle into real return.
*/
void CallFunc(RPCFuncHandle handle,
TVMArgs args,
TVMRetValue* rv);
TVMRetValue* rv,
const PackedFunc* fwrap);
/*!
* \brief Copy bytes into remote array content.
* \param from The source host data.
......@@ -146,6 +149,7 @@ class RPCSession {
int nstep);
/*!
* \brief Call a remote defined system function with arguments.
* \param sptr_to_self shared_ptr to self.
* \param fcode The function code.
* \param args The arguments
* \return The returned remote value.
......@@ -178,7 +182,8 @@ class RPCSession {
class EventHandler;
// Handle events until receives a return
// Also flushes channels so that the function advances.
RPCCode HandleUntilReturnEvent(TVMRetValue* rv);
RPCCode HandleUntilReturnEvent(
TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap);
// Initalization
void Init();
// Shutdown
......@@ -191,7 +196,7 @@ class RPCSession {
common::RingBuffer reader_, writer_;
// Event handler.
std::shared_ptr<EventHandler> handler_;
// call remote with the specified function coede.
// call remote with specified function code.
PackedFunc call_remote_;
// The index of this session in RPC session table.
int table_index_{0};
......
......@@ -95,8 +95,8 @@ def test_add():
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
vbias = np.random.uniform()
vscale = np.random.uniform()
ftimer = fadd.time_evaluator(fadd.entry_name, ctx, number=1000)
tcost = ftimer(a, b, c, vbias, vscale)
ftimer = fadd.time_evaluator(fadd.entry_name, ctx, number=10)
tcost = ftimer(a, b, c, vbias, vscale).mean
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6)
......
......@@ -78,8 +78,8 @@ def test_gemm():
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
ftimer = f.time_evaluator(f.entry_name, ctx, number=20)
tcost = ftimer(a, b, c)
ftimer = f.time_evaluator(f.entry_name, ctx, number=1)
tcost = ftimer(a, b, c).mean
print("%s: exec=%g sec/op" % (ctx, tcost))
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
......
......@@ -11,7 +11,7 @@ def test_flatten2():
xo, xi = s[A2].split(A2.op.axis[0], 8)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
......
......@@ -11,7 +11,7 @@ def test_storage_share():
s = tvm.create_schedule(B.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
......@@ -47,7 +47,7 @@ def test_storage_share_gpu():
s[A[2*t+1]].set_scope("shared")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A[0].shape, A[0].dtype, name='A')
Bb = tvm.decl_buffer(A[0].shape, A[0].dtype, name='B')
......
......@@ -15,7 +15,7 @@ def test_storage_sync():
s[A1].set_scope("shared")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
......
......@@ -14,7 +14,7 @@ def test_virtual_thread():
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
......
......@@ -11,7 +11,7 @@ def test_rpc_simple():
def addone(x):
return x + 1
@tvm.register_func("rpc.test.strcat")
def addone(name, x):
def strcat(name, x):
return "%s:%d" % (name, x)
@tvm.register_func("rpc.test.except")
......@@ -83,16 +83,26 @@ def test_rpc_remote_module():
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
cost = time_f(a, b)
cost = time_f(a, b).mean
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote()
def test_rpc_return_func():
@tvm.register_func("rpc.test.remote_func")
def addone(x):
return lambda y: x+y
server = rpc.Server("localhost")
client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.remote_func")
fadd = f1(10)
assert fadd(12) == 22
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
test_rpc_return_func()
test_rpc_file_exchange()
exit(0)
test_rpc_array()
test_rpc_remote_module()
test_rpc_simple()
......@@ -11,7 +11,7 @@ def test_bound1():
xo, xi = s[A2].split(s[A2].op.axis[0], 8)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
assert(bounds[A1.op.axis[0]].extent.value == 8)
def test_bound2():
......@@ -26,7 +26,7 @@ def test_bound2():
_ = s.normalize()
s[A1].compute_at(s[A2], yo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
assert(bounds[A1.op.axis[0]].extent.value == 8)
assert(bounds[A1.op.axis[1]].extent.value == 8)
......@@ -49,7 +49,7 @@ def test_bound3():
s[A1].compute_at(s[A2], yo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)
......
......@@ -9,7 +9,7 @@ def test_schedule0():
s = tvm.create_schedule(A1.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule1():
......@@ -21,7 +21,7 @@ def test_schedule1():
s = tvm.create_schedule(A1.op)
xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
......@@ -36,7 +36,7 @@ def test_schedule2():
xo, xi = s[A2].split(A2.op.axis[0], 8)
s[A1].compute_at(s[A2], xo)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
......@@ -79,7 +79,7 @@ def test_schedule_const_bound():
s = tvm.create_schedule(A1.op)
xo, xi = s[A1].split(A1.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
......
......@@ -2,6 +2,8 @@
export PYTHONPATH=python
rm -rf python/tvm/*.pyc
TVM_FFI=ctypes python -m nose -v tests/python/unittest || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/unittest || exit -1
make cython || exit -1
......
......@@ -40,7 +40,7 @@ def test_rpc_array():
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
cost = time_f(a, b)
cost = time_f(a, b).mean
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote()
......
......@@ -115,13 +115,13 @@ def test_depthwise_conv2d_map():
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
# Measure time cost of kernel 1 (depthwise_conv2d)
timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=10000)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
# Measure time cost of kernel 2 (depthwise_conv2d + scale_shift)
timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=10000)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
# Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=10000)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
print("Input shape = " + str(get_const_tuple(Input.shape)))
print("Filter shape = " + str(get_const_tuple(Filter.shape)))
print("Stride = (%d, %d)" % (stride_h, stride_w))
......
......@@ -79,13 +79,13 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
# launch kernel 1 (depthwise_conv2d)
timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
# launch kernel 2 (depthwise_conv2d + scale_shift)
timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
# launch kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
# correctness with scipy
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
......
......@@ -151,7 +151,7 @@ np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
# device and returns the measured cost.
#
time_f = f.time_evaluator(f.entry_name, ctx, number=10)
cost = time_f(a, b)
cost = time_f(a, b).mean
print('%g secs/op' % cost)
# terminate the server after experiment
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment