Commit 93057c8a by Tianqi Chen Committed by GitHub

[BUILD] add with_api_wrapper to lower (#95)

parent 3b8ad0a2
...@@ -12,10 +12,12 @@ from . import ir_pass ...@@ -12,10 +12,12 @@ from . import ir_pass
from . import collections from . import collections
from . import codegen from . import codegen
def lower(sch, def lower(sch,
args, args,
name="default_function", name="default_function",
binds=None, binds=None,
with_api_wrapper=True,
max_auto_unroll_step=8): max_auto_unroll_step=8):
"""Lowering step before build into target. """Lowering step before build into target.
...@@ -34,13 +36,17 @@ def lower(sch, ...@@ -34,13 +36,17 @@ def lower(sch,
Dictionary that maps the binding of symbolic buffer to Tensor. Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument. By default, a new buffer is created for each tensor in the argument.
with_api_wrapper : bool, optional
Whether add API wrapper during lowering.
max_auto_unroll_step: int, optional max_auto_unroll_step: int, optional
Maximum step to perform automatic unrolling Maximum step to perform automatic unrolling
Returns Returns
------- -------
f : LoweredFunc f : LoweredFunc or Stmt
The result function. The result function, if with_api_wrapper=False
Then the Stmt before make api is returned.
""" """
binds = {} if binds is None else binds.copy() binds = {} if binds is None else binds.copy()
arg_list = [] arg_list = []
...@@ -67,8 +73,9 @@ def lower(sch, ...@@ -67,8 +73,9 @@ def lower(sch,
stmt = ir_pass.LiftAllocate(stmt) stmt = ir_pass.LiftAllocate(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step) stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, 0) if not with_api_wrapper:
return fapi return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0)
def build(sch, def build(sch,
......
...@@ -170,7 +170,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ...@@ -170,7 +170,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
if (idx < leaf_vars->data.size()) { if (idx < leaf_vars->data.size()) {
// insert rebase // insert rebase
IterVar rebased = IterVarNode::make( IterVar rebased = IterVarNode::make(
Range(), iv->var.copy_with_suffix(".rb"), iv->iter_type); Range(), iv->var.copy_with_suffix(""), iv->iter_type);
s->relations.push_back(RebaseNode::make(iv, rebased)); s->relations.push_back(RebaseNode::make(iv, rebased));
leaf_vars->data[idx] = rebased.node_; leaf_vars->data[idx] = rebased.node_;
rebase_map[iv] = rebased; rebase_map[iv] = rebased;
......
...@@ -162,7 +162,7 @@ class SchedulePostProc : public IRMutator { ...@@ -162,7 +162,7 @@ class SchedulePostProc : public IRMutator {
// delete duplicated thread extent attr // delete duplicated thread extent attr
auto it = thread_extent_scope_.find(op->node.get()); auto it = thread_extent_scope_.find(op->node.get());
if (it != thread_extent_scope_.end()) { if (it != thread_extent_scope_.end()) {
CHECK(is_zero(ir::Simplify(it->second- op->value))); CHECK(is_zero(ir::Simplify(it->second - op->value)));
return this->Mutate(op->body); return this->Mutate(op->body);
} else { } else {
thread_extent_scope_[op->node.get()] = op->value; thread_extent_scope_[op->node.get()] = op->value;
......
...@@ -17,8 +17,7 @@ import numpy as np ...@@ -17,8 +17,7 @@ import numpy as np
# Vector Add Example # Vector Add Example
# ------------------ # ------------------
# In this tutorial, we will use a vector addition example to demonstrate # In this tutorial, we will use a vector addition example to demonstrate
# the workflow in TVM. We will demonstrate how we can describe and compile # the workflow.
# vector addition code that runs on GPU.
# #
###################################################################### ######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment