Commit 57a74936 by tqchen

Rename dim_var to axis, update testcases

parent ff26cd68
...@@ -129,12 +129,14 @@ TVM_DLL int TVMNodeFree(NodeHandle handle); ...@@ -129,12 +129,14 @@ TVM_DLL int TVMNodeFree(NodeHandle handle);
* \param handle The node handle * \param handle The node handle
* \param key The attribute name * \param key The attribute name
* \param out_value The attribute value * \param out_value The attribute value
* \param out_typeid The typeif of the attribute. * \param out_typeid The typeid of the attribute.
* \param out_success Whether get is successful.
*/ */
TVM_DLL int TVMNodeGetAttr(NodeHandle handle, TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
const char* key, const char* key,
ArgVariant* out_value, ArgVariant* out_value,
int* out_typeid); int* out_typeid,
int* out_success);
/*! /*!
* \brief get attributes names in the node. * \brief get attributes names in the node.
......
...@@ -17,8 +17,8 @@ namespace tvm { ...@@ -17,8 +17,8 @@ namespace tvm {
*/ */
class ComputeOpNode : public OperationNode { class ComputeOpNode : public OperationNode {
public: public:
/*! \brief Iteration variables over the dimensions */ /*! \brief IterVar on each axis */
Array<IterVar> dim_var; Array<IterVar> axis;
/*! \brief the compute expression */ /*! \brief the compute expression */
Expr body; Expr body;
/*! \brief constructor */ /*! \brief constructor */
...@@ -34,11 +34,11 @@ class ComputeOpNode : public OperationNode { ...@@ -34,11 +34,11 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("dim_var", &dim_var); v->Visit("axis", &axis);
v->Visit("body", &body); v->Visit("body", &body);
} }
static Operation make(std::string name, static Operation make(std::string name,
Array<IterVar> dim_var, Array<IterVar> axis,
Expr body); Expr body);
static constexpr const char* _type_key = "ComputeOp"; static constexpr const char* _type_key = "ComputeOp";
......
...@@ -72,10 +72,18 @@ class NodeBase(object): ...@@ -72,10 +72,18 @@ class NodeBase(object):
def __getattr__(self, name): def __getattr__(self, name):
ret_val = ArgVariant() ret_val = ArgVariant()
ret_typeid = ctypes.c_int() ret_typeid = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr( check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name), self.handle, c_str(name),
ctypes.byref(ret_val), ctypes.byref(ret_typeid))) ctypes.byref(ret_val),
return RET_SWITCH[ret_typeid.value](ret_val) ctypes.byref(ret_typeid),
ctypes.byref(ret_success)))
value = RET_SWITCH[ret_typeid.value](ret_val)
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return value
def __hash__(self): def __hash__(self):
return _function_internal._raw_ptr(self) return _function_internal._raw_ptr(self)
......
...@@ -37,6 +37,7 @@ using TVMAPINode = std::shared_ptr<Node>; ...@@ -37,6 +37,7 @@ using TVMAPINode = std::shared_ptr<Node>;
struct APIAttrGetter : public AttrVisitor { struct APIAttrGetter : public AttrVisitor {
std::string skey; std::string skey;
APIVariantValue* ret; APIVariantValue* ret;
bool found_node_ref{false};
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final {
if (skey == key) *ret = value[0]; if (skey == key) *ret = value[0];
...@@ -62,7 +63,10 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -62,7 +63,10 @@ struct APIAttrGetter : public AttrVisitor {
if (skey == key) *ret = value[0]; if (skey == key) *ret = value[0];
} }
void Visit(const char* key, NodeRef* value) final { void Visit(const char* key, NodeRef* value) final {
if (skey == key) *ret = value[0]; if (skey == key) {
*ret = value[0];
found_node_ref = true;
}
} }
}; };
...@@ -198,7 +202,8 @@ int TVMNodeFree(NodeHandle handle) { ...@@ -198,7 +202,8 @@ int TVMNodeFree(NodeHandle handle) {
int TVMNodeGetAttr(NodeHandle handle, int TVMNodeGetAttr(NodeHandle handle,
const char* key, const char* key,
ArgVariant* ret_val, ArgVariant* ret_val,
int* ret_typeid) { int* ret_typeid,
int* ret_success) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
ret->ret_value.type_id = kNull; ret->ret_value.type_id = kNull;
...@@ -209,11 +214,14 @@ int TVMNodeGetAttr(NodeHandle handle, ...@@ -209,11 +214,14 @@ int TVMNodeGetAttr(NodeHandle handle,
if (getter.skey == "type_key") { if (getter.skey == "type_key") {
ret_val->v_str = (*tnode)->type_key(); ret_val->v_str = (*tnode)->type_key();
*ret_typeid = kStr; *ret_typeid = kStr;
*ret_success = 1;
} else { } else {
(*tnode)->VisitAttrs(&getter); (*tnode)->VisitAttrs(&getter);
if (ret->ret_value.type_id != kNull) { if (ret->ret_value.type_id != kNull) {
ret->SetReturn(ret_val, ret_typeid); ret->SetReturn(ret_val, ret_typeid);
*ret_success = 1;
} else { } else {
*ret_success = getter.found_node_ref ? 1 : 0;
*ret_typeid = kNull; *ret_typeid = kNull;
} }
} }
......
...@@ -13,10 +13,10 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); ...@@ -13,10 +13,10 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc } // namespace dmlc
namespace tvm { namespace tvm {
Range::Range(Expr begin, Expr end) Range::Range(Expr begin, Expr end)
: Range(std::make_shared<Halide::IR::RangeNode>(begin, end - begin)) { : Range(std::make_shared<Halide::IR::RangeNode>(
// TODO(tqchen) add simplify to end - begin begin,
is_zero(begin) ? end : (end - begin))) {
} }
Range Range::make_with_min_extent(Expr min, Expr extent) { Range Range::make_with_min_extent(Expr min, Expr extent) {
......
...@@ -18,27 +18,27 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) { ...@@ -18,27 +18,27 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
size_t ndim = shape.size(); size_t ndim = shape.size();
std::vector<IterVar> dim_var; std::vector<IterVar> axis;
std::vector<Var> args; std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) { for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os; std::ostringstream os;
os << "dim_var" << i; os << "ax" << i;
dim_var.push_back(IterVar(Range(0, shape[i]), os.str())); axis.emplace_back(IterVar(Range(0, shape[i]), os.str()));
args.push_back(dim_var.back()->var); args.push_back(axis.back()->var);
} }
op_node->dim_var = Array<IterVar>(dim_var); op_node->axis = Array<IterVar>(axis);
op_node->body = fcompute(args); op_node->body = fcompute(args);
op_node->name = name; op_node->name = name;
return Operation(op_node).output(0); return Operation(op_node).output(0);
} }
Operation ComputeOpNode::make(std::string name, Operation ComputeOpNode::make(std::string name,
Array<IterVar> dim_var, Array<IterVar> axis,
Expr body) { Expr body) {
auto n = std::make_shared<ComputeOpNode>(); auto n = std::make_shared<ComputeOpNode>();
n->name = name; n->name = name;
n->dim_var = dim_var; n->axis = axis;
n->body = body; n->body = body;
return Operation(n); return Operation(n);
} }
...@@ -54,7 +54,7 @@ Tensor Operation::output(size_t i) const { ...@@ -54,7 +54,7 @@ Tensor Operation::output(size_t i) const {
} }
Array<IterVar> ComputeOpNode::root_iter_vars() const { Array<IterVar> ComputeOpNode::root_iter_vars() const {
return dim_var; return axis;
} }
std::string ComputeOpNode::output_name(size_t i) const { std::string ComputeOpNode::output_name(size_t i) const {
...@@ -70,8 +70,8 @@ Type ComputeOpNode::output_dtype(size_t i) const { ...@@ -70,8 +70,8 @@ Type ComputeOpNode::output_dtype(size_t i) const {
Array<Expr> ComputeOpNode::output_shape(size_t i) const { Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U); CHECK_EQ(i, 0U);
std::vector<Expr> shape; std::vector<Expr> shape;
for (size_t i = 0; i < dim_var.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = dim_var[i]->dom; const Range& r = axis[i]->dom;
shape.push_back(r->extent); shape.push_back(r->extent);
} }
return Array<Expr>(shape); return Array<Expr>(shape);
......
...@@ -30,7 +30,15 @@ def test_attr(): ...@@ -30,7 +30,15 @@ def test_attr():
stmt = tvm.make.AttrStmt( stmt = tvm.make.AttrStmt(
y, "stride", 10, tvm.make.Evaluate(x + 1)); y, "stride", 10, tvm.make.Evaluate(x + 1));
assert stmt.node == y assert stmt.node == y
print(stmt)
a = tvm.convert(1)
assert a.value == 1
try:
a.no_field
assert False
except AttributeError:
pass
def test_basic(): def test_basic():
a = tvm.Var('a') a = tvm.Var('a')
...@@ -48,7 +56,6 @@ def test_stmt(): ...@@ -48,7 +56,6 @@ def test_stmt():
if __name__ == "__main__": if __name__ == "__main__":
test_attr() test_attr()
test_const() test_const()
test_make() test_make()
test_ir() test_ir()
......
...@@ -8,11 +8,11 @@ def test_bound1(): ...@@ -8,11 +8,11 @@ def test_bound1():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op) sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op) sA2 = tvm.Schedule(A2.op)
xo, xi = sA2.split(A2.op.dim_var[0], 8) xo, xi = sA2.split(A2.op.axis[0], 8)
sA1.compute_at(sA2, xo) sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2) bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.dim_var[0]].extent.value == 8) assert(bounds[A1.op.axis[0]].extent.value == 8)
def test_bound2(): def test_bound2():
m = tvm.Var('m') m = tvm.Var('m')
...@@ -22,12 +22,12 @@ def test_bound2(): ...@@ -22,12 +22,12 @@ def test_bound2():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op) sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op) sA2 = tvm.Schedule(A2.op)
xo, yo, xi, yi = sA2.tile(A2.op.dim_var[0], A2.op.dim_var[1], 8, 8) xo, yo, xi, yi = sA2.tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
sA1.compute_at(sA2, yo) sA1.compute_at(sA2, yo)
bounds = tvm.schedule.InferBound(sA2) bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.dim_var[0]].extent.value == 8) assert(bounds[A1.op.axis[0]].extent.value == 8)
assert(bounds[A1.op.dim_var[1]].extent.value == 8) assert(bounds[A1.op.axis[1]].extent.value == 8)
def test_bound3(): def test_bound3():
m = tvm.Var('m') m = tvm.Var('m')
...@@ -38,16 +38,16 @@ def test_bound3(): ...@@ -38,16 +38,16 @@ def test_bound3():
sA1 = tvm.Schedule(A1.op, scope="shared") sA1 = tvm.Schedule(A1.op, scope="shared")
sA2 = tvm.Schedule(A2.op) sA2 = tvm.Schedule(A2.op)
thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x") thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x")
xo, xi = sA2.split(A2.op.dim_var[0], 32) xo, xi = sA2.split(A2.op.axis[0], 32)
xi0, xi1 = sA2.split(xi, outer=thread_x) xi0, xi1 = sA2.split(xi, outer=thread_x)
yo, yi = sA2.split(A2.op.dim_var[1], 16) yo, yi = sA2.split(A2.op.axis[1], 16)
sA2.reorder(xo, xi0, yo, xi1, yi) sA2.reorder(xo, xi0, yo, xi1, yi)
sA1.compute_at(sA2, yo) sA1.compute_at(sA2, yo)
bounds = tvm.schedule.InferBound(sA2) bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map) assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.dim_var[0]].extent.value==32) assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.dim_var[1]].extent.value==16) assert(bounds[A1.op.axis[1]].extent.value==16)
def test_create_read_graph(): def test_create_read_graph():
......
...@@ -3,11 +3,10 @@ import tvm ...@@ -3,11 +3,10 @@ import tvm
def test_inline(): def test_inline():
m = tvm.Var('m') m = tvm.Var('m')
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A(i) + 10, name='T') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
X = T(100) stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.make.Evaluate(T(10) + 11 * T(100))
stmt = tvm.ir_pass.Inline( stmt = tvm.ir_pass.Inline(
T, T.op.dim_var, T.op.body, stmt) T, [x.var for x in T.op.axis], T.op.body, stmt)
print(stmt) print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt)) assert(tvm.ir_pass.VerifySSA(stmt))
......
...@@ -12,14 +12,14 @@ def test_schedule_create(): ...@@ -12,14 +12,14 @@ def test_schedule_create():
sch_T = tvm.Schedule(T.op, scope="shared") sch_T = tvm.Schedule(T.op, scope="shared")
sch_A = tvm.Schedule(AA.op, scope="global") sch_A = tvm.Schedule(AA.op, scope="global")
xo, xi = sch_T.split(T.op.dim_var[0], factor=10) xo, xi = sch_T.split(T.op.axis[0], factor=10)
xi1, xi2 = sch_T.split(xi, factor=2) xi1, xi2 = sch_T.split(xi, factor=2)
sch_A.compute_at(sch_T, xi1) sch_A.compute_at(sch_T, xi1)
xo, xi = sch_A.split(AA.op.dim_var[0], factor=10) xo, xi = sch_A.split(AA.op.axis[0], factor=10)
sch_T.reorder(xi2, xi1) sch_T.reorder(xi2, xi1)
assert T.op.dim_var[1] in sch_T.leaf_iter_vars assert T.op.axis[1] in sch_T.leaf_iter_vars
def test_reorder(): def test_reorder():
m = tvm.Var('m') m = tvm.Var('m')
...@@ -27,7 +27,7 @@ def test_reorder(): ...@@ -27,7 +27,7 @@ def test_reorder():
T = tvm.compute(m, lambda i: A[i+1]) T = tvm.compute(m, lambda i: A[i+1])
sch_T = tvm.Schedule(T.op, scope="shared") sch_T = tvm.Schedule(T.op, scope="shared")
xo, xi = sch_T.split(T.op.dim_var[0], factor=10) xo, xi = sch_T.split(T.op.axis[0], factor=10)
xi1, xi2 = sch_T.split(xi, factor=2) xi1, xi2 = sch_T.split(xi, factor=2)
order = (xi2, xi1, xo) order = (xi2, xi1, xo)
assert tuple(sch_T.leaf_iter_vars) != order assert tuple(sch_T.leaf_iter_vars) != order
...@@ -40,7 +40,7 @@ def test_split(): ...@@ -40,7 +40,7 @@ def test_split():
T = tvm.compute((m,), lambda i: A[i]) T = tvm.compute((m,), lambda i: A[i])
sT = tvm.Schedule(T.op) sT = tvm.Schedule(T.op)
xo, xi = sT.split(T.op.dim_var[0], factor=10) xo, xi = sT.split(T.op.axis[0], factor=10)
assert tuple(sT.leaf_iter_vars) == (xo, xi) assert tuple(sT.leaf_iter_vars) == (xo, xi)
...@@ -51,7 +51,7 @@ def test_tile(): ...@@ -51,7 +51,7 @@ def test_tile():
T = tvm.compute((m, n), lambda i, j: A[i, j]) T = tvm.compute((m, n), lambda i, j: A[i, j])
sch_T = tvm.Schedule(T.op, scope="shared") sch_T = tvm.Schedule(T.op, scope="shared")
xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5) xo, yo, xi, yi = sch_T.tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi) assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -10,7 +10,7 @@ def test_tensor(): ...@@ -10,7 +10,7 @@ def test_tensor():
print(T) print(T)
print(T.op.body) print(T.op.body)
assert(tuple(T.shape) == (m, n, l)) assert(tuple(T.shape) == (m, n, l))
assert(A.source is None) assert(A.op is None)
def test_tensor_reduce(): def test_tensor_reduce():
m = tvm.Var('m') m = tvm.Var('m')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment