Commit b0b16a07 by Zhao Wu Committed by Tianqi Chen

[CodeGen] Add build config option disable_assert to control whether to generate assert (#4340)

parent a8e6ee9b
...@@ -229,6 +229,9 @@ class BuildConfigNode : public Node { ...@@ -229,6 +229,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable loop vectorization. */ /*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false; bool disable_vectorize = false;
/*! \brief Whether to disable assert stmt generation. */
bool disable_assert = false;
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("data_alignment", &data_alignment); v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor); v->Visit("offset_factor", &offset_factor);
...@@ -244,6 +247,7 @@ class BuildConfigNode : public Node { ...@@ -244,6 +247,7 @@ class BuildConfigNode : public Node {
v->Visit("instrument_bound_checkers", &instrument_bound_checkers); v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting); v->Visit("disable_select_rewriting", &disable_select_rewriting);
v->Visit("disable_vectorize", &disable_vectorize); v->Visit("disable_vectorize", &disable_vectorize);
v->Visit("disable_assert", &disable_assert);
} }
static constexpr const char* _type_key = "BuildConfig"; static constexpr const char* _type_key = "BuildConfig";
......
...@@ -564,6 +564,13 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); ...@@ -564,6 +564,13 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
LoweredFunc InferFragment(LoweredFunc f); LoweredFunc InferFragment(LoweredFunc f);
/*! /*!
* \brief skip assert stmt generation
* \param f The function to be transformed.
* \return Transformed function.
*/
LoweredFunc SkipAssert(LoweredFunc f);
/*!
* \brief Verify if memory accesses are legal for a specific target device type. * \brief Verify if memory accesses are legal for a specific target device type.
* *
* In the case that tgt is cuda, if not all workload is bound with * In the case that tgt is cuda, if not all workload is bound with
......
...@@ -144,7 +144,8 @@ class BuildConfig(NodeBase): ...@@ -144,7 +144,8 @@ class BuildConfig(NodeBase):
"dump_pass_ir": False, "dump_pass_ir": False,
"instrument_bound_checkers": False, "instrument_bound_checkers": False,
"disable_select_rewriting": False, "disable_select_rewriting": False,
"disable_vectorize": False "disable_vectorize": False,
"disable_assert": False
} }
_dump_ir = DumpIR() _dump_ir = DumpIR()
......
...@@ -672,6 +672,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -672,6 +672,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << "disable_vectorize=" << op->disable_vectorize; p->stream << "disable_vectorize=" << op->disable_vectorize;
p->stream << "disable_assert=" << op->disable_assert;
p->stream << ")"; p->stream << ")";
}); });
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/build_module.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <sstream> #include <sstream>
#include <iostream> #include <iostream>
...@@ -40,12 +41,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs, ...@@ -40,12 +41,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
if (pos != std::string::npos) { if (pos != std::string::npos) {
mode = mode.substr(0, pos); mode = mode.substr(0, pos);
} }
Array<LoweredFunc> transformed_funcs;
for (const auto& x : funcs) {
if (BuildConfig::Current()->disable_assert) {
auto func = ir::SkipAssert(x);
transformed_funcs.push_back(func);
}
}
std::string build_f_name = "codegen.build_" + mode; std::string build_f_name = "codegen.build_" + mode;
// the build function. // the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name); const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr) CHECK(bf != nullptr)
<< "Target " << target << " is not enabled"; << "Target " << target << " is not enabled";
runtime::Module m = (*bf)(funcs, target); runtime::Module m = transformed_funcs.empty() ?
(*bf)(funcs, target) :
(*bf)(transformed_funcs, target);
return m; return m;
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
namespace tvm {
namespace ir {
class AssertSkipper : public IRMutator {
public:
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AssertStmt>();
return op->body;
}
};
Stmt SkipAssert(Stmt stmt) {
return AssertSkipper().Mutate(stmt);
}
LoweredFunc SkipAssert(LoweredFunc f) {
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = SkipAssert(f->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment