Commit dd248af6 by Ding Committed by Tianqi Chen

[LANGUAGE] Verify Compute with respect to Reduce operations (#1006)

parent 6bcf95f2
......@@ -24,6 +24,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
/// Verify if ComputeOp is valid with respect to Reduce operations.
static void VerifyComputeOp(const ComputeOpNode *op);
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
......@@ -116,15 +119,9 @@ Operation ComputeOpNode::make(std::string name,
n->body = body;
if (n->body[0]->is_type<ir::Reduce>()) {
const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
for (size_t i = 1; i < n->body.size(); ++i) {
const ir::Reduce* reduce_ = n->body[i].as<ir::Reduce>();
CHECK(reduce_);
CHECK(ReduceEqual(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
n->reduce_axis = reduce->axis;
}
VerifyComputeOp(n.get());
return Operation(n);
}
......@@ -151,18 +148,11 @@ Operation ComputeOpNode::ReplaceInputs(
const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
VerifyComputeOp(this);
Array<Expr> arr;
if (this->body[0]->is_type<ir::Reduce>()) {
// Specially handle reduce so the replaced op
// still share all the components
const ir::Reduce* reduce = this->body[0].as<ir::Reduce>();
for (size_t i = 1; i < this->body.size(); ++i) {
const ir::Reduce* reduce_ = this->body[i].as<ir::Reduce>();
CHECK(reduce_);
CHECK(ReduceEqual(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}\
Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
if (!new_reduce.same_as(this->body[0])) {
const ir::Reduce* r = new_reduce.as<ir::Reduce>();
......@@ -466,4 +456,78 @@ ComputeLoopNest ComputeLoopNest::make(
// copy elison here.
return ret;
}
namespace {
/*!
* \brief Verify if ComputeOp is valid with respect to Reduce operations.
*
* The following two properties are verified:
* (1) All Reduce operations must exist at top level.
* (2) For a list of operations, if one is Reduce, then the others
* must be Reduce as well; and their inputs should have the
* same attribute except value_index.
*/
class ComputeVerifier final : protected ir::IRVisitor {
public:
/// Special member functions
//@{
explicit ComputeVerifier(const ComputeOpNode* compute)
: compute_(compute), reduce_(compute->body[0].as<ir::Reduce>()) {}
virtual ~ComputeVerifier() = default;
ComputeVerifier(const ComputeVerifier&) = delete;
ComputeVerifier(ComputeVerifier&&) = delete;
ComputeVerifier& operator=(const ComputeVerifier&) = delete;
ComputeVerifier& operator=(ComputeVerifier&&) = delete;
//@}
/// Interface to perform compute verification
void Run() {
for (const Expr e : compute_->body) {
// Check for consistency of top level reductions
const ir::Reduce* reduce = e.as<ir::Reduce>();
CHECK((reduce && reduce_) || (!reduce && !reduce_))
<< "All ComputeOp should be consistent "
<< "with being Reduce operation or not.";
if (reduce && reduce_) {
CHECK(ReduceEqual(reduce, reduce_))
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
level_ = 0;
ir::IRVisitor::Visit(e);
}
}
protected:
/// Visitor implementation
//@{
void Visit(const NodeRef& n) final {
++level_;
ir::IRVisitor::Visit(n);
--level_;
}
void Visit_(const ir::Reduce* op) final {
// Check for non top level reductions
CHECK(0 == level_)
<< "Reductions are only allowed at the top level of compute. "
<< "Please create another tensor for further composition.";
}
//@}
private:
const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
const ir::Reduce* reduce_{nullptr}; ///< Top level Reduce operation
int level_{0}; ///< Level of op being processed
};
} // namespace
/// Verify if ComputeOp is valid with respect to Reduce operations.
static void VerifyComputeOp(const ComputeOpNode* op) {
ComputeVerifier v(op);
v.Run();
}
} // namespace tvm
import tvm
def test_verify_compute():
n = tvm.var("n")
m = tvm.var("m")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k")
k_ = tvm.reduce_axis((0, m-1), "k_")
f1 = lambda i: tvm.sum(A[i, k], axis=k)
f2 = lambda i: A[i,0] + 1
f3 = lambda i: tvm.sum(A[i, k], axis=k) + 1
f4 = lambda i: A[i,0] * (tvm.sum(A[i, k], axis=k) + 1)
f5 = lambda i: (tvm.sum(A[i, k], axis=k), A[i,0] + 1)
f6 = lambda i: (tvm.sum(A[i, k], axis=k), tvm.sum(A[i, k_], axis=k_))
#
# Valid compute
try:
B = tvm.compute((n,), f1, name="B")
except tvm._ffi.base.TVMError as ex:
assert False
#
# Valid compute
try:
B = tvm.compute((n,), f2, name="B")
except tvm._ffi.base.TVMError as ex:
assert False
#
# Invalid compute with non top level reduction
try:
B = tvm.compute((n,), f3, name="B")
assert False
except tvm._ffi.base.TVMError as ex:
pass
#
# Invalid compute with non top level reduction
try:
B = tvm.compute((n,), f4, name="B")
assert False
except tvm._ffi.base.TVMError as ex:
pass
#
# Invalid compute with reduction and non-reduction batch ops
try:
B0, B1 = tvm.compute((n,), f5, name="B")
assert False
except tvm._ffi.base.TVMError as ex:
pass
#
# Invalid compute with unequal batch reduction ops
try:
B0, B1 = tvm.compute((n,), f6, name="B")
assert False
except tvm._ffi.base.TVMError as ex:
pass
if __name__ == "__main__":
test_verify_compute()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment