Commit b0b16a07 by Zhao Wu Committed by Tianqi Chen

[CodeGen] Add build config option disable_assert to control whether to generate assert (#4340)

parent a8e6ee9b
......@@ -229,6 +229,9 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false;
/*! \brief Whether to disable assert stmt generation. */
bool disable_assert = false;
void VisitAttrs(AttrVisitor* v) {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
......@@ -244,6 +247,7 @@ class BuildConfigNode : public Node {
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting);
v->Visit("disable_vectorize", &disable_vectorize);
v->Visit("disable_assert", &disable_assert);
}
static constexpr const char* _type_key = "BuildConfig";
......
......@@ -564,6 +564,13 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
LoweredFunc InferFragment(LoweredFunc f);
/*!
* \brief skip assert stmt generation
* \param f The function to be transformed.
* \return Transformed function.
*/
LoweredFunc SkipAssert(LoweredFunc f);
/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
......
......@@ -144,7 +144,8 @@ class BuildConfig(NodeBase):
"dump_pass_ir": False,
"instrument_bound_checkers": False,
"disable_select_rewriting": False,
"disable_vectorize": False
"disable_vectorize": False,
"disable_assert": False
}
_dump_ir = DumpIR()
......
......@@ -672,6 +672,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << "disable_vectorize=" << op->disable_vectorize;
p->stream << "disable_assert=" << op->disable_assert;
p->stream << ")";
});
......
......@@ -26,6 +26,7 @@
#include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/build_module.h>
#include <dmlc/memory_io.h>
#include <sstream>
#include <iostream>
......@@ -40,12 +41,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
if (pos != std::string::npos) {
mode = mode.substr(0, pos);
}
Array<LoweredFunc> transformed_funcs;
for (const auto& x : funcs) {
if (BuildConfig::Current()->disable_assert) {
auto func = ir::SkipAssert(x);
transformed_funcs.push_back(func);
}
}
std::string build_f_name = "codegen.build_" + mode;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
<< "Target " << target << " is not enabled";
runtime::Module m = (*bf)(funcs, target);
runtime::Module m = transformed_funcs.empty() ?
(*bf)(funcs, target) :
(*bf)(transformed_funcs, target);
return m;
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
namespace tvm {
namespace ir {
class AssertSkipper : public IRMutator {
public:
Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AssertStmt>();
return op->body;
}
};
Stmt SkipAssert(Stmt stmt) {
return AssertSkipper().Mutate(stmt);
}
LoweredFunc SkipAssert(LoweredFunc f) {
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = SkipAssert(f->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment