Commit 54c18a6c by tqchen

Add map container and tests

parent cd307dda
Subproject commit ea1a81be8baa43665f6ebd4d75d51c081283ebc8
Subproject commit adaea9e85bc0a213d4eb63edfa4762f2147c73ec
Subproject commit 39007ac49b6087339dc3104324cb4e0de47f1c5f
Subproject commit f294fc2271b27b0b6e2b117003ed2dc3d3ba8fda
......@@ -6,7 +6,7 @@
#ifndef TVM_TENSOR_H_
#define TVM_TENSOR_H_
#include <tvm/array.h>
#include <tvm/container.h>
#include <ir/FunctionBase.h>
#include <string>
#include <vector>
......
......@@ -115,6 +115,14 @@ def convert(value):
elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value]
return _function_internal._Array(*value)
elif isinstance(value, dict):
vlist = []
for it in value.items():
if not isinstance(it[0], NodeBase):
raise ValueError("key of map must already been a container type")
vlist.append(it[0])
vlist.append(convert(it[1]))
return _function_internal._Map(*vlist)
elif isinstance(value, SliceBase):
return value.tensor(*value.indices)
else:
......
......@@ -17,6 +17,24 @@ class Array(NodeBase):
def __repr__(self):
return '[' + (','.join(str(x) for x in self)) + ']'
@register_node
class Map(NodeBase):
def __getitem__(self, k):
return _function_internal._MapGetItem(self, k)
def __contains__(self, k):
return _function_internal._MapCount(self, k) != 0
def items(self):
akvs = _function_internal._MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]
def __len__(self):
return _function_internal._MapSize(self)
def __repr__(self):
return '{' + (", ".join(str(x[0]) + ": " +str(x[1]) for x in self.items())) + '}'
@register_node
class Range(NodeBase):
......
......@@ -45,7 +45,7 @@ struct APIAttrGetter : public AttrVisitor {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, uint64_t* value) final {
CHECK_LE(value[0], std::numeric_limits<int64_t>::max())
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
<< "cannot return too big constant";
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
......
......@@ -63,6 +63,71 @@ TVM_REGISTER_API(_ArraySize)
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(_Map)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK_EQ(args.size() % 2, 0U);
MapNode::ContainerType data;
for (size_t i = 0; i < args.size(); i += 2) {
CHECK(args.at(i).type_id == kNodeHandle)
<< "need content of array to be NodeBase";
CHECK(args.at(i + 1).type_id == kNodeHandle)
<< "need content of array to be NodeBase";
data.emplace(std::make_pair(args.at(i).sptr, args.at(i + 1).sptr));
}
auto node = std::make_shared<MapNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_MapSize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.size());
});
TVM_REGISTER_API(_MapGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
auto it = n->data.find(args.at(1).sptr);
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
ret->sptr = (*it).second;
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_MapCount)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
CHECK(args.at(1).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.count(args.at(1).sptr));
});
TVM_REGISTER_API(_MapItems)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
}
ret->sptr = rkvs;
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(Range)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.size() == 1) {
......
......@@ -53,17 +53,17 @@ Array<IterVar> ComputeOpNode::root_iter_vars() const {
}
std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0);
CHECK_EQ(i, 0U);
return name;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0);
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0);
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < dim_var.size(); ++i) {
const Range& r = dim_var[i]->dom;
......
......@@ -4,6 +4,7 @@
* \brief The bound inference logic.
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include "./int_set.h"
#include "./bound.h"
......@@ -22,8 +23,7 @@ void PassDown(const Schedule& s,
std::unordered_map<IterVar, Range>* p_state) {
auto& state = *p_state;
// forwar iteration on relations
for (size_t i = 0; i < s->relations.size(); ++i) {
IterVarRelation rel = s->relations[i];
for (IterVarRelation rel : s->relations) {
if (rel.as<SplitNode>()) {
const SplitNode* r = rel.as<SplitNode>();
CHECK(state.count(r->parent));
......@@ -89,6 +89,59 @@ void PassUp(const Schedule& s,
}
}
void PassBound(
const Tensor& tensor,
const std::vector<IntSet>& arg_bounds,
std::unordered_map<IterVar, std::vector<IntSet> >* result) {
if (tensor->op.as<ComputeOpNode>()) {
auto root_iter_vars = tensor->op->root_iter_vars();
CHECK_EQ(tensor.ndim(), root_iter_vars.size());
for (size_t i = 0; i < tensor.ndim(); ++i) {
(*result)[root_iter_vars[i]].push_back(arg_bounds[i]);
}
} else {
LOG(FATAL) << "unknown operation mode";
}
}
void PassBound(
Operation op,
std::unordered_map<IterVar, IntSet>* ebound) {
if (op.as<ComputeOpNode>()) {
auto fvisit = [ebound](const NodeRef& n) {
auto *call = n.as<ir::Call>();
if (call != nullptr && call->func.defined()) {
Tensor t(call->func.node_);
std::vector<IntSet> arg_bounds;
for (size_t i = 0; i < t.ndim(); ++i) {
arg_bounds.push_back(Eval(call->args[i], *ebound));
}
}
};
ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
} else {
LOG(FATAL) << "unknown operation mode";
}
}
void InferBound(const Schedule& sch,
std::unordered_map<IterVar, Range>* rmap) {
CHECK_NE(sch->attach_type, kNone);
if (sch->attach_type == kInline) return;
if (sch->attach_type == kRoot) {
auto root_iter_vars = sch->op->root_iter_vars();
for (size_t i = 0; i < root_iter_vars.size(); ++i) {
auto v = root_iter_vars[i];
CHECK(v->dom.defined());
CHECK(!rmap->count(v));
(*rmap)[v] = v->dom;
}
}
// get range of all child iter vars.
PassDown(sch, rmap);
// pass iteration variable to children
}
std::unordered_map<IterVar, Range> InferBound(Schedule sch) {
return {};
......
......@@ -122,7 +122,7 @@ inline Range Negation(Range a) {
}
inline IntSet Negation(IntSet a) {
CHECK_EQ(a->concrete.size(), 0);
CHECK_EQ(a->concrete.size(), 0U);
auto n = std::make_shared<IntSetNode>();
n->base = Negation(a->base);
for (size_t i = 0; i < a->domain.size(); ++i) {
......@@ -182,6 +182,11 @@ IntSet IntSet::make(Range dom) {
return IntSet(n);
}
IntSet IntSet::make_all_set() {
LOG(FATAL) << "TODO";
return IntSet();
}
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
......
......@@ -18,10 +18,44 @@ TEST(Array, Mutate) {
auto z = max(x + 1 + 2, 100);
Array<Expr> list{x, z, z};
auto list2 = list;
list[1] = x;
list.Set(1, x);
CHECK(list[1].same_as(x));
CHECK(list2[1].same_as(z));
}
LOG(INFO) << list[1];
LOG(INFO) << list2[1];
TEST(Map, Expr) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
auto zz = z + 1;
Map<Expr, Expr> dict{{x, z}, {z, 2}};
CHECK(dict.size() == 2);
CHECK(dict[x].same_as(z));
CHECK(dict.count(z));
CHECK(!dict.count(zz));
}
TEST(Map, Mutate) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
Map<Expr, Expr> dict{{x, z}, {z, 2}};
auto zz = z + 1;
CHECK(dict[x].same_as(z));
dict.Set(x, zz);
auto dict2 = dict;
CHECK(dict2.count(z) == 1);
dict.Set(zz, x);
CHECK(dict2.count(zz) == 0);
CHECK(dict.count(zz) == 1);
auto it = dict.find(zz);
CHECK(it != dict.end() && (*it).second.same_as(x));
it = dict2.find(zz);
CHECK(it == dict.end());
LOG(INFO) << dict;
}
int main(int argc, char ** argv) {
......
......@@ -38,8 +38,6 @@ def test_basic():
c = a + b
assert str(c) == '(%s + %s)' % (a.name, b.name)
def test_array():
a = tvm.convert([1,2,3])
def test_stmt():
x = tvm.make.Evaluate(0)
......
import tvm
def test_array():
a = tvm.convert([1,2,3])
assert len(a) == 3
def test_map():
a = tvm.Var('a')
b = tvm.Var('b')
amap = tvm.convert({a: 2,
b: 3})
assert a in amap
assert len(amap) == 2
dd = dict(amap.items())
assert str(dd) == str(amap)
assert a + 1 not in amap
if __name__ == "__main__":
test_array()
test_map()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment