Unverified Commit 623dd208 by Haichen Shen Committed by GitHub

[Relay][AutoTVM] Relay op strategy (#4644)

* relay op strategy

fix lint

bitpack strategy

bitserial_dense (#6)

* update strategy

* address comments

fix a few topi test

Dense strategy (#5)

* dense

* add biforst; remove comments

* address comment

Refactor x86 conv2d_NCHWc (#4)

* Refactor x86 conv2d

* Add x86 depthwise_conv2d_NCHWc

* Add back topi x86 conv2d_nchw

* Merge x86 conv2d_nchw and conv2d_NCHWc

* Minor fix for x86 conv2d

fix more strategy

Add x86 conv2d_NCHWc_int8 strategy (#8)

* Add x86 conv2d_NCHWc_int8 strategy

* Remove contrib_conv2d_nchwc_int8

* Fix generic conv2d_NCHWc for int8

* Fix topi arm_cpu conv2d_NCHWc_int8

update x86 conv2d

enable specify relay ops to be tuned for autotvm

add cuda conv2d strategy

add conv2d strategy for rocm

add conv2d strategy for hls

add conv2d strategy for arm cpu

add conv2d strategy for mali

add conv2d strategy for bifrost

add conv2d strategy for intel graphics

clean up and fix lint

remove template keys from autotvm

remove 2 in the func name

address comments

fix

* fix bugs

* lint

* address comments

* add name to op implement

* Modify topi tests (#9)

* Add pooling, reorg, softmax and vision

* Add lrn

* fix topi test

* fix more topi test

* lint

* address comments

* x

* fix more tests & bugs

* Modify more tests (#10)

* Modify tests for bitserial_conv2d, bitserial_dense, bitserial_conv2d_rasp and bnn

* Minor fix

* More minor fix

* fix more test

* try to update vta using strategy

* fix cpptest

* x

* fix rebase err

* Fix two tests (#11)

* change autotvm log format

* lint

* minor fix

* try fix vta test

* fix rebase err

* tweak

* tmp hack for vta pass

* fix tutorial

* fix

* fix more tutorials

* fix vta tutorial

* minor

* address comments

* fix

* address comments

* fix cpptest

* fix docs

* change data structure name and api

* address comments

* lint

* fix rebase err

* updates

* fix winograd test

* fix doc

* rebase

* upgrade tophub version number

* fix bug

* re-enable vta tsim test after tophub is upgraded

* fix vta test to use the correct args so the config can be found in tophub

Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
parent c4c61cb7
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <tvm/target/generic_func.h>
#include <tvm/tir/data_layout.h> #include <tvm/tir/data_layout.h>
#include <string> #include <string>
...@@ -106,8 +107,7 @@ using TShapeDataDependant = bool; ...@@ -106,8 +107,7 @@ using TShapeDataDependant = bool;
using FTVMCompute = runtime::TypedPackedFunc< using FTVMCompute = runtime::TypedPackedFunc<
Array<te::Tensor>(const Attrs& attrs, Array<te::Tensor>(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type)>;
const Target& target)>;
/*! /*!
* \brief Build the computation schedule for * \brief Build the computation schedule for
...@@ -124,6 +124,16 @@ using FTVMSchedule = runtime::TypedPackedFunc< ...@@ -124,6 +124,16 @@ using FTVMSchedule = runtime::TypedPackedFunc<
const Target& target)>; const Target& target)>;
/*! /*!
* \brief Generate the strategy of operators. This function is a generic
* function and can be re-defined for different targets.
*
* The function signature of generic function is:
* OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
* const Type& out_type, const Target& target)
*/
using FTVMStrategy = GenericFunc;
/*!
* \brief Alternate the layout of operators or replace the * \brief Alternate the layout of operators or replace the
* operator with other expressions. This function will be invoked * operator with other expressions. This function will be invoked
* in AlterOpLayout pass. * in AlterOpLayout pass.
...@@ -136,7 +146,8 @@ using FTVMSchedule = runtime::TypedPackedFunc< ...@@ -136,7 +146,8 @@ using FTVMSchedule = runtime::TypedPackedFunc<
using FTVMAlterOpLayout = runtime::TypedPackedFunc< using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs, Expr(const Attrs& attrs,
const Array<Expr>& args, const Array<Expr>& args,
const Array<te::Tensor>& tinfos)>; const Array<te::Tensor>& tinfos,
const Type& out_type)>;
/*! /*!
* \brief Convert the layout of operators or replace the * \brief Convert the layout of operators or replace the
...@@ -191,9 +202,7 @@ using FForwardRewrite = runtime::TypedPackedFunc< ...@@ -191,9 +202,7 @@ using FForwardRewrite = runtime::TypedPackedFunc<
* \brief Gradient for a specific op. * \brief Gradient for a specific op.
* *
* \param orig_call the original Expr. * \param orig_call the original Expr.
*
* \param output_grad the gradient of the Expr. * \param output_grad the gradient of the Expr.
*
* \return the gradient for each parameters. * \return the gradient for each parameters.
*/ */
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call, using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
...@@ -207,7 +216,7 @@ enum AnyCodegenStrategy { ...@@ -207,7 +216,7 @@ enum AnyCodegenStrategy {
kVariableDimensions kVariableDimensions
}; };
/* \brief A runtime representation of shape. */ /*! \brief A runtime representation of shape. */
using Shape = Array<IndexExpr>; using Shape = Array<IndexExpr>;
using FShapeFunc = runtime::TypedPackedFunc< using FShapeFunc = runtime::TypedPackedFunc<
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/op_strategy.h
* \brief The Relay operator Strategy and related data structure.
*/
#ifndef TVM_RELAY_OP_STRATEGY_H_
#define TVM_RELAY_OP_STRATEGY_H_
#include <tvm/te/tensor.h>
#include <tvm/te/schedule.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/target/target.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Operator implementation that includes compute and schedule function.
*/
class OpImplementationNode : public Object {
public:
/*! \brief Compute function */
FTVMCompute fcompute;
/*! \brief Schedule function */
FTVMSchedule fschedule;
/*! \brief Name of the implementation */
std::string name;
/*! \brief Priority level */
int plevel;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("plevel", &plevel);
}
static constexpr const char* _type_key = "relay.OpImplementation";
TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object);
};
/*!
* \brief Operator implementation class.
*/
class OpImplementation : public ObjectRef {
public:
/*!
* \brief Invoke the operator compute function.
* \param attrs The attribute of the primitive
* \param inputs The input tensors.
* \param out_type The output type information.
* \return The output compute description of the operator.
*/
TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type);
/*!
* \brief Build the computation schedule.
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return The computation schedule.
*/
TVM_DLL te::Schedule Schedule(const Attrs& attrs,
const Array<te::Tensor>& outs,
const Target& target);
TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode);
};
/*!
* \brief Specialized implementations for operators under certain conditions.
*/
class OpSpecializationNode : public Object {
public:
/*! \brief List of implementations. */
Array<OpImplementation> implementations;
/*! \brief Condition to enable the specialization.
* Could be undefined to represent generic case. */
te::SpecializedCondition condition;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("condition", &condition);
v->Visit("implementations", &implementations);
}
static constexpr const char* _type_key = "relay.OpSpecialization";
TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);
};
/*!
* \brief Operator specialization class.
*/
class OpSpecialization : public ObjectRef {
public:
/*!
* \brief Add an implementation.
* \param fcompute Compute function
* \param fschedule Schedule function
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
std::string name, int plevel);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
};
/*!
* \brief Operator strategy to choose implementation.
*/
class OpStrategyNode : public Object {
public:
/*! \brief List of operator specializations. */
Array<OpSpecialization> specializations;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("specializations", &specializations);
}
static constexpr const char* _type_key = "relay.OpStrategy";
TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
};
/*!
* \brief Operator strategy class.
*/
class OpStrategy : public ObjectRef {
public:
/*!
* \brief Add an implementation.
* \param fcompute Compute function
* \param fschedule Schedule function
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
std::string name, int plevel);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_STRATEGY_H_
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/te/tensor.h> #include <tvm/te/tensor.h>
#include <tvm/te/tensor_intrin.h> #include <tvm/te/tensor_intrin.h>
#include <tvm/support/with.h>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -742,6 +743,53 @@ class SingletonNode : public IterVarRelationNode { ...@@ -742,6 +743,53 @@ class SingletonNode : public IterVarRelationNode {
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode); TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
}; };
/*! \brief Container for specialization conditions. */
class SpecializedConditionNode : public Object {
public:
/*!
* \brief List of conditions in conjunctive joint form (CNF).
* Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc.,
* where n, m are tvm::Var that represents a dimension in the tensor shape.
*/
Array<PrimExpr> clauses;
void VisitAttrs(AttrVisitor* v) {
v->Visit("clauses", &clauses);
}
static constexpr const char* _type_key = "SpecializedCondition";
TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
};
/*!
* \brief Specialized condition to enable op specialization
*/
class SpecializedCondition : public ObjectRef {
public:
/*!
* \brief construct from conditions
* \param conditions The clauses in the specialized condition.
*/
TVM_DLL SpecializedCondition(Array<PrimExpr> conditions); // NOLINT(*)
/*!
* \brief Get the current specialized condition.
* \return the current specialized condition.
*/
TVM_DLL static SpecializedCondition Current();
TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode);
class Internal;
private:
// enable with syntax.
friend class Internal;
friend class With<SpecializedCondition>;
/*! \brief Push a new specialized condition onto the thread local stack. */
TVM_DLL void EnterWithScope();
/*! \brief Pop a specialized condition off the thread local context stack. */
TVM_DLL void ExitWithScope();
};
// implementations // implementations
inline const StageNode* Stage::operator->() const { inline const StageNode* Stage::operator->() const {
...@@ -765,6 +813,7 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const { ...@@ -765,6 +813,7 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
inline const IterVarAttrNode* IterVarAttr::operator->() const { inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(get()); return static_cast<const IterVarAttrNode*>(get());
} }
} // namespace te } // namespace te
} // namespace tvm } // namespace tvm
#endif // TVM_TE_SCHEDULE_H_ #endif // TVM_TE_SCHEDULE_H_
...@@ -41,8 +41,8 @@ from . import tophub ...@@ -41,8 +41,8 @@ from . import tophub
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \ from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
LocalBuilder, LocalRunner, RPCRunner LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \ from .task import get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, \ register_topi_compute, register_topi_schedule, register_customized_task, \
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \ DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
ApplyGraphBest as apply_graph_best ApplyGraphBest as apply_graph_best
from .env import GLOBAL_SCOPE from .env import GLOBAL_SCOPE
...@@ -125,7 +125,7 @@ class RedisDatabase(Database): ...@@ -125,7 +125,7 @@ class RedisDatabase(Database):
current = self.get(measure_str_key(inp)) current = self.get(measure_str_key(inp))
if current is not None: if current is not None:
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)] records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
results = [rec[1] for rec in records] results = [rec[1] for rec in records if rec is not None]
if get_all: if get_all:
return results return results
return max(results, key=lambda result: result.timestamp) return max(results, key=lambda result: result.timestamp)
...@@ -167,9 +167,12 @@ class RedisDatabase(Database): ...@@ -167,9 +167,12 @@ class RedisDatabase(Database):
current = self.get(key) current = self.get(key)
try: try:
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)] records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
records = [rec for rec in records if rec is not None]
except TypeError: # got a badly formatted/old format record except TypeError: # got a badly formatted/old format record
continue continue
if not records:
continue
inps, results = zip(*records) inps, results = zip(*records)
inp = inps[0] inp = inps[0]
if not func(inp, results): if not func(inp, results):
......
...@@ -153,7 +153,10 @@ def get_flatten_name(fea): ...@@ -153,7 +153,10 @@ def get_flatten_name(fea):
from .record import decode from .record import decode
# flatten line to feature # flatten line to feature
line = fea line = fea
inp, _ = decode(line) ret = decode(line)
if ret is None:
raise ValueError("Unsupported AutoTVM log format")
inp, _ = ret
target = _target.create(inp.target) target = _target.create(inp.target)
with target: with target:
s, args = inp.template.instantiate(inp.config) s, args = inp.template.instantiate(inp.config)
......
...@@ -25,7 +25,6 @@ import topi ...@@ -25,7 +25,6 @@ import topi
import tvm import tvm
from tvm import autotvm, relay from tvm import autotvm, relay
from tvm.autotvm.task import get_config from tvm.autotvm.task import get_config
from tvm.autotvm.task.topi_integration import deserialize_args, serialize_args
from tvm.autotvm.record import encode, load_from_file from tvm.autotvm.record import encode, load_from_file
from tvm.autotvm.measure import MeasureResult, MeasureInput from tvm.autotvm.measure import MeasureResult, MeasureInput
...@@ -35,18 +34,16 @@ from .utils import is_boundary_node, get_in_nodes, get_out_nodes, has_multiple_i ...@@ -35,18 +34,16 @@ from .utils import is_boundary_node, get_in_nodes, get_out_nodes, has_multiple_i
from ._base import INVALID_LAYOUT_TIME from ._base import INVALID_LAYOUT_TIME
# Setup topi_op_name -> layout function def get_infer_layout(task_name):
# NOTE: To add more ops, change the following dictionary. if task_name.startswith("conv2d"):
OP2LAYOUT = { return topi.nn.conv2d_infer_layout
"topi_nn_conv2d": topi.nn.conv2d_infer_layout, if task_name.startswith("depthwise_conv2d"):
"topi_nn_depthwise_conv2d_nchw": topi.nn.depthwise_conv2d_infer_layout, return topi.nn.depthwise_conv2d_infer_layout
} raise ValueError("Cannot find infer layout for task %s" % task_name)
@autotvm.register_customized_task("layout_transform")
@autotvm.template
def layout_transform(*args): def layout_transform(*args):
"""Autotvm layout transform template.""" """Autotvm layout transform template."""
args = deserialize_args(args)
cfg = get_config() cfg = get_config()
cfg.add_flop(-1) cfg.add_flop(-1)
data = args[0] data = args[0]
...@@ -82,7 +79,7 @@ class BaseGraphTuner(object): ...@@ -82,7 +79,7 @@ class BaseGraphTuner(object):
Each row of this file is an encoded record pair. Each row of this file is an encoded record pair.
Otherwise, it is an iterator. Otherwise, it is an iterator.
target_ops : List of str target_ops : List of relay.op.Op
Target tuning operators. Target tuning operators.
target : str or tvm.target target : str or tvm.target
...@@ -104,7 +101,7 @@ class BaseGraphTuner(object): ...@@ -104,7 +101,7 @@ class BaseGraphTuner(object):
self._layout_transform_perf_records = {} self._layout_transform_perf_records = {}
self._layout_transform_interlayer_cost = {} self._layout_transform_interlayer_cost = {}
self._input_shapes = input_shapes self._input_shapes = input_shapes
self._target_ops = [op.__name__ for op in target_ops] self._target_ops = target_ops
self._name = name self._name = name
self._max_sch_num = max_sch_num self._max_sch_num = max_sch_num
...@@ -179,7 +176,7 @@ class BaseGraphTuner(object): ...@@ -179,7 +176,7 @@ class BaseGraphTuner(object):
dtype = first_tensor[-1] dtype = first_tensor[-1]
new_shape = tuple([val.value for val in node_entry["types"][0].shape]) new_shape = tuple([val.value for val in node_entry["types"][0].shape])
actual_workload = (input_workload[0],) + \ actual_workload = (input_workload[0],) + \
((new_shape + (dtype,)),) + input_workload[2:] (("TENSOR", new_shape, dtype),) + input_workload[2:]
node_entry["workloads"].append(actual_workload) node_entry["workloads"].append(actual_workload)
if "record_candidates" not in node_entry: if "record_candidates" not in node_entry:
node_entry["record_candidates"] = input_node["record_candidates"] node_entry["record_candidates"] = input_node["record_candidates"]
...@@ -212,7 +209,7 @@ class BaseGraphTuner(object): ...@@ -212,7 +209,7 @@ class BaseGraphTuner(object):
node_entry["record_candidates"] = cache_dict[workload] node_entry["record_candidates"] = cache_dict[workload]
continue continue
record_candidates = [] record_candidates = []
infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]] infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
layout_tracking_dict = {} layout_tracking_dict = {}
for record in cfg_dict[workload]: for record in cfg_dict[workload]:
in_measure, out_measure = record in_measure, out_measure = record
...@@ -264,7 +261,7 @@ class BaseGraphTuner(object): ...@@ -264,7 +261,7 @@ class BaseGraphTuner(object):
if node_entry["op"] in self._target_ops: if node_entry["op"] in self._target_ops:
o_idx = key o_idx = key
o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]] o_infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
o_wkl = node_entry["workloads"][0] o_wkl = node_entry["workloads"][0]
i_topi_op = in_node_entry["topi_op"][0] i_topi_op = in_node_entry["topi_op"][0]
i_wkl = in_node_entry["workloads"][0] i_wkl = in_node_entry["workloads"][0]
...@@ -273,14 +270,14 @@ class BaseGraphTuner(object): ...@@ -273,14 +270,14 @@ class BaseGraphTuner(object):
pivot += 1 pivot += 1
i_topi_op = in_node_entry["topi_op"][pivot] i_topi_op = in_node_entry["topi_op"][pivot]
i_wkl = in_node_entry["workloads"][pivot] i_wkl = in_node_entry["workloads"][pivot]
i_infer_layout_func = OP2LAYOUT[i_topi_op] i_infer_layout_func = get_infer_layout(i_topi_op)
else: else:
o_idx = target_input_idx o_idx = target_input_idx
if i <= target_input_pos: if i <= target_input_pos:
continue continue
o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]] o_infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
o_wkl = node_entry["workloads"][target_input_pos] o_wkl = node_entry["workloads"][target_input_pos]
i_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][i]] i_infer_layout_func = get_infer_layout(node_entry["topi_op"][i])
i_wkl = node_entry["workloads"][i] i_wkl = node_entry["workloads"][i]
if (i_idx, o_idx) in pair_tracker: if (i_idx, o_idx) in pair_tracker:
...@@ -314,9 +311,8 @@ class BaseGraphTuner(object): ...@@ -314,9 +311,8 @@ class BaseGraphTuner(object):
to_sch_idx, args): to_sch_idx, args):
"""Create dictionary containing matrix format of layout transformation """Create dictionary containing matrix format of layout transformation
between nodes.""" between nodes."""
sargs = serialize_args(args)
in_layout, out_layout = args[1], args[2] in_layout, out_layout = args[1], args[2]
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(sargs) ltf_workload = autotvm.task.args_to_workload(args, 'layout_transform')
idx_pair_key = (from_node_idx, to_node_idx) idx_pair_key = (from_node_idx, to_node_idx)
if in_layout == out_layout: if in_layout == out_layout:
...@@ -449,8 +445,7 @@ class BaseGraphTuner(object): ...@@ -449,8 +445,7 @@ class BaseGraphTuner(object):
measure_option = autotvm.measure_option(builder=builder, runner=runner) measure_option = autotvm.measure_option(builder=builder, runner=runner)
for args in args_list: for args in args_list:
data, in_layout, out_layout = args data, in_layout, out_layout = args
args = serialize_args(args) ltf_workload = autotvm.task.args_to_workload(args, 'layout_transform')
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
if ltf_workload in self._layout_transform_perf_records: if ltf_workload in self._layout_transform_perf_records:
continue continue
...@@ -478,9 +473,8 @@ class BaseGraphTuner(object): ...@@ -478,9 +473,8 @@ class BaseGraphTuner(object):
continue continue
records = [] records = []
task = autotvm.task.create(layout_transform, args=args, target=self._target, task = autotvm.task.create("layout_transform", args=args, target=self._target,
target_host=target_host) target_host=target_host)
task.workload = ltf_workload
tuner = autotvm.tuner.GridSearchTuner(task) tuner = autotvm.tuner.GridSearchTuner(task)
tuner.tune(n_trial=1, measure_option=measure_option, tuner.tune(n_trial=1, measure_option=measure_option,
callbacks=[_log_to_list(records)]) callbacks=[_log_to_list(records)])
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
"""API for graph traversing.""" """API for graph traversing."""
import threading import threading
import topi
import tvm import tvm
from tvm import relay, autotvm from tvm import relay, autotvm
from tvm.relay import transform from tvm.relay import transform
...@@ -30,13 +28,6 @@ from tvm.autotvm.task import TaskExtractEnv ...@@ -30,13 +28,6 @@ from tvm.autotvm.task import TaskExtractEnv
from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
# Setup relay op base name -> topi compute functions
# NOTE: To add more ops, change the following dictionary.
OP2COMPUTE = {
"conv2d" : [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw],
}
def expr2graph(expr, target_ops, node_dict, node_list): def expr2graph(expr, target_ops, node_dict, node_list):
"""Convert relay expr to graph data structure """Convert relay expr to graph data structure
and fetch workloads of target operators. and fetch workloads of target operators.
...@@ -46,8 +37,8 @@ def expr2graph(expr, target_ops, node_dict, node_list): ...@@ -46,8 +37,8 @@ def expr2graph(expr, target_ops, node_dict, node_list):
expr : tvm.relay.Expr.Function expr : tvm.relay.Expr.Function
Input relay function expression. Input relay function expression.
target_ops: List of str target_ops: List of relay.op.Op
List of target relay base op name List of target relay ops
node_dict : dictionary from tvm.relay.Expr to int node_dict : dictionary from tvm.relay.Expr to int
Dictionary to record node index Dictionary to record node index
...@@ -58,14 +49,11 @@ def expr2graph(expr, target_ops, node_dict, node_list): ...@@ -58,14 +49,11 @@ def expr2graph(expr, target_ops, node_dict, node_list):
{"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type], {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type],
"name": str, "workloads": [tuple], "topi_op": [function]} "name": str, "workloads": [tuple], "topi_op": [function]}
""" """
# TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact
# that # autotvm tasks == # ops. But this won't be true after having relay op
# strategy. We need to find a solution to fix this.
env = TaskExtractEnv.get(allow_duplicate=True) env = TaskExtractEnv.get(allow_duplicate=True)
topi_funcs = [] env.reset(target_ops)
for op_name in target_ops:
if op_name not in OP2COMPUTE:
raise RuntimeError("Not supported relay op in graph tuner: %s"
% op_name)
topi_funcs += OP2COMPUTE[op_name]
env.reset(topi_funcs)
# pylint: disable=not-context-manager # pylint: disable=not-context-manager
with env: with env:
_expr2graph_impl(expr, target_ops, node_dict, node_list) _expr2graph_impl(expr, target_ops, node_dict, node_list)
...@@ -75,8 +63,7 @@ def expr2graph(expr, target_ops, node_dict, node_list): ...@@ -75,8 +63,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
task_name, args = env.task_collection[task_pos] task_name, args = env.task_collection[task_pos]
task = autotvm.task.create(task_name, args, task = autotvm.task.create(task_name, args,
target="llvm", target="llvm",
target_host=None, target_host=None)
template_key='direct')
node_entry["workloads"] = [task.workload] node_entry["workloads"] = [task.workload]
node_entry["topi_op"] = [task_name] node_entry["topi_op"] = [task_name]
task_pos += 1 task_pos += 1
...@@ -98,11 +85,11 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): ...@@ -98,11 +85,11 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
return return
node_index = len(node_list) node_index = len(node_list)
node_entry = {"node": node, "inputs": [], "types": [], node_entry = {"node": node, "inputs": [], "types": [],
"op": "null", "name": None} "op": None, "name": None}
if isinstance(node, Call): if isinstance(node, Call):
op_name = node.op.name.split(".")[-1] op = node.op
node_entry["op"] = op_name node_entry["op"] = node.op
for arg in node.args: for arg in node.args:
in_node_idx = node_dict[arg] in_node_idx = node_dict[arg]
if isinstance(arg, (Tuple, TupleGetItem)): if isinstance(arg, (Tuple, TupleGetItem)):
...@@ -118,12 +105,12 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): ...@@ -118,12 +105,12 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
node_entry["types"].append(tupe_type) node_entry["types"].append(tupe_type)
else: else:
raise RuntimeError("Unsupported output type %s in operator %s" raise RuntimeError("Unsupported output type %s in operator %s"
% (type(out_type), op_name)) % (type(out_type), op.name))
# Utilize tracing target to fetch workload with topo-order. # Utilize tracing target to fetch workload with topo-order.
# Since we only need workload, dummy target can be used to # Since we only need workload, dummy target can be used to
# create task. # create task.
if op_name in target_ops: if op in target_ops:
params = [] params = []
for i, input_idx in enumerate(node_entry["inputs"]): for i, input_idx in enumerate(node_entry["inputs"]):
input_node_entry = node_list[input_idx[0]] input_node_entry = node_list[input_idx[0]]
...@@ -133,7 +120,7 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): ...@@ -133,7 +120,7 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
"operators with input node of type " "operators with input node of type "
"relay.expr.Var/Constant/Call. Now " "relay.expr.Var/Constant/Call. Now "
"find a target op %s with input type %s" "find a target op %s with input type %s"
% (op_name, str(type(input_node_entry["node"])))) % (op, str(type(input_node_entry["node"]))))
free_var = relay.Var("var_%d" % i, input_type) free_var = relay.Var("var_%d" % i, input_type)
params.append(free_var) params.append(free_var)
call = relay.Call(node.op, params, node.attrs) call = relay.Call(node.op, params, node.attrs)
...@@ -155,11 +142,9 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): ...@@ -155,11 +142,9 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
_expr2graph_impl(node, target_ops, node_dict, node_list) _expr2graph_impl(node, target_ops, node_dict, node_list)
return return
elif isinstance(node, TupleGetItem): elif isinstance(node, TupleGetItem):
node_entry["op"] = "TupleGetItem"
in_node_idx = node_dict[node.tuple_value] in_node_idx = node_dict[node.tuple_value]
node_entry["inputs"].append([in_node_idx, node.index, 0]) node_entry["inputs"].append([in_node_idx, node.index, 0])
elif isinstance(node, Tuple): elif isinstance(node, Tuple):
node_entry["op"] = "Tuple"
for tuple_item in node: for tuple_item in node:
in_node_idx = node_dict[tuple_item] in_node_idx = node_dict[tuple_item]
if isinstance(tuple_item, TupleGetItem): if isinstance(tuple_item, TupleGetItem):
......
...@@ -47,7 +47,7 @@ def has_multiple_inputs(node_list, node_idx, input_names): ...@@ -47,7 +47,7 @@ def has_multiple_inputs(node_list, node_idx, input_names):
in_idx = in_idx[0] in_idx = in_idx[0]
in_node = node_list[in_idx] in_node = node_list[in_idx]
# Exclude parameter nodes # Exclude parameter nodes
if in_node["op"] != "null" or \ if in_node["op"] is not None or \
("name" in in_node and in_node["name"] in input_names): ("name" in in_node and in_node["name"] in input_names):
num_inputs += 1 num_inputs += 1
return num_inputs > 1 return num_inputs > 1
...@@ -72,9 +72,10 @@ def is_boundary_node(node_entry, input_names): ...@@ -72,9 +72,10 @@ def is_boundary_node(node_entry, input_names):
whether node is a boundary node. whether node is a boundary node.
""" """
# Operators dependent on original layouts. # Operators dependent on original layouts.
_LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape", _LAYOUT_FIXED_OP = [relay.op.get(name) for name in (
"multibox_prior", "multibox_transform_loc", "where", "nn.batch_flatten", "transpose", "reshape", "vision.multibox_prior",
"non_max_suppression", "strided_slice"] "vision.multibox_transform_loc", "where", "vision.non_max_suppression",
"strided_slice")]
out = node_entry["op"] in _LAYOUT_FIXED_OP or \ out = node_entry["op"] in _LAYOUT_FIXED_OP or \
("name" in node_entry and node_entry["name"] in input_names) ("name" in node_entry and node_entry["name"] in input_names)
...@@ -95,9 +96,7 @@ def is_skipped_node(node_entry): ...@@ -95,9 +96,7 @@ def is_skipped_node(node_entry):
whether node is skipped. whether node is skipped.
""" """
# Operators not counted in graph tuner. # Operators not counted in graph tuner.
_SKIPPED_OP = ["Tuple"] return isinstance(node_entry["node"], relay.Tuple)
return node_entry["op"] in _SKIPPED_OP
def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
......
...@@ -28,14 +28,16 @@ import time ...@@ -28,14 +28,16 @@ import time
import os import os
import itertools import itertools
from collections import OrderedDict from collections import OrderedDict
import numpy as np
from .. import build, lower, target as _target from .. import build, lower, target as _target
from .. import __version__
from . import task from . import task
from .task import ConfigEntity, ApplyHistoryBest from .task import ConfigEntity, ApplyHistoryBest
from .measure import MeasureInput, MeasureResult from .measure import MeasureInput, MeasureResult
AUTOTVM_LOG_VERSION = 0.1 AUTOTVM_LOG_VERSION = 0.2
_old_version_warning = True
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
try: # convert unicode to str for python2 try: # convert unicode to str for python2
...@@ -88,27 +90,30 @@ def encode(inp, result, protocol='json'): ...@@ -88,27 +90,30 @@ def encode(inp, result, protocol='json'):
if protocol == 'json': if protocol == 'json':
json_dict = { json_dict = {
"i": (str(inp.target), "input": (str(inp.target),
inp.task.name, inp.task.args, inp.task.kwargs, inp.task.name, inp.task.args, inp.task.kwargs),
inp.task.workload,
inp.config.to_json_dict()), "config": inp.config.to_json_dict(),
"r": (result.costs if result.error_no == 0 else (1e9,), "result": (result.costs if result.error_no == 0 else (1e9,),
result.error_no, result.error_no,
result.all_cost, result.all_cost,
result.timestamp), result.timestamp),
"v": AUTOTVM_LOG_VERSION "version": AUTOTVM_LOG_VERSION,
"tvm_version": __version__
} }
return json.dumps(json_dict) return json.dumps(json_dict)
if protocol == 'pickle': if protocol == 'pickle':
row = (str(inp.target), row = (str(inp.target),
str(base64.b64encode(pickle.dumps([inp.task.name, str(base64.b64encode(pickle.dumps([inp.task.name,
inp.task.args, inp.task.args,
inp.task.kwargs, inp.task.kwargs])).decode()),
inp.task.workload])).decode()),
str(base64.b64encode(pickle.dumps(inp.config)).decode()), str(base64.b64encode(pickle.dumps(inp.config)).decode()),
str(base64.b64encode(pickle.dumps(tuple(result))).decode())) str(base64.b64encode(pickle.dumps(tuple(result))).decode()),
str(AUTOTVM_LOG_VERSION),
str(__version__))
return '\t'.join(row) return '\t'.join(row)
raise RuntimeError("Invalid log protocol: " + protocol) raise RuntimeError("Invalid log protocol: " + protocol)
...@@ -119,20 +124,29 @@ def decode(row, protocol='json'): ...@@ -119,20 +124,29 @@ def decode(row, protocol='json'):
Parameters Parameters
---------- ----------
row: str row : str
a row in the logger file a row in the logger file
protocol: str
protocol : str
log protocol, json or pickle log protocol, json or pickle
Returns Returns
------- -------
input: autotvm.tuner.MeasureInput ret : tuple(autotvm.tuner.MeasureInput, autotvm.tuner.MeasureResult), or None
result: autotvm.tuner.MeasureResult The tuple of input and result, or None if input uses old version log format.
""" """
# pylint: disable=unused-variable # pylint: disable=unused-variable
global _old_version_warning
if protocol == 'json': if protocol == 'json':
row = json.loads(row) row = json.loads(row)
tgt, task_name, task_args, task_kwargs, workload, config = row['i'] if 'v' in row and row['v'] == 0.1:
if _old_version_warning:
logger.warning("AutoTVM log version 0.1 is no longer supported.")
_old_version_warning = False
return None
tgt, task_name, task_args, task_kwargs = row["input"]
tgt = _target.create(str(tgt)) tgt = _target.create(str(tgt))
def clean_json_to_python(x): def clean_json_to_python(x):
...@@ -148,22 +162,27 @@ def decode(row, protocol='json'): ...@@ -148,22 +162,27 @@ def decode(row, protocol='json'):
return x return x
tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args)) tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
tsk.workload = clean_json_to_python(workload) config = ConfigEntity.from_json_dict(row["config"])
config = ConfigEntity.from_json_dict(config)
inp = MeasureInput(tgt, tsk, config) inp = MeasureInput(tgt, tsk, config)
result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["r"]]) result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["result"]])
config.cost = np.mean(result.costs)
return inp, result return inp, result
if protocol == 'pickle': if protocol == 'pickle':
items = row.split("\t") items = row.split("\t")
if len(items) == 4:
if _old_version_warning:
logger.warning("AutoTVM log version 0.1 is no longer supported.")
_old_version_warning = False
return None
tgt = _target.create(items[0]) tgt = _target.create(items[0])
task_tuple = pickle.loads(base64.b64decode(items[1].encode())) task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
config = pickle.loads(base64.b64decode(items[2].encode())) config = pickle.loads(base64.b64decode(items[2].encode()))
result = pickle.loads(base64.b64decode(items[3].encode())) result = MeasureResult(*pickle.loads(base64.b64decode(items[3].encode())))
config.cost = np.mean(result.costs)
tsk = task.Task(task_tuple[0], task_tuple[1]) tsk = task.Task(task_tuple[0], task_tuple[1])
tsk.workload = task_tuple[3] return MeasureInput(tgt, tsk, config), result
return MeasureInput(tgt, tsk, config), MeasureResult(*result)
raise RuntimeError("Invalid log protocol: " + protocol) raise RuntimeError("Invalid log protocol: " + protocol)
...@@ -183,7 +202,10 @@ def load_from_file(filename): ...@@ -183,7 +202,10 @@ def load_from_file(filename):
""" """
for row in open(filename): for row in open(filename):
if row and not row.startswith('#'): if row and not row.startswith('#'):
inp, res = decode(row) ret = decode(row)
if ret is None:
continue
inp, res = ret
# Avoid loading the record with an empty config. The TOPI schedule with no entities # Avoid loading the record with an empty config. The TOPI schedule with no entities
# will result in an empty entity map (e.g., depthwise_conv2d_nchw on x86). # will result in an empty entity map (e.g., depthwise_conv2d_nchw on x86).
# Using an empty config will cause problems when applying alter op like NCHW to NCHWc. # Using an empty config will cause problems when applying alter op like NCHW to NCHWc.
...@@ -208,7 +230,7 @@ def split_workload(in_file, clean=True): ...@@ -208,7 +230,7 @@ def split_workload(in_file, clean=True):
logger.info("start converting...") logger.info("start converting...")
pool = multiprocessing.Pool() pool = multiprocessing.Pool()
lines = pool.map(decode, lines) lines = [rec for rec in pool.map(decode, lines) if rec is not None]
logger.info("map done %.2f", time.time() - tic) logger.info("map done %.2f", time.time() - tic)
wkl_dict = OrderedDict() wkl_dict = OrderedDict()
......
...@@ -22,12 +22,13 @@ This module defines the task data structure, as well as a collection(zoo) ...@@ -22,12 +22,13 @@ This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest. of typical tasks of interest.
""" """
from .task import Task, create, register, template, get_config, args_to_workload from .task import Task, create, get_config, args_to_workload, \
register_customized_task
from .space import ConfigSpace, ConfigEntity from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \ from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
FallbackContext, clear_fallback_cache, ApplyGraphBest FallbackContext, clear_fallback_cache, ApplyGraphBest
from .topi_integration import register_topi_compute, register_topi_schedule, \ from .topi_integration import register_topi_compute, register_topi_schedule, \
TaskExtractEnv TaskExtractEnv, get_workload
from .relay_integration import extract_from_program, extract_from_multiple_program from .relay_integration import extract_from_program, extract_from_multiple_program
...@@ -33,9 +33,6 @@ from __future__ import absolute_import as _abs ...@@ -33,9 +33,6 @@ from __future__ import absolute_import as _abs
import logging import logging
import numpy as np import numpy as np
from decorator import decorate
from tvm import target as _target
from .space import FallbackConfigEntity from .space import FallbackConfigEntity
...@@ -152,79 +149,6 @@ class DispatchContext(object): ...@@ -152,79 +149,6 @@ class DispatchContext(object):
DispatchContext.current = self._old_ctx DispatchContext.current = self._old_ctx
def dispatcher(fworkload):
"""Wrap a workload dispatcher function.
Parameters
----------
fworkload : function
The workload extraction function from arguments.
Returns
-------
fdispatcher : function
A wrapped dispatcher function, which will
dispatch based on DispatchContext and
the current workload.
"""
dispatch_dict = {}
func_name = fworkload.__name__
def register(key, func=None, override=False):
"""Register template function.
Parameters
----------
key : str or List of str
The template key to identify the template
under this dispatcher.
func : function
The function to be registered.
The first argument of the function is always
cfg returned by DispatchContext,
the rest arguments are the same as the fworkload.
override : bool
Whether override existing registration.
Returns
-------
The register function if necessary.
"""
if isinstance(key, str):
key = [key]
def _do_reg(myf):
for x in key:
if x in dispatch_dict and not override:
raise ValueError(
"Key %s is already registered for %s" % (x, func_name))
dispatch_dict[x] = myf
return myf
if func:
return _do_reg(func)
return _do_reg
def dispatch_func(func, *args, **kwargs):
"""The wrapped dispatch function"""
tgt = _target.Target.current()
workload = func(*args, **kwargs)
cfg = DispatchContext.current.query(tgt, workload)
if cfg.is_fallback and not cfg.template_key:
# first try 'direct' template
if 'direct' in dispatch_dict:
return dispatch_dict['direct'](cfg, *args, **kwargs)
# otherwise pick a random template
for v in dispatch_dict.values():
return v(cfg, *args, **kwargs)
else:
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
fdecorate = decorate(fworkload, dispatch_func)
fdecorate.register = register
return fdecorate
class ApplyConfig(DispatchContext): class ApplyConfig(DispatchContext):
"""Apply a deterministic config entity for all queries. """Apply a deterministic config entity for all queries.
...@@ -334,7 +258,8 @@ class ApplyHistoryBest(DispatchContext): ...@@ -334,7 +258,8 @@ class ApplyHistoryBest(DispatchContext):
if key in self._best_user_defined: if key in self._best_user_defined:
return self._best_user_defined[key] return self._best_user_defined[key]
if key in self.best_by_model: if key in self.best_by_model:
return self.best_by_model[key][0].config inp, _ = self.best_by_model[key]
return inp.config
# then try matching by target key # then try matching by target key
for k in target.keys: for k in target.keys:
...@@ -342,13 +267,16 @@ class ApplyHistoryBest(DispatchContext): ...@@ -342,13 +267,16 @@ class ApplyHistoryBest(DispatchContext):
if key in self._best_user_defined: if key in self._best_user_defined:
return self._best_user_defined[key] return self._best_user_defined[key]
if key in self.best_by_targetkey: if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config inp, _ = self.best_by_targetkey[key]
return inp.config
return None return None
def update(self, target, workload, cfg): def update(self, target, workload, cfg):
model = target.model model = target.model
key = (model, workload) key = (model, workload)
# assume user provided config is the best
cfg.cost = 0
self._best_user_defined[key] = cfg self._best_user_defined[key] = cfg
for k in target.keys: for k in target.keys:
...@@ -481,8 +409,12 @@ class ApplyGraphBest(DispatchContext): ...@@ -481,8 +409,12 @@ class ApplyGraphBest(DispatchContext):
""" """
if self._counter < len(self._records): if self._counter < len(self._records):
cfg = self._records[self._counter][0].config cfg = self._records[self._counter][0].config
wkl = self._records[self._counter][0].task.workload
if workload is not None:
assert wkl == workload
self._counter += 1 self._counter += 1
self.update(target, workload, cfg) self.update(target, wkl, cfg)
cfg.workload = wkl
return cfg return cfg
key = (str(target), workload) key = (str(target), workload)
if key not in self._global_cfg_dict: if key not in self._global_cfg_dict:
......
...@@ -21,10 +21,9 @@ Decorator and utilities for the integration with TOPI and Relay ...@@ -21,10 +21,9 @@ Decorator and utilities for the integration with TOPI and Relay
""" """
import threading import threading
import warnings
import logging import logging
import tvm
from .task import create from .task import create
from .topi_integration import TaskExtractEnv from .topi_integration import TaskExtractEnv
...@@ -55,8 +54,7 @@ def _lower(mod, ...@@ -55,8 +54,7 @@ def _lower(mod,
compiler.lower(mod, target=target) compiler.lower(mod, target=target)
def extract_from_program(mod, params, ops, target, target_host=None, def extract_from_program(mod, params, target, target_host=None, ops=None):
template_keys=None):
""" Extract tuning tasks from a relay program. """ Extract tuning tasks from a relay program.
This function is the single program version of extract_from_multiple_program. This function is the single program version of extract_from_multiple_program.
...@@ -67,27 +65,22 @@ def extract_from_program(mod, params, ops, target, target_host=None, ...@@ -67,27 +65,22 @@ def extract_from_program(mod, params, ops, target, target_host=None,
The module or function to tune The module or function to tune
params: dict of str to numpy array params: dict of str to numpy array
The associated parameters of the program The associated parameters of the program
ops: List of relay op
List of relay ops to be tuned
target: tvm.target.Target target: tvm.target.Target
The compilation target The compilation target
target_host: tvm.target.Target target_host: tvm.target.Target
The host compilation target The host compilation target
template_keys: dict of topi op to str ops: List[relay.op.Op] or None
The tuning template keys map for schedules, default to None. List of relay ops to be tuned. If not specified, all tunable ops will be extracted.
Example: {topi.nn.conv2d: 'direct'}
Returns Returns
------- -------
task: Array of autotvm.task.Task task: Array of autotvm.task.Task
collected tasks collected tasks
""" """
return extract_from_multiple_program([mod], [params], ops, target, target_host, return extract_from_multiple_program([mod], [params], target, target_host, ops)
template_keys)
def extract_from_multiple_program(mods, params, ops, target, target_host=None, def extract_from_multiple_program(mods, params, target, target_host=None, ops=None):
template_keys=None):
""" Extract tuning tasks from multiple relay programs. """ Extract tuning tasks from multiple relay programs.
This function collects tuning tasks by building a list of programs This function collects tuning tasks by building a list of programs
...@@ -99,15 +92,12 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, ...@@ -99,15 +92,12 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
The list of modules or functions to tune The list of modules or functions to tune
params: List of dict of str to numpy array params: List of dict of str to numpy array
The associated parameters of the programs The associated parameters of the programs
ops: List of relay op
List of relay ops to be tuned
target: tvm.target.Target target: tvm.target.Target
The compilation target The compilation target
target_host: tvm.target.Target target_host: tvm.target.Target
The host compilation target The host compilation target
template_keys: dict of topi op to str ops: List[relay.op.Op] or None
The tuning template keys map for schedules, default to None. List of relay ops to be tuned. If not specified, all tunable ops will be extracted.
Example: {topi.nn.conv2d: 'direct'}
Returns Returns
------- -------
...@@ -115,36 +105,13 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, ...@@ -115,36 +105,13 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
collected tasks collected tasks
""" """
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
import tvm.relay.op
from tvm import relay from tvm import relay
import topi import topi
env = TaskExtractEnv.get() env = TaskExtractEnv.get()
# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI = {
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw,
topi.nn.conv2d_NCHWc,
topi.nn.conv2d_NCHWc_int8],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
tvm.relay.op.nn.conv1d_transpose: [topi.nn.conv1d_transpose_ncw],
tvm.relay.op.nn.conv3d: [topi.nn.conv3d],
}
topi_funcs = []
for op_name in ops:
if op_name in OP2TOPI:
topi_funcs.extend(OP2TOPI[op_name])
else:
warnings.warn("Op %s is not tunable, ignored" % op_name)
# run compiler to collect all TOPI calls during compilation # run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs) env.reset(ops)
with env: with env:
# disable logger temporarily # disable logger temporarily
old_state = logger.disabled old_state = logger.disabled
...@@ -164,24 +131,12 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, ...@@ -164,24 +131,12 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
logger.disabled = old_state logger.disabled = old_state
# convert *topi op to template key* map to *task name to template key* map
task_name_to_keys = {}
if template_keys is not None:
for op in template_keys.keys():
if op in env.topi_to_task:
task_name_to_keys[env.topi_to_task[op]] = template_keys[op]
else:
logger.warning("Invalid template key, fallback to direct")
task_name_to_keys[env.topi_to_task[op]] = 'direct'
# create tasks for target # create tasks for target
tasks = [] tasks = []
for task_name, args in env.get_tasks(): for task_name, args in env.get_tasks():
try: try:
key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
tsk = create(task_name, args, tsk = create(task_name, args,
target=target, target_host=target_host, target=target, target_host=target_host)
template_key=key)
tasks.append(tsk) tasks.append(tsk)
except topi.InvalidShapeError: except topi.InvalidShapeError:
logger.warning("Invalid shape during AutoTVM task creation") logger.warning("Invalid shape during AutoTVM task creation")
......
...@@ -613,9 +613,9 @@ class ConfigSpace(object): ...@@ -613,9 +613,9 @@ class ConfigSpace(object):
self._entity_map = OrderedDict() # name -> entity self._entity_map = OrderedDict() # name -> entity
self._constraints = [] self._constraints = []
self.errors = [] self.errors = []
self.template_key = None
self.code_hash = None self.code_hash = None
self.flop = 0 self.flop = 0
self.cost = None
self.is_fallback = False self.is_fallback = False
@staticmethod @staticmethod
...@@ -796,7 +796,7 @@ class ConfigSpace(object): ...@@ -796,7 +796,7 @@ class ConfigSpace(object):
for name, space in self.space_map.items(): for name, space in self.space_map.items():
entities[name] = space[t % len(space)] entities[name] = space[t % len(space)]
t //= len(space) t //= len(space)
ret = ConfigEntity(index, self.code_hash, self.template_key, entities, self._constraints) ret = ConfigEntity(index, self.code_hash, entities, self._constraints)
return ret return ret
def __iter__(self): def __iter__(self):
...@@ -836,17 +836,14 @@ class ConfigEntity(ConfigSpace): ...@@ -836,17 +836,14 @@ class ConfigEntity(ConfigSpace):
index of this config in space index of this config in space
code_hash: str code_hash: str
hash of schedule code hash of schedule code
template_key : str
The specific template key
entity_map: dict entity_map: dict
map name to transform entity map name to transform entity
constraints : list constraints : list
List of constraints List of constraints
""" """
def __init__(self, index, code_hash, template_key, entity_map, constraints): def __init__(self, index, code_hash, entity_map, constraints):
super(ConfigEntity, self).__init__() super(ConfigEntity, self).__init__()
self.index = index self.index = index
self.template_key = template_key
self._collect = False self._collect = False
self._entity_map = entity_map self._entity_map = entity_map
self._space_map = None self._space_map = None
...@@ -896,9 +893,8 @@ class ConfigEntity(ConfigSpace): ...@@ -896,9 +893,8 @@ class ConfigEntity(ConfigSpace):
a json serializable dictionary a json serializable dictionary
""" """
ret = {} ret = {}
ret['i'] = int(self.index) ret['index'] = int(self.index)
ret['t'] = self.template_key ret['code_hash'] = self.code_hash
ret['c'] = self.code_hash
entity_map = [] entity_map = []
for k, v in self._entity_map.items(): for k, v in self._entity_map.items():
if isinstance(v, SplitEntity): if isinstance(v, SplitEntity):
...@@ -911,7 +907,7 @@ class ConfigEntity(ConfigSpace): ...@@ -911,7 +907,7 @@ class ConfigEntity(ConfigSpace):
entity_map.append((k, 'ot', v.val)) entity_map.append((k, 'ot', v.val))
else: else:
raise RuntimeError("Invalid entity instance: " + v) raise RuntimeError("Invalid entity instance: " + v)
ret['e'] = entity_map ret['entity'] = entity_map
return ret return ret
@staticmethod @staticmethod
...@@ -930,13 +926,12 @@ class ConfigEntity(ConfigSpace): ...@@ -930,13 +926,12 @@ class ConfigEntity(ConfigSpace):
The corresponding config object The corresponding config object
""" """
index = json_dict["i"] index = json_dict["index"]
code_hash = json_dict["c"] code_hash = json_dict["code_hash"]
template_key = json_dict["t"]
constraints = [] constraints = []
entity_map = OrderedDict() entity_map = OrderedDict()
for item in json_dict["e"]: for item in json_dict["entity"]:
key, knob_type, knob_args = item key, knob_type, knob_args = item
if knob_type == 'sp': if knob_type == 'sp':
entity = SplitEntity(knob_args) entity = SplitEntity(knob_args)
...@@ -950,11 +945,10 @@ class ConfigEntity(ConfigSpace): ...@@ -950,11 +945,10 @@ class ConfigEntity(ConfigSpace):
raise RuntimeError("Invalid config knob type: " + knob_type) raise RuntimeError("Invalid config knob type: " + knob_type)
entity_map[str(key)] = entity entity_map[str(key)] = entity
return ConfigEntity(index, code_hash, template_key, entity_map, constraints) return ConfigEntity(index, code_hash, entity_map, constraints)
def __repr__(self): def __repr__(self):
return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key, return "%s,%s,%d" % (str(self._entity_map)[12:-1], self.code_hash, self.index)
self.code_hash, self.index)
class FallbackConfigEntity(ConfigSpace): class FallbackConfigEntity(ConfigSpace):
...@@ -1068,4 +1062,4 @@ class FallbackConfigEntity(ConfigSpace): ...@@ -1068,4 +1062,4 @@ class FallbackConfigEntity(ConfigSpace):
self._entity_map[name] = entity self._entity_map[name] = entity
def __repr__(self): def __repr__(self):
return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash) return "%s,%s" % (str(self._entity_map)[12:-1], self.code_hash)
...@@ -46,16 +46,16 @@ AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub ...@@ -46,16 +46,16 @@ AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub
# the version of each package # the version of each package
PACKAGE_VERSION = { PACKAGE_VERSION = {
'arm_cpu': "v0.04", 'arm_cpu': "v0.06",
'llvm': "v0.03", 'llvm': "v0.04",
'cuda': "v0.06", 'cuda': "v0.08",
'rocm': "v0.03", 'rocm': "v0.04",
'opencl': "v0.03", 'opencl': "v0.04",
'mali': "v0.05", 'mali': "v0.06",
'intel_graphics': "v0.01", 'intel_graphics': "v0.02",
'vta': "v0.06", 'vta': "v0.08",
} }
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
...@@ -189,7 +189,7 @@ def download_package(tophub_location, package_name): ...@@ -189,7 +189,7 @@ def download_package(tophub_location, package_name):
# global cache for load_reference_log # global cache for load_reference_log
REFERENCE_LOG_CACHE = {} REFERENCE_LOG_CACHE = {}
def load_reference_log(backend, model, workload_name, template_key): def load_reference_log(backend, model, workload_name):
""" Load reference log from TopHub to support fallback in template. """ Load reference log from TopHub to support fallback in template.
Template will use these reference logs to choose fallback config. Template will use these reference logs to choose fallback config.
...@@ -201,8 +201,6 @@ def load_reference_log(backend, model, workload_name, template_key): ...@@ -201,8 +201,6 @@ def load_reference_log(backend, model, workload_name, template_key):
The name of the device model The name of the device model
workload_name: str workload_name: str
The name of the workload. (The first item in the workload tuple) The name of the workload. (The first item in the workload tuple)
template_key: str
The template key
""" """
backend = _alias(backend) backend = _alias(backend)
...@@ -211,7 +209,7 @@ def load_reference_log(backend, model, workload_name, template_key): ...@@ -211,7 +209,7 @@ def load_reference_log(backend, model, workload_name, template_key):
filename = os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name) filename = os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)
global REFERENCE_LOG_CACHE global REFERENCE_LOG_CACHE
key = (backend, model, workload_name, template_key) key = (backend, model, workload_name)
if key not in REFERENCE_LOG_CACHE: if key not in REFERENCE_LOG_CACHE:
tmp = [] tmp = []
...@@ -233,8 +231,7 @@ def load_reference_log(backend, model, workload_name, template_key): ...@@ -233,8 +231,7 @@ def load_reference_log(backend, model, workload_name, template_key):
model = max(counts.items(), key=lambda k: k[1])[0] model = max(counts.items(), key=lambda k: k[1])[0]
for inp, res in load_from_file(filename): for inp, res in load_from_file(filename):
if (model == inp.target.model and inp.task.workload[0] == workload_name and if model == inp.target.model and inp.task.workload[0] == workload_name:
inp.config.template_key == template_key):
tmp.append((inp, res)) tmp.append((inp, res))
REFERENCE_LOG_CACHE[key] = tmp REFERENCE_LOG_CACHE[key] = tmp
......
...@@ -219,8 +219,7 @@ class XGBoostCostModel(CostModel): ...@@ -219,8 +219,7 @@ class XGBoostCostModel(CostModel):
# filter data, only pick the data with a same task # filter data, only pick the data with a same task
data = [] data = []
for inp, res in records: for inp, res in records:
if inp.task.name == self.task.name and \ if inp.task.name == self.task.name:
inp.config.template_key == self.task.config_space.template_key:
data.append((inp, res)) data.append((inp, res))
logger.debug("XGB load %d entries from history log file", len(data)) logger.debug("XGB load %d entries from history log file", len(data))
......
...@@ -14,18 +14,30 @@ ...@@ -14,18 +14,30 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=len-as-condition,no-else-return,invalid-name
"""Backend code generation engine.""" """Backend code generation engine."""
from __future__ import absolute_import from __future__ import absolute_import
import logging
import numpy as np
import tvm
from ..base import register_relay_node, Object from ..base import register_relay_node, Object
from ... import target as _target from ... import target as _target
from ... import autotvm
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op
from .. import ty as _ty
from . import _backend from . import _backend
logger = logging.getLogger('compile_engine')
@register_relay_node @register_relay_node
class CachedFunc(Object): class LoweredOutput(Object):
"""Low-level tensor function to back a relay primitive function. """Lowered output"""
""" def __init__(self, outputs, implement):
self.__init_handle_by_constructor__(
_backend._make_LoweredOutput, outputs, implement)
@register_relay_node @register_relay_node
...@@ -63,6 +75,191 @@ def _get_cache_key(source_func, target): ...@@ -63,6 +75,191 @@ def _get_cache_key(source_func, target):
return source_func return source_func
def get_shape(shape):
"""Convert the shape to correct dtype and vars."""
ret = []
for dim in shape:
if isinstance(dim, tvm.expr.IntImm):
val = int(dim)
assert val <= np.iinfo(np.int32).max
ret.append(tvm.expr.IntImm("int32", val))
elif isinstance(dim, tvm.expr.Any):
ret.append(tvm.var("any_dim", "int32"))
else:
ret.append(dim)
return ret
def get_valid_implementations(op, attrs, inputs, out_type, target):
"""Get all valid implementations from the op strategy.
Note that this function doesn't support op with symbolic input shapes.
Parameters
----------
op : relay.op.Op
Relay operator.
attrs : object
The op attribute.
inputs : List[tvm.Tensor]
Input tensors to the op.
out_type : relay.Type
The output type.
target : tvm.target.Target
The target to compile the op.
Returns
-------
ret : List[relay.op.OpImplementation]
The list of all valid op implementations.
"""
fstrategy = op.get_attr("FTVMStrategy")
assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
with target:
strategy = fstrategy(attrs, inputs, out_type, target)
analyzer = tvm.arith.Analyzer()
ret = []
for spec in strategy.specializations:
if spec.condition:
# check if all the clauses in the specialized condition are true
flag = True
for clause in spec.condition.clauses:
clause = analyzer.canonical_simplify(clause)
if isinstance(clause, tvm.expr.IntImm) and clause.value:
continue
flag = False
break
if flag:
for impl in spec.implementations:
ret.append(impl)
else:
for impl in spec.implementations:
ret.append(impl)
return ret
def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True):
"""Select the best implementation from the op strategy.
If use_autotvm is True, it'll first try to find the best implementation
based on AutoTVM profile results. If no AutoTVM profile result is found,
it'll choose the implementation with highest plevel.
If use_autotvm is False, it'll directly choose the implementation with
highest plevel.
Note that this function doesn't support op with symbolic input shapes.
Parameters
----------
op : relay.op.Op
Relay operator.
attrs : object
The op attribute.
inputs : List[tvm.Tensor]
Input tensors to the op.
out_type : relay.Type
The output type.
target : tvm.target.Target
The target to compile the op.
use_autotvm : bool
Whether query AutoTVM to pick the best.
Returns
-------
ret : tuple(relay.op.OpImplementation, List[tvm.Tensor])
The best op implementation and the corresponding output tensors.
"""
all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
best_plevel_impl = None
for impl in all_impls:
if best_plevel_impl is None or impl.plevel > best_plevel_impl.plevel:
best_plevel_impl = impl
if not use_autotvm:
outs = best_plevel_impl.compute(attrs, inputs, out_type)
return best_plevel_impl, outs
outputs = {}
best_autotvm_impl = None
best_cfg = None
dispatch_ctx = autotvm.task.DispatchContext.current
for impl in all_impls:
outs = impl.compute(attrs, inputs, out_type)
outputs[impl] = outs
workload = autotvm.task.get_workload(outs)
if workload is None:
continue
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback:
# It's a fallback config
continue
if best_cfg is None or best_cfg.cost > cfg.cost:
best_autotvm_impl = impl
best_cfg = cfg
if best_autotvm_impl:
return best_autotvm_impl, outputs[best_autotvm_impl]
return best_plevel_impl, outputs[best_plevel_impl]
@tvm._ffi.register_func("relay.backend.lower_call")
def lower_call(call, inputs, target):
"""Lower the call expression to op implementation and tensor outputs."""
assert isinstance(call.op, _op.Op)
op = call.op
# Prepare the call_node->checked_type(). For the call node inputs, we ensure that
# the shape is Int32. Following code ensures the same for the output as well.
# TODO(@icemelon9): Support recursive tuple
ret_type = call.checked_type
if isinstance(ret_type, _ty.TensorType):
ret_type = _ty.TensorType(get_shape(ret_type.shape), ret_type.dtype)
elif isinstance(ret_type, _ty.TupleType):
new_fields = []
for field in ret_type.fields:
if isinstance(field, _ty.TensorType):
new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype))
else:
new_fields.append(field)
ret_type = _ty.TupleType(new_fields)
is_dyn = _ty.type_has_any(call.checked_type)
for arg in call.args:
is_dyn = is_dyn or _ty.type_has_any(arg.checked_type)
# check if in the AutoTVM tracing mode, and disable if op is not in wanted list
env = autotvm.task.TaskExtractEnv.current
reenable_tracing = False
if env is not None and env.tracing:
if env.wanted_relay_ops is not None and op not in env.wanted_relay_ops:
env.tracing = False
reenable_tracing = True
if not is_dyn:
best_impl, outputs = select_implementation(
op, call.attrs, inputs, ret_type, target)
logger.info("Use implementation %s for op %s", best_impl.name, op.name)
else:
# TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes.
# Currently, we just use the implementation with highest plevel
best_impl, outputs = select_implementation(
op, call.attrs, inputs, ret_type, target, use_autotvm=False)
# re-enable AutoTVM tracing
if reenable_tracing:
env.tracing = True
return LoweredOutput(outputs, best_impl)
@register_relay_node @register_relay_node
class CompileEngine(Object): class CompileEngine(Object):
"""CompileEngine to get lowered code. """CompileEngine to get lowered code.
......
...@@ -131,22 +131,22 @@ class ExprVisitor(ExprFunctor): ...@@ -131,22 +131,22 @@ class ExprVisitor(ExprFunctor):
The default behavior recursively traverses the AST. The default behavior recursively traverses the AST.
""" """
def visit_tuple(self, t): def visit_tuple(self, tup):
for x in t.fields: for x in tup.fields:
self.visit(x) self.visit(x)
def visit_call(self, c): def visit_call(self, call):
self.visit(c.op) self.visit(call.op)
for a in c.args: for a in call.args:
self.visit(a) self.visit(a)
def visit_var(self, v): def visit_var(self, var):
pass pass
def visit_let(self, l): def visit_let(self, let):
self.visit(l.var) self.visit(let.var)
self.visit(l.value) self.visit(let.value)
self.visit(l.body) self.visit(let.body)
def visit_function(self, f): def visit_function(self, f):
self.visit(f.body) self.visit(f.body)
......
...@@ -311,6 +311,7 @@ def _conv(opname): ...@@ -311,6 +311,7 @@ def _conv(opname):
flip_layout = True flip_layout = True
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
in_channels = input_shape[3]
kernel_h, kernel_w, _, depth_mult = weights_shape kernel_h, kernel_w, _, depth_mult = weights_shape
attr['kernel_shape'] = (weights_shape[0], weights_shape[1]) attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
if opname == 'conv': if opname == 'conv':
...@@ -324,6 +325,7 @@ def _conv(opname): ...@@ -324,6 +325,7 @@ def _conv(opname):
attr['dilations'] = (attr['dilations'][1], attr['dilations'][2]) attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
attr['strides'] = (attr['strides'][1], attr['strides'][2]) attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW': elif attr['data_format'] == 'NCHW':
in_channels = input_shape[1]
_, depth_mult, kernel_h, kernel_w = weights_shape _, depth_mult, kernel_h, kernel_w = weights_shape
attr['kernel_shape'] = (weights_shape[2], weights_shape[3]) attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
if opname == 'conv': if opname == 'conv':
...@@ -344,7 +346,7 @@ def _conv(opname): ...@@ -344,7 +346,7 @@ def _conv(opname):
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if opname == 'depthwise': if opname == 'depthwise':
attr['groups'] = attr['channels'] attr['groups'] = in_channels
# Fix padding # Fix padding
attr['padding'] = attr['padding'].decode("utf-8") attr['padding'] = attr['padding'].decode("utf-8")
......
...@@ -1156,7 +1156,7 @@ class OperatorConverter(object): ...@@ -1156,7 +1156,7 @@ class OperatorConverter(object):
if is_depthwise_conv: if is_depthwise_conv:
params['channels'] = int(in_channels) params['channels'] = int(in_channels)
params['groups'] = int(in_channels) params['groups'] = int(input_c)
params['kernel_layout'] = 'HWOI' params['kernel_layout'] = 'HWOI'
else: else:
params['channels'] = int(output_channels) params['channels'] = int(output_channels)
......
...@@ -28,8 +28,8 @@ from .backend import compile_engine ...@@ -28,8 +28,8 @@ from .backend import compile_engine
def is_primitive(call): def is_primitive(call):
return hasattr(call.op, 'attrs') and hasattr(call.op.attrs, 'Primitive') and \ return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \
int(call.op.attrs.Primitive) == 1 hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1
# TODO(@jroesch): port to c++ and unify with existing code # TODO(@jroesch): port to c++ and unify with existing code
class LinearizeRetType: class LinearizeRetType:
......
...@@ -17,9 +17,10 @@ ...@@ -17,9 +17,10 @@
#pylint: disable=wildcard-import, redefined-builtin #pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators.""" """Relay core operators."""
# operator defs # operator defs
from .op import get, register, register_schedule, register_compute, register_gradient, \ from .op import get, register, register_compute, register_gradient, \
register_pattern, register_alter_op_layout, register_legalize, \ register_pattern, register_alter_op_layout, register_legalize, \
schedule_injective, Op, OpPattern, debug Op, OpPattern, OpStrategy, debug
from . import strategy
# Operators # Operators
from .reduce import * from .reduce import *
......
...@@ -18,48 +18,14 @@ ...@@ -18,48 +18,14 @@
# pylint: disable=invalid-name,unused-argument # pylint: disable=invalid-name,unused-argument
from __future__ import absolute_import from __future__ import absolute_import
import topi from . import strategy
from topi.util import get_const_int from .op import OpPattern, register_pattern
from ..op import OpPattern, register_compute, register_schedule, register_pattern from .op import register_strategy
@register_schedule("argsort")
def schedule_argsort(_, outs, target):
"""Schedule definition of argsort"""
with target:
return topi.generic.schedule_argsort(outs)
@register_compute("argsort")
def compute_argsort(attrs, inputs, _, target):
"""Compute definition of argsort"""
axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
# argsort
register_strategy("argsort", strategy.argsort_strategy)
register_pattern("argsort", OpPattern.OPAQUE) register_pattern("argsort", OpPattern.OPAQUE)
# topk
@register_schedule("topk") register_strategy("topk", strategy.topk_strategy)
def schedule_topk(_, outs, target):
"""Schedule definition of argsort"""
with target:
return topi.generic.schedule_topk(outs)
@register_compute("topk")
def compute_topk(attrs, inputs, _, target):
"""Compute definition of argsort"""
k = get_const_int(attrs.k)
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype)
out = out if isinstance(out, list) else [out]
return out
register_pattern("topk", OpPattern.OPAQUE) register_pattern("topk", OpPattern.OPAQUE)
...@@ -17,33 +17,21 @@ ...@@ -17,33 +17,21 @@
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
import topi
from topi.util import get_const_int, get_const_tuple from topi.util import get_const_int, get_const_tuple
from . import op as _reg from . import op as _reg
from ...api import convert from ...api import convert
from ...hybrid import script from ...hybrid import script
_reg.register_reduce_schedule("argmax")
def _schedule_reduce(_, outs, target): _reg.register_reduce_schedule("argmin")
"""Generic schedule for reduce""" _reg.register_reduce_schedule("sum")
with target: _reg.register_reduce_schedule("all")
return topi.generic.schedule_reduce(outs) _reg.register_reduce_schedule("any")
_reg.register_reduce_schedule("max")
_reg.register_reduce_schedule("min")
_reg.register_schedule("argmax", _schedule_reduce) _reg.register_reduce_schedule("prod")
_reg.register_schedule("argmin", _schedule_reduce) _reg.register_reduce_schedule("mean")
_reg.register_schedule("sum", _schedule_reduce) _reg.register_reduce_schedule("variance")
_reg.register_schedule("all", _schedule_reduce)
_reg.register_schedule("any", _schedule_reduce)
_reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce)
_reg.register_schedule("mean", _schedule_reduce)
_reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)
def _create_axis_record(attrs, inputs): def _create_axis_record(attrs, inputs):
axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis)) axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis))
......
...@@ -19,101 +19,99 @@ ...@@ -19,101 +19,99 @@
from __future__ import absolute_import from __future__ import absolute_import
import topi import topi
from topi.util import get_const_tuple from topi.util import get_const_tuple
from .op import register_compute, register_schedule, register_pattern, register_shape_func from .op import register_compute, register_shape_func
from .op import schedule_injective, OpPattern from .op import register_broadcast_schedule, register_injective_schedule
from .op import register_pattern, OpPattern
from ...hybrid import script from ...hybrid import script
from ...api import convert from ...api import convert
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective register_broadcast_schedule("log")
register_broadcast_schedule("cos")
register_schedule("log", schedule_broadcast) register_broadcast_schedule("sin")
register_schedule("cos", schedule_broadcast) register_broadcast_schedule("atan")
register_schedule("sin", schedule_broadcast) register_broadcast_schedule("exp")
register_schedule("atan", schedule_broadcast) register_broadcast_schedule("erf")
register_schedule("exp", schedule_broadcast) register_broadcast_schedule("sqrt")
register_schedule("erf", schedule_broadcast) register_broadcast_schedule("rsqrt")
register_schedule("sqrt", schedule_broadcast) register_broadcast_schedule("sigmoid")
register_schedule("rsqrt", schedule_broadcast) register_broadcast_schedule("floor")
register_schedule("sigmoid", schedule_broadcast) register_broadcast_schedule("ceil")
register_schedule("floor", schedule_broadcast) register_broadcast_schedule("trunc")
register_schedule("ceil", schedule_broadcast) register_broadcast_schedule("round")
register_schedule("trunc", schedule_broadcast) register_broadcast_schedule("sign")
register_schedule("round", schedule_broadcast) register_broadcast_schedule("abs")
register_schedule("sign", schedule_broadcast) register_broadcast_schedule("tanh")
register_schedule("abs", schedule_broadcast) register_broadcast_schedule("add")
register_schedule("tanh", schedule_broadcast) register_broadcast_schedule("subtract")
register_schedule("logical_not", schedule_broadcast) register_broadcast_schedule("multiply")
register_schedule("bitwise_not", schedule_broadcast) register_broadcast_schedule("divide")
register_schedule("negative", schedule_broadcast) register_broadcast_schedule("floor_divide")
register_schedule("copy", schedule_broadcast) register_broadcast_schedule("power")
register_broadcast_schedule("copy")
register_schedule("add", schedule_broadcast) register_broadcast_schedule("logical_not")
register_schedule("subtract", schedule_broadcast) register_broadcast_schedule("logical_and")
register_schedule("multiply", schedule_broadcast) register_broadcast_schedule("logical_or")
register_schedule("divide", schedule_broadcast) register_broadcast_schedule("bitwise_not")
register_schedule("floor_divide", schedule_broadcast) register_broadcast_schedule("bitwise_and")
register_schedule("power", schedule_injective) register_broadcast_schedule("bitwise_or")
register_schedule("mod", schedule_broadcast) register_broadcast_schedule("bitwise_xor")
register_schedule("floor_mod", schedule_broadcast) register_broadcast_schedule("negative")
register_schedule("logical_and", schedule_broadcast) register_broadcast_schedule("mod")
register_schedule("logical_or", schedule_broadcast) register_broadcast_schedule("floor_mod")
register_schedule("bitwise_and", schedule_broadcast) register_broadcast_schedule("equal")
register_schedule("bitwise_or", schedule_broadcast) register_broadcast_schedule("not_equal")
register_schedule("bitwise_xor", schedule_broadcast) register_broadcast_schedule("less")
register_schedule("equal", schedule_broadcast) register_broadcast_schedule("less_equal")
register_schedule("not_equal", schedule_broadcast) register_broadcast_schedule("greater")
register_schedule("less", schedule_broadcast) register_broadcast_schedule("greater_equal")
register_schedule("less_equal", schedule_broadcast) register_injective_schedule("maximum")
register_schedule("greater", schedule_broadcast) register_injective_schedule("minimum")
register_schedule("greater_equal", schedule_broadcast) register_injective_schedule("right_shift")
register_schedule("maximum", schedule_injective) register_injective_schedule("left_shift")
register_schedule("minimum", schedule_injective) register_injective_schedule("shape_of")
register_schedule("right_shift", schedule_injective)
register_schedule("left_shift", schedule_injective)
register_schedule("shape_of", schedule_injective)
# zeros # zeros
@register_compute("zeros") @register_compute("zeros")
def zeros_compute(attrs, inputs, output_type, target): def zeros_compute(attrs, inputs, output_type):
assert not inputs assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 0.0)] return [topi.full(output_type.shape, output_type.dtype, 0.0)]
register_schedule("zeros", schedule_broadcast) register_broadcast_schedule("zeros")
register_pattern("zeros", OpPattern.ELEMWISE) register_pattern("zeros", OpPattern.ELEMWISE)
# zeros_like # zeros_like
@register_compute("zeros_like") @register_compute("zeros_like")
def zeros_like_compute(attrs, inputs, output_type, target): def zeros_like_compute(attrs, inputs, output_type):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.full_like(inputs[0], 0.0)] return [topi.full_like(inputs[0], 0.0)]
register_schedule("zeros_like", schedule_broadcast) register_broadcast_schedule("zeros_like")
# ones # ones
@register_compute("ones") @register_compute("ones")
def ones_compute(attrs, inputs, output_type, target): def ones_compute(attrs, inputs, output_type):
assert not inputs assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 1.0)] return [topi.full(output_type.shape, output_type.dtype, 1.0)]
register_schedule("ones", schedule_broadcast) register_broadcast_schedule("ones")
register_pattern("ones", OpPattern.ELEMWISE) register_pattern("ones", OpPattern.ELEMWISE)
# ones_like # ones_like
@register_compute("ones_like") @register_compute("ones_like")
def ones_like(attrs, inputs, output_type, target): def ones_like_compute(attrs, inputs, output_type):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.full_like(inputs[0], 1.0)] return [topi.full_like(inputs[0], 1.0)]
register_schedule("ones_like", schedule_broadcast) register_broadcast_schedule("ones_like")
# clip # clip
@register_compute("clip") @register_compute("clip")
def clip_compute(attrs, inputs, output_type, target): def clip_compute(attrs, inputs, output_type):
assert len(inputs) == 1 assert len(inputs) == 1
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)] return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
register_schedule("clip", schedule_elemwise) register_injective_schedule("clip")
@script @script
def _cast_shape_function(x): def _cast_shape_function(x):
...@@ -198,6 +196,7 @@ register_shape_func("mod", False, broadcast_shape_func) ...@@ -198,6 +196,7 @@ register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("floor_mod", False, broadcast_shape_func) register_shape_func("floor_mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func) register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func) register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("bitwise_not", False, broadcast_shape_func)
register_shape_func("bitwise_and", False, broadcast_shape_func) register_shape_func("bitwise_and", False, broadcast_shape_func)
register_shape_func("bitwise_or", False, broadcast_shape_func) register_shape_func("bitwise_or", False, broadcast_shape_func)
register_shape_func("bitwise_xor", False, broadcast_shape_func) register_shape_func("bitwise_xor", False, broadcast_shape_func)
......
...@@ -21,52 +21,74 @@ import tvm ...@@ -21,52 +21,74 @@ import tvm
import topi import topi
from topi.util import get_const_int, get_const_tuple from topi.util import get_const_int, get_const_tuple
from . import op as _reg from . import op as _reg
from ._reduce import _schedule_reduce from . import strategy
from .op import OpPattern from .op import OpPattern
from ...hybrid import script from ...hybrid import script
from ...api import convert from ...api import convert
schedule_injective = _reg.schedule_injective _reg.register_broadcast_schedule("broadcast_to")
schedule_broadcast = _reg.schedule_injective _reg.register_broadcast_schedule("broadcast_to_like")
schedule_concatenate = _reg.schedule_concatenate _reg.register_broadcast_schedule("expand_dims")
_reg.register_broadcast_schedule("repeat")
_reg.register_broadcast_schedule("tile")
_reg.register_schedule("collapse_sum_like", _schedule_reduce) _reg.register_broadcast_schedule("where")
_reg.register_schedule("broadcast_to", schedule_broadcast) _reg.register_injective_schedule("squeeze")
_reg.register_schedule("broadcast_to_like", schedule_broadcast) _reg.register_injective_schedule("reshape")
_reg.register_schedule("expand_dims", schedule_broadcast) _reg.register_injective_schedule("reshape_like")
_reg.register_schedule("squeeze", schedule_injective) _reg.register_injective_schedule("full")
_reg.register_schedule("reshape", schedule_injective) _reg.register_injective_schedule("full_like")
_reg.register_schedule("reshape_like", schedule_injective) _reg.register_injective_schedule("arange")
_reg.register_schedule("full", schedule_injective) _reg.register_injective_schedule("reverse")
_reg.register_schedule("full_like", schedule_injective) _reg.register_injective_schedule("cast")
_reg.register_schedule("arange", schedule_injective) _reg.register_injective_schedule("cast_like")
_reg.register_schedule("reverse", schedule_injective) _reg.register_injective_schedule("reinterpret")
_reg.register_schedule("repeat", schedule_broadcast) _reg.register_injective_schedule("strided_slice")
_reg.register_schedule("tile", schedule_broadcast) _reg.register_injective_schedule("slice_like")
_reg.register_schedule("cast", schedule_injective) _reg.register_injective_schedule("split")
_reg.register_schedule("cast_like", schedule_injective) _reg.register_injective_schedule("take")
_reg.register_schedule("reinterpret", schedule_injective) _reg.register_injective_schedule("transpose")
_reg.register_schedule("strided_slice", schedule_injective) _reg.register_injective_schedule("stack")
_reg.register_schedule("strided_set", schedule_injective) _reg.register_injective_schedule("_contrib_reverse_reshape")
_reg.register_schedule("slice_like", schedule_injective) _reg.register_injective_schedule("gather_nd")
_reg.register_schedule("split", schedule_injective) _reg.register_injective_schedule("sequence_mask")
_reg.register_schedule("take", schedule_injective) _reg.register_injective_schedule("one_hot")
_reg.register_schedule("transpose", schedule_injective) _reg.register_reduce_schedule("collapse_sum_like")
_reg.register_schedule("where", schedule_broadcast)
_reg.register_schedule("stack", schedule_injective) # concatenate
_reg.register_schedule("concatenate", schedule_concatenate) _reg.register_schedule("concatenate", strategy.schedule_concatenate)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective) # strided_set
_reg.register_schedule("sequence_mask", schedule_injective) @_reg.register_compute("strided_set")
_reg.register_schedule("one_hot", schedule_injective) def compute_strided_set(attrs, inputs, output_type):
"""Compute definition of strided_set"""
return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])]
_reg.register_injective_schedule("strided_set")
# layout_transform # layout_transform
_reg.register_schedule("layout_transform", schedule_injective) _reg.register_injective_schedule("layout_transform")
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE) _reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
# shape func # argwhere
@_reg.register_compute("argwhere")
def compute_argwhere(attrs, inputs, output_type):
"""Compute definition of argwhere"""
output_shape = []
for s in output_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
# see Any, replace it with a var
output_shape.append(tvm.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]
_reg.register_schedule("argwhere", strategy.schedule_argwhere)
#####################
# Shape functions #
#####################
@script @script
def _arange_shape_func(start, stop, step): def _arange_shape_func(start, stop, step):
out = output_tensor((1,), "int64") out = output_tensor((1,), "int64")
...@@ -284,31 +306,6 @@ def argwhere_shape_func(attrs, inputs, out_ndims): ...@@ -284,31 +306,6 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
return [_argwhere_shape_func_5d(inputs[0])] return [_argwhere_shape_func_5d(inputs[0])]
return ValueError("Does not support rank higher than 5 in argwhere") return ValueError("Does not support rank higher than 5 in argwhere")
@_reg.register_schedule("argwhere")
def schedule_argwhere(_, outs, target):
"""Schedule definition of argwhere"""
with target:
return topi.generic.schedule_argwhere(outs)
@_reg.register_compute("argwhere")
def compute_argwhere(attrs, inputs, output_type, _):
"""Compute definition of argwhere"""
output_shape = []
for s in output_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
# see Any, replace it with a var
output_shape.append(tvm.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]
@_reg.register_compute("strided_set")
def compute_strided_set(attrs, inputs, output_type, _):
"""Compute definition of strided_set"""
return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])]
@script @script
def _layout_transform_shape_func(data_shape, def _layout_transform_shape_func(data_shape,
out_layout_len, out_layout_len,
......
...@@ -19,7 +19,7 @@ from tvm.runtime import ndarray as _nd ...@@ -19,7 +19,7 @@ from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext from tvm.runtime import TVMContext as _TVMContext
from . import _make from . import _make
from ..op import register_schedule, schedule_injective from .. import op as reg
def on_device(data, device): def on_device(data, device):
...@@ -79,7 +79,7 @@ def checkpoint(data): ...@@ -79,7 +79,7 @@ def checkpoint(data):
""" """
return _make.checkpoint(data) return _make.checkpoint(data)
register_schedule("annotation.checkpoint", schedule_injective) reg.register_injective_schedule("annotation.checkpoint")
def compiler_begin(data, compiler): def compiler_begin(data, compiler):
......
...@@ -18,29 +18,19 @@ ...@@ -18,29 +18,19 @@
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
import topi
from .. import op as reg from .. import op as reg
from ..op import schedule_injective, OpPattern from .. import strategy
from ..op import OpPattern
# adaptive_max_pool2d # adaptive_max_pool2d
@reg.register_schedule("contrib.adaptive_max_pool2d") reg.register_schedule("contrib.adaptive_max_pool2d", strategy.schedule_adaptive_pool)
def schedule_adaptive_max_pool2d(_, outs, target):
"""Schedule definition of adaptive_max_pool2d"""
with target:
return topi.generic.schedule_adaptive_pool(outs)
reg.register_pattern("contrib.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("contrib.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# adaptive_avg_pool2d # adaptive_avg_pool2d
@reg.register_schedule("contrib.adaptive_avg_pool2d") reg.register_schedule("contrib.adaptive_avg_pool2d", strategy.schedule_adaptive_pool)
def schedule_adaptive_avg_pool2d(_, outs, target):
"""Schedule definition of adaptive_avg_pool2d"""
with target:
return topi.generic.schedule_adaptive_pool(outs)
reg.register_pattern("contrib.adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("contrib.adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# relay.contrib.ndarray_size # relay.contrib.ndarray_size
reg.register_schedule("contrib.ndarray_size", schedule_injective) reg.register_injective_schedule("contrib.ndarray_size")
...@@ -20,13 +20,10 @@ from __future__ import absolute_import ...@@ -20,13 +20,10 @@ from __future__ import absolute_import
import topi import topi
from .. import op as reg from .. import op as reg
from ..op import schedule_injective
# resize # resize
reg.register_schedule("image.resize", schedule_injective)
@reg.register_compute("image.resize") @reg.register_compute("image.resize")
def compute_resize(attrs, inputs, out_type, target): def compute_resize(attrs, inputs, out_type):
size = attrs.size size = attrs.size
layout = attrs.layout layout = attrs.layout
method = attrs.method method = attrs.method
...@@ -34,12 +31,12 @@ def compute_resize(attrs, inputs, out_type, target): ...@@ -34,12 +31,12 @@ def compute_resize(attrs, inputs, out_type, target):
out_dtype = attrs.out_dtype out_dtype = attrs.out_dtype
return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)] return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)]
reg.register_injective_schedule("image.resize")
# crop and resize
reg.register_schedule("image.crop_and_resize", schedule_injective)
# crop and resize
@reg.register_compute("image.crop_and_resize") @reg.register_compute("image.crop_and_resize")
def compute_crop_and_resize(attrs, inputs, out_type, target): def compute_crop_and_resize(attrs, inputs, out_type):
crop_size = attrs.crop_size crop_size = attrs.crop_size
layout = attrs.layout layout = attrs.layout
method = attrs.method method = attrs.method
...@@ -48,3 +45,5 @@ def compute_crop_and_resize(attrs, inputs, out_type, target): ...@@ -48,3 +45,5 @@ def compute_crop_and_resize(attrs, inputs, out_type, target):
return [topi.image.crop_and_resize(inputs[0], inputs[1], inputs[2], return [topi.image.crop_and_resize(inputs[0], inputs[1], inputs[2],
crop_size, layout, method, crop_size, layout, method,
extrapolation_value, out_dtype)] extrapolation_value, out_dtype)]
reg.register_injective_schedule("image.crop_and_resize")
...@@ -204,7 +204,6 @@ def conv2d(data, ...@@ -204,7 +204,6 @@ def conv2d(data,
# TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged # TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
# convert 2-way padding to 4-way padding # convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding) padding = get_pad_tuple2d(padding)
return _make.conv2d(data, weight, strides, padding, dilation, return _make.conv2d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout, groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype) kernel_layout, out_layout, out_dtype)
...@@ -298,7 +297,6 @@ def conv3d(data, ...@@ -298,7 +297,6 @@ def conv3d(data,
dilation = (dilation, dilation, dilation) dilation = (dilation, dilation, dilation)
if isinstance(padding, int): if isinstance(padding, int):
padding = (padding, padding, padding) padding = (padding, padding, padding)
return _make.conv3d(data, weight, strides, padding, dilation, return _make.conv3d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout, groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype) kernel_layout, out_layout, out_dtype)
...@@ -1772,74 +1770,6 @@ def contrib_conv2d_winograd_without_weight_transform(data, ...@@ -1772,74 +1770,6 @@ def contrib_conv2d_winograd_without_weight_transform(data,
kernel_layout, out_layout, out_dtype) kernel_layout, out_layout, out_dtype)
def contrib_conv2d_winograd_nnpack_without_weight_transform(data,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""2D convolution with the NNPACK implementation of winograd algorithm.
The basic parameters are the same as the ones in vanilla conv2d.
It assumes the weight is pre-transformed by nn.contrib_conv2d_winograd_nnpack_weight_transform
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : tuple of int, optional
The strides of convolution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.contrib_conv2d_winograd_nnpack_without_weight_transform(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_nchwc(data, def contrib_conv2d_nchwc(data,
kernel, kernel,
strides=(1, 1), strides=(1, 1),
...@@ -1974,73 +1904,6 @@ def contrib_depthwise_conv2d_nchwc(data, ...@@ -1974,73 +1904,6 @@ def contrib_depthwise_conv2d_nchwc(data,
groups, channels, kernel_size, data_layout, groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype) kernel_layout, out_layout, out_dtype)
def contrib_conv2d_nchwc_int8(data,
kernel,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW8c",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""Variant of 2D convolution. It deals with only int8 inputs.
This operator takes the weight as the convolution kernel
and convolves it with data to produce an output, following a specialized
NCHWc data layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
kernel : tvm.relay.Expr
The kernel expressions.
strides : tuple of int, optional
The strides of convolution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.contrib_conv2d_NCHWc_int8(data, kernel, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_winograd_weight_transform(weight, def contrib_conv2d_winograd_weight_transform(weight,
tile_size): tile_size):
......
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
#pylint: disable=unused-argument #pylint: disable=unused-argument,invalid-name
"""The base node types for the Relay language.""" """The base node types for the Relay language."""
import topi
import tvm._ffi import tvm._ffi
from tvm.driver import lower, build from tvm.driver import lower, build
from ..base import register_relay_node from ..base import register_relay_node
from ..expr import RelayExpr from ..expr import RelayExpr
from ...api import register_func from ...api import register_func
from ...target import get_native_generic_func, GenericFunc
from ...runtime import Object
from . import _make from . import _make
@register_relay_node @register_relay_node
...@@ -143,21 +144,105 @@ class OpPattern(object): ...@@ -143,21 +144,105 @@ class OpPattern(object):
OPAQUE = 8 OPAQUE = 8
def register_schedule(op_name, schedule=None, level=10): @tvm._ffi.register_object("relay.OpImplementation")
"""Register schedule function for an op class OpImplementation(Object):
"""Operator implementation"""
def compute(self, attrs, inputs, out_type):
"""Call compute function.
Parameters Parameters
---------- ----------
op_name : str attrs : Attrs
The name of the op. Op attributes.
inputs : list[tvm.tensor.Tensor]
The input tensors.
out_type : relay.Type
The output type.
Returns
-------
outs : list[tvm.tensor.Tensor]
The output tensors.
"""
return _OpImplementationCompute(self, attrs, inputs, out_type)
def schedule(self, attrs, outs, target):
"""Call schedule function.
Parameters
----------
attrs : Attrs
Op attributes.
outs : list[tvm.tensor.Tensor]
The output tensors.
schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule target : tvm.target.Target
The target to schedule the op.
Returns
-------
schedule : tvm.Schedule
The schedule.
"""
return _OpImplementationSchedule(self, attrs, outs, target)
@tvm._ffi.register_object("relay.OpSpecialization")
class OpSpecialization(Object):
"""Operator specialization"""
@tvm._ffi.register_object("relay.OpStrategy")
class OpStrategy(Object):
"""Operator strategy"""
def __init__(self):
self.__init_handle_by_constructor__(_make.OpStrategy)
def add_implementation(self, compute, schedule, name="default", plevel=10):
"""Add an implementation to the strategy
Parameters
----------
compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type)
-> List[Tensor]
The compute function.
schedule : function (attrs: Attrs, outs: List[Tensor], target:Target) -> Schedule
The schedule function. The schedule function.
level : int name : str
The priority level The name of implementation.
plevel : int
The priority level of implementation.
""" """
return register(op_name, "FTVMSchedule", schedule, level) _OpStrategyAddImplementation(self, compute, schedule, name, plevel)
def _wrap_default_fstrategy(compute, schedule, name):
def _fstrategy(attrs, inputs, out_type, target):
strategy = OpStrategy()
strategy.add_implementation(compute, schedule, name=name)
return strategy
return _fstrategy
def _create_fstrategy_from_schedule(op_name, schedule):
assert hasattr(schedule, "dispatch_dict")
compute = get(op_name).get_attr("FTVMCompute")
assert compute is not None, "FTVMCompute is not registered for op %s" % op_name
fstrategy = get_native_generic_func("{}_strategy".format(op_name))
name_pfx = schedule.__name__
name_pfx = name_pfx[name_pfx.index('_')+1:]
fstrategy.set_default(
_wrap_default_fstrategy(compute, schedule.fdefault, "%s.generic" % name_pfx))
for key, sch in schedule.dispatch_dict.items():
fstrategy.register(
_wrap_default_fstrategy(compute, sch, "%s.%s" % (name_pfx, key)), [key])
return fstrategy
def register_compute(op_name, compute=None, level=10): def register_compute(op_name, compute=None, level=10):
...@@ -168,7 +253,7 @@ def register_compute(op_name, compute=None, level=10): ...@@ -168,7 +253,7 @@ def register_compute(op_name, compute=None, level=10):
op_name : str op_name : str
The name of the op. The name of the op.
compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target) compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type)
-> List[Tensor] -> List[Tensor]
The compute function. The compute function.
...@@ -178,6 +263,91 @@ def register_compute(op_name, compute=None, level=10): ...@@ -178,6 +263,91 @@ def register_compute(op_name, compute=None, level=10):
return register(op_name, "FTVMCompute", compute, level) return register(op_name, "FTVMCompute", compute, level)
def register_strategy(op_name, fstrategy=None, level=10):
"""Register strategy function for an op.
Parameters
----------
op_name : str
The name of the op.
fstrategy : function (attrs: Attrs, inputs: List[Tensor], out_type: Type,
target:Target) -> OpStrategy
The strategy function. Need to be native GenericFunc.
level : int
The priority level
"""
if not isinstance(fstrategy, GenericFunc):
assert hasattr(fstrategy, "generic_func_node")
fstrategy = fstrategy.generic_func_node
return register(op_name, "FTVMStrategy", fstrategy, level)
def register_schedule(op_name, schedule, level=10):
"""Register schedule function for an op.
This is used when compute function is the same for all targets and only
schedule is different. It requires FTVMCompute is already registered to
the op.
Parameters
----------
op_name : str
The name of the op.
schedule : function (attrs: Attrs, outs: List[Tensor], target:Target) -> Schedule
The schedule function. Need to be target.generic_func.
level : int
The priority level
"""
fstrategy = _create_fstrategy_from_schedule(op_name, schedule)
return register_strategy(op_name, fstrategy, level)
def register_injective_schedule(op_name, level=10):
"""Register injective schedule function for an op.
Parameters
----------
op_name : str
The name of the op.
level : int
The priority level
"""
return register_schedule(op_name, _schedule_injective, level)
def register_broadcast_schedule(op_name, level=10):
"""Register broadcast schedule function for an op.
Parameters
----------
op_name : str
The name of the op.
level : int
The priority level
"""
return register_schedule(op_name, _schedule_injective, level)
def register_reduce_schedule(op_name, level=10):
"""Register reduce schedule function for an op.
Parameters
----------
op_name : str
The name of the op.
level : int
The priority level
"""
return register_schedule(op_name, _schedule_reduce, level)
def register_alter_op_layout(op_name, alter_layout=None, level=10): def register_alter_op_layout(op_name, alter_layout=None, level=10):
"""Register alter op layout function for an op """Register alter op layout function for an op
...@@ -245,6 +415,7 @@ def register_pattern(op_name, pattern, level=10): ...@@ -245,6 +415,7 @@ def register_pattern(op_name, pattern, level=10):
""" """
return register(op_name, "TOpPattern", pattern, level) return register(op_name, "TOpPattern", pattern, level)
def register_gradient(op_name, fgradient=None, level=10): def register_gradient(op_name, fgradient=None, level=10):
"""Register operator pattern for an op. """Register operator pattern for an op.
...@@ -261,6 +432,7 @@ def register_gradient(op_name, fgradient=None, level=10): ...@@ -261,6 +432,7 @@ def register_gradient(op_name, fgradient=None, level=10):
""" """
return register(op_name, "FPrimalGradient", fgradient, level) return register(op_name, "FPrimalGradient", fgradient, level)
def register_shape_func(op_name, data_dependant, shape_func=None, level=10): def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
"""Register operator shape function for an op. """Register operator shape function for an op.
...@@ -290,18 +462,8 @@ def _lower(name, schedule, inputs, outputs): ...@@ -290,18 +462,8 @@ def _lower(name, schedule, inputs, outputs):
def _build(lowered_funcs): def _build(lowered_funcs):
return build(lowered_funcs, target="llvm") return build(lowered_funcs, target="llvm")
_schedule_injective = None
def schedule_injective(attrs, outputs, target): _schedule_reduce = None
"""Generic schedule for binary broadcast."""
with target:
return topi.generic.schedule_injective(outputs)
def schedule_concatenate(attrs, outputs, target):
"""Generic schedule for concatinate."""
with target:
return topi.generic.schedule_concatenate(outputs)
__DEBUG_COUNTER__ = 0 __DEBUG_COUNTER__ = 0
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Relay op strategies."""
from __future__ import absolute_import as _abs
from .generic import *
from . import x86
from . import arm_cpu
from . import cuda
from . import hls
from . import mali
from . import bifrost
from . import opengl
from . import rocm
from . import intel_graphics
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of bifrost operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import re
import topi
from .generic import *
from .. import op as _op
@conv2d_strategy.register("bifrost")
def conv2d_strategy_bifrost(attrs, inputs, out_type, target):
"""conv2d mali(bifrost) strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
stride_h, stride_w = attrs.get_int_tuple("strides")
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
if kernel_layout == "OIHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.bifrost.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.bifrost")
_, _, kh, kw = get_const_tuple(kernel.shape)
if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.bifrost.conv2d_nchw_winograd),
wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.bifrost",
plevel=15)
elif re.match(r"OIHW\d*o", kernel_layout):
strategy.add_implementation(
wrap_compute_conv2d(topi.bifrost.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.bifrost")
else:
raise RuntimeError("Unsupported conv2d layout {} for Mali(Bifrost)".
format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.bifrost.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.bifrost")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {} for Mali(Bifrost)".
format(layout))
else: # group_conv2d
raise RuntimeError("group_conv2d is not supported for Mali(Bifrost)")
return strategy
@conv2d_winograd_without_weight_transfrom_strategy.register("bifrost")
def conv2d_winograd_without_weight_transfrom_strategy_bifrost(attrs, inputs, out_type, target):
"""conv2d_winograd_without_weight_transfrom mali(bifrost) strategy"""
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.data_layout
strides = attrs.get_int_tuple("strides")
assert dilation == (1, 1), "Do not support dilate now"
assert strides == (1, 1), "Do not support strides now"
assert groups == 1, "Do not supoort arbitrary group number"
strategy = _op.OpStrategy()
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.bifrost.conv2d_nchw_winograd),
wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.bifrost")
else:
raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
format(layout))
return strategy
@dense_strategy.register("bifrost")
def dense_strategy_bifrost(attrs, inputs, out_type, target):
"""dense mali(bifrost) strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_dense(topi.bifrost.dense),
wrap_topi_schedule(topi.bifrost.schedule_dense),
name="dense.bifrost")
return strategy
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of HLS operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import topi
from .generic import *
from .. import op as _op
@schedule_injective.register("hls")
def schedule_injective_hls(attrs, outs, target):
"""schedule injective ops for hls"""
with target:
return topi.hls.schedule_injective(outs)
@schedule_reduce.register("hls")
def schedule_reduce_hls(attrs, outs, target):
"""schedule reduction ops for hls"""
with target:
return topi.hls.schedule_reduce(outs)
@schedule_concatenate.register("hls")
def schedule_concatenate_hls(attrs, outs, target):
"""schedule concatenate for hls"""
with target:
return topi.hls.schedule_injective(outs)
@schedule_pool.register("hls")
def schedule_pool_hls(attrs, outs, target):
"""schedule pooling ops for hls"""
with target:
return topi.hls.schedule_pool(outs, attrs.layout)
@schedule_adaptive_pool.register("hls")
def schedule_adaptive_pool_hls(attrs, outs, target):
"""schedule adaptive pooling ops for hls"""
with target:
return topi.hls.schedule_adaptive_pool(outs)
@schedule_softmax.register("hls")
def schedule_softmax_hls(attrs, outs, target):
"""schedule softmax for hls"""
with target:
return topi.hls.schedule_softmax(outs)
@override_native_generic_func("conv2d_strategy")
def conv2d_strategy_hls(attrs, inputs, out_type, target):
"""conv2d hls strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nchw),
wrap_topi_schedule(topi.hls.schedule_conv2d_nchw),
name="conv2d_nchw.hls")
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc),
wrap_topi_schedule(topi.hls.schedule_conv2d_nhwc),
name="conv2d_nhwc.hls")
else:
raise RuntimeError("Unsupported conv2d layout {}".format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.hls.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.hls")
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.hls.schedule_depthwise_conv2d_nhwc),
name="depthwise_nhwc.hls")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
raise RuntimeError("group_conv2d is not supported for hls")
return strategy
@override_native_generic_func("conv2d_NCHWc_strategy")
def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, target):
"""conv2d_NCHWc hls strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.hls.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.hls")
return strategy
@conv2d_transpose_strategy.register("hls")
def conv2d_transpose_strategy_hls(attrs, inputs, out_type, target):
"""conv2d_transpose hls strategy"""
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.hls.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.hls")
return strategy
@dense_strategy.register("hls")
def dense_strategy_hls(attrs, inputs, out_type, target):
"""dense hls strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_dense(topi.nn.dense),
wrap_topi_schedule(topi.hls.schedule_dense),
name="dense.hls")
return strategy
@bitserial_conv2d_strategy.register("hls")
def bitserial_conv2d_strategy_hls(attrs, inputs, out_type, target):
"""bitserial_conv2d hls strategy"""
strategy = _op.OpStrategy()
layout = attrs.data_layout
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nchw),
wrap_topi_schedule(topi.hls.schedule_bitserial_conv2d_nchw),
name="bitserial_conv2d_nchw.hls")
elif layout == "NHWC":
strategy.add_implementation(
wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nhwc),
wrap_topi_schedule(topi.hls.schedule_bitserial_conv2d_nhwc),
name="bitserial_conv2d_nhwc.hls")
else:
raise ValueError("Data layout {} not supported.".format(layout))
return strategy
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of x86 operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import topi
from .generic import *
from .. import op as _op
@conv2d_strategy.register("intel_graphics")
def conv2d_strategy_intel_graphics(attrs, inputs, out_type, target):
"""conv2d intel graphics strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_nchw),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_nchw),
name="conv2d_nchw.intel_graphics")
# conv2d_NCHWc won't work without alter op layout pass
# TODO(@Laurawly): fix this
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.intel_graphics",
plevel=5)
else:
raise RuntimeError("Unsupported conv2d layout {} for intel graphics".
format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.intel_graphics.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.intel_graphics")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
raise RuntimeError("group_conv2d is not supported for intel graphics")
return strategy
@conv2d_NCHWc_strategy.register("intel_graphics")
def conv2d_NCHWc_strategy_intel_graphics(attrs, inputs, out_type, target):
"""conv2d_NCHWc intel_graphics strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.intel_graphics")
return strategy
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of mali operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import re
import topi
from .generic import *
from .. import op as _op
@conv2d_strategy.register("mali")
def conv2d_strategy_mali(attrs, inputs, out_type, target):
"""conv2d mali strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
stride_h, stride_w = attrs.get_int_tuple("strides")
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
if kernel_layout == "OIHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.mali")
# check if winograd algorithm is applicable
_, _, kh, kw = get_const_tuple(kernel.shape)
if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.conv2d_nchw_winograd),
wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.mali",
plevel=15)
elif re.match(r"OIHW\d*o", kernel_layout):
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.mali")
else:
raise RuntimeError("Unsupported weight layout {} for conv2d NCHW".
format(kernel_layout))
else:
raise RuntimeError("Unsupported conv2d layout {} for mali".format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.mali")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {} for mali".format(layout))
else: # group_conv2d
raise RuntimeError("group_conv2d is not supported for mali")
return strategy
@conv2d_winograd_without_weight_transfrom_strategy.register("mali")
def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_type, target):
"""conv2d_winograd_without_weight_transfrom mali strategy"""
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.data_layout
strides = attrs.get_int_tuple("strides")
kernel = inputs[1]
assert dilation == (1, 1), "Do not support dilate now"
assert strides == (1, 1), "Do not support strides now"
assert groups == 1, "Do not supoort arbitrary group number"
strategy = _op.OpStrategy()
if layout == "NCHW":
assert len(kernel.shape) == 5, "Kernel must be packed into 5-dim"
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.conv2d_nchw_winograd),
wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.mali")
else:
raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
format(layout))
return strategy
@dense_strategy.register("mali")
def dense_strategy_mali(attrs, inputs, out_type, target):
"""dense mali strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_dense(topi.mali.dense),
wrap_topi_schedule(topi.mali.schedule_dense),
name="dense.mali")
return strategy
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of OpenGL operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import topi
from .generic import *
from .. import op as _op
@schedule_injective.register("opengl")
def schedule_injective_opengl(attrs, outs, target):
"""schedule injective ops for opengl"""
with target:
return topi.opengl.schedule_injective(outs)
@schedule_concatenate.register("opengl")
def schedule_concatenate_opengl(attrs, outs, target):
"""schedule concatenate for opengl"""
with target:
return topi.opengl.schedule_injective(outs)
@schedule_pool.register("opengl")
def schedule_pool_opengl(attrs, outs, target):
"""schedule pooling ops for opengl"""
with target:
return topi.opengl.schedule_pool(outs, attrs.layout)
@schedule_adaptive_pool.register("opengl")
def schedule_adaptive_pool_opengl(attrs, outs, target):
"""schedule adative pooling ops for opengl"""
with target:
return topi.opengl.schedule_adaptive_pool(outs)
@schedule_softmax.register("opengl")
def schedule_softmax_opengl(attrs, outs, target):
"""schedule softmax for opengl"""
with target:
return topi.opengl.schedule_softmax(outs)
@conv2d_strategy.register("opengl")
def conv2d_strategy_opengl(attrs, inputs, out_type, target):
"""conv2d opengl strategy"""
strategy = _op.OpStrategy()
groups = attrs.groups
layout = attrs.data_layout
assert groups == 1, "Don't support group conv2d on OpenGL"
assert layout == "NCHW", "Only support conv2d layout NCHW for OpenGL"
strategy.add_implementation(wrap_compute_conv2d(topi.nn.conv2d),
wrap_topi_schedule(topi.opengl.schedule_conv2d_nchw),
name="conv2d_nchw.opengl")
return strategy
@dense_strategy.register("opengl")
def dense_strategy_opengl(attrs, inputs, out_type, target):
"""dense opengl strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_dense(topi.nn.dense),
wrap_topi_schedule(topi.opengl.schedule_dense),
name="dense.opengl")
return strategy
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of ROCm operator strategy."""
# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
import topi
from .generic import *
from .. import op as _op
@schedule_lrn.register("rocm")
def schedule_lrn_rocm(attrs, outs, target):
"""schedule LRN for rocm"""
with target:
return topi.rocm.schedule_lrn(outs)
@conv2d_strategy.register("rocm")
def conv2d_strategy_rocm(attrs, inputs, out_type, target):
"""conv2d rocm strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
groups = attrs.groups
layout = attrs.data_layout
stride_h, stride_w = attrs.get_int_tuple("strides")
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda")
_, _, kh, kw = get_const_tuple(kernel.shape)
if kh <= 7 and kw <= 7 and kh == kw and stride_h == 1 and stride_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.cuda",
plevel=15)
elif layout == "HWCN":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda")
# TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
# elif layout == "NHWC":
# assert kernel_layout == "HWIO"
# strategy.add_implementation(
# wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
# wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
# name="conv2d_nhwc.cuda")
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda")
else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add miopen implementation
if "miopen" in target.libs:
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=15)
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.cuda")
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.cuda")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
if layout == 'NCHW':
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.cuda")
elif layout == 'NCHW4c' and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
name="group_conv2d_NCHWc_int8.cuda")
else:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy
@dense_strategy.register("rocm")
def dense_strategy_rocm(attrs, inputs, out_type, target):
"""Dense strategy for ROCM"""
strategy = _op.OpStrategy()
assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense),
wrap_topi_schedule(topi.rocm.schedule_dense),
name="dense.rocm")
if target.target_name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense_rocblas),
wrap_topi_schedule(topi.rocm.dense_rocblas),
name="dense_rocblas.rocm",
plevel=5)
return strategy
...@@ -17,65 +17,27 @@ ...@@ -17,65 +17,27 @@
# pylint: disable=invalid-name, unused-argument # pylint: disable=invalid-name, unused-argument
"""Faster R-CNN and Mask R-CNN operations.""" """Faster R-CNN and Mask R-CNN operations."""
import topi import topi
from topi.util import get_const_tuple, get_float_tuple, get_const_int from topi.util import get_const_tuple
from .. import op as reg from .. import op as reg
from .. import strategy
from ..op import OpPattern from ..op import OpPattern
# roi_align
@reg.register_compute("vision.roi_align") reg.register_strategy("vision.roi_align", strategy.roi_align_strategy)
def compute_roi_align(attrs, inputs, _, target):
"""Compute definition of roi_align"""
assert attrs.layout == "NCHW"
return [topi.vision.rcnn.roi_align_nchw(
inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size),
spatial_scale=attrs.spatial_scale, sample_ratio=attrs.sample_ratio)]
@reg.register_schedule("vision.roi_align")
def schedule_roi_align(_, outs, target):
"""Schedule definition of roi_align"""
with target:
return topi.generic.vision.schedule_roi_align(outs)
reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)
# roi_pool
@reg.register_compute("vision.roi_pool") @reg.register_compute("vision.roi_pool")
def compute_roi_pool(attrs, inputs, _, target): def compute_roi_pool(attrs, inputs, _):
"""Compute definition of roi_pool""" """Compute definition of roi_pool"""
assert attrs.layout == "NCHW" assert attrs.layout == "NCHW"
return [topi.vision.rcnn.roi_pool_nchw( return [topi.vision.rcnn.roi_pool_nchw(
inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size), inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size),
spatial_scale=attrs.spatial_scale)] spatial_scale=attrs.spatial_scale)]
@reg.register_schedule("vision.roi_pool") reg.register_schedule("vision.roi_pool", strategy.schedule_roi_pool)
def schedule_roi_pool(_, outs, target):
"""Schedule definition of roi_pool"""
with target:
return topi.generic.vision.schedule_roi_pool(outs)
reg.register_pattern("vision.roi_pool", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("vision.roi_pool", OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("vision.proposal") # proposal
def compute_proposal(attrs, inputs, _, target): reg.register_strategy("vision.proposal", strategy.proposal_strategy)
"""Compute definition of proposal"""
scales = get_float_tuple(attrs.scales)
ratios = get_float_tuple(attrs.ratios)
feature_stride = attrs.feature_stride
threshold = attrs.threshold
rpn_pre_nms_top_n = attrs.rpn_pre_nms_top_n
rpn_post_nms_top_n = attrs.rpn_post_nms_top_n
rpn_min_size = attrs.rpn_min_size
iou_loss = bool(get_const_int(attrs.iou_loss))
with target:
return [
topi.vision.rcnn.proposal(inputs[0], inputs[1], inputs[2], scales, ratios,
feature_stride, threshold, rpn_pre_nms_top_n,
rpn_post_nms_top_n, rpn_min_size, iou_loss)
]
@reg.register_schedule("vision.proposal")
def schedule_proposal(_, outs, target):
"""Schedule definition of proposal"""
with target:
return topi.generic.schedule_proposal(outs)
reg.register_pattern("vision.proposal", OpPattern.OPAQUE) reg.register_pattern("vision.proposal", OpPattern.OPAQUE)
...@@ -18,104 +18,25 @@ ...@@ -18,104 +18,25 @@
"""Definition of vision ops""" """Definition of vision ops"""
from __future__ import absolute_import from __future__ import absolute_import
import topi
from topi.util import get_const_int, get_const_float, get_float_tuple
from .. import op as reg from .. import op as reg
from .. import strategy
from ..op import OpPattern from ..op import OpPattern
# multibox_prior
@reg.register_schedule("vision.multibox_prior") reg.register_strategy("vision.multibox_prior", strategy.multibox_prior_strategy)
def schedule_multibox_prior(_, outs, target):
"""Schedule definition of multibox_prior"""
with target:
return topi.generic.schedule_multibox_prior(outs)
@reg.register_compute("vision.multibox_prior")
def compute_multibox_prior(attrs, inputs, _, target):
"""Compute definition of multibox_prior"""
sizes = get_float_tuple(attrs.sizes)
ratios = get_float_tuple(attrs.ratios)
steps = get_float_tuple(attrs.steps)
offsets = get_float_tuple(attrs.offsets)
clip = bool(get_const_int(attrs.clip))
return [
topi.vision.ssd.multibox_prior(inputs[0], sizes, ratios, steps,
offsets, clip)
]
reg.register_pattern("vision.multibox_prior", OpPattern.OPAQUE) reg.register_pattern("vision.multibox_prior", OpPattern.OPAQUE)
# multibox_transform_loc # multibox_transform_loc
@reg.register_schedule("vision.multibox_transform_loc") reg.register_strategy("vision.multibox_transform_loc", strategy.multibox_transform_loc_strategy)
def schedule_multibox_transform_loc(_, outs, target):
"""Schedule definition of multibox_detection"""
with target:
return topi.generic.schedule_multibox_transform_loc(outs)
@reg.register_compute("vision.multibox_transform_loc")
def compute_multibox_transform_loc(attrs, inputs, _, target):
"""Compute definition of multibox_detection"""
clip = bool(get_const_int(attrs.clip))
threshold = get_const_float(attrs.threshold)
variances = get_float_tuple(attrs.variances)
return topi.vision.ssd.multibox_transform_loc(
inputs[0], inputs[1], inputs[2], clip, threshold, variances)
reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE) reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE)
reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE)
# Get counts of valid boxes # Get counts of valid boxes
@reg.register_schedule("vision.get_valid_counts") reg.register_strategy("vision.get_valid_counts", strategy.get_valid_counts_strategy)
def schedule_get_valid_counts(_, outs, target):
"""Schedule definition of get_valid_counts"""
with target:
return topi.generic.schedule_get_valid_counts(outs)
@reg.register_compute("vision.get_valid_counts")
def compute_get_valid_counts(attrs, inputs, _, target):
"""Compute definition of get_valid_counts"""
score_threshold = get_const_float(attrs.score_threshold)
id_index = get_const_int(attrs.id_index)
score_index = get_const_int(attrs.score_index)
return topi.vision.get_valid_counts(inputs[0], score_threshold,
id_index, score_index)
reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE) reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE)
# non-maximum suppression # non-maximum suppression
@reg.register_schedule("vision.non_max_suppression") reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy)
def schedule_nms(_, outs, target):
"""Schedule definition of nms"""
with target:
return topi.generic.schedule_nms(outs)
@reg.register_compute("vision.non_max_suppression")
def compute_nms(attrs, inputs, _, target):
"""Compute definition of nms"""
return_indices = bool(get_const_int(attrs.return_indices))
max_output_size = get_const_int(attrs.max_output_size)
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
top_k = get_const_int(attrs.top_k)
coord_start = get_const_int(attrs.coord_start)
score_index = get_const_int(attrs.score_index)
id_index = get_const_int(attrs.id_index)
invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
return [
topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
iou_threshold, force_suppress, top_k,
coord_start, score_index, id_index,
return_indices, invalid_to_bottom)
]
reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE) reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE)
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#pylint: disable=invalid-name, unused-argument #pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
from ..op import register_schedule, register_pattern from ..op import register_pattern, OpPattern
from ..op import schedule_injective, OpPattern from ..op import register_injective_schedule
# reorg # reorg
register_pattern("vision.yolo_reorg", OpPattern.INJECTIVE) register_pattern("vision.yolo_reorg", OpPattern.INJECTIVE)
register_schedule("vision.yolo_reorg", schedule_injective) register_injective_schedule("vision.yolo_reorg")
...@@ -31,7 +31,7 @@ from .quantize import _forward_op ...@@ -31,7 +31,7 @@ from .quantize import _forward_op
@_reg.register_compute("relay.op.annotation.simulated_quantize") @_reg.register_compute("relay.op.annotation.simulated_quantize")
def simulated_quantize_compute(attrs, inputs, out_type, target): def simulated_quantize_compute(attrs, inputs, out_type):
"""Compiler for simulated_quantize.""" """Compiler for simulated_quantize."""
assert len(inputs) == 4 assert len(inputs) == 4
assert attrs.sign assert attrs.sign
...@@ -52,11 +52,10 @@ def simulated_quantize_compute(attrs, inputs, out_type, target): ...@@ -52,11 +52,10 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):
return [rdata] return [rdata]
_reg.register_schedule("relay.op.annotation.simulated_quantize", _reg.register_injective_schedule("relay.op.annotation.simulated_quantize")
_reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize", _reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.ELEMWISE) _reg.OpPattern.ELEMWISE)
_reg.register_schedule("annotation.cast_hint", _reg.schedule_injective) _reg.register_injective_schedule("annotation.cast_hint")
@register_relay_node @register_relay_node
......
...@@ -44,15 +44,18 @@ def conv_block(data, name, channels, kernel_size=(3, 3), strides=(1, 1), ...@@ -44,15 +44,18 @@ def conv_block(data, name, channels, kernel_size=(3, 3), strides=(1, 1),
def separable_conv_block(data, name, depthwise_channels, pointwise_channels, def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
kernel_size=(3, 3), downsample=False, padding=(1, 1), kernel_size=(3, 3), downsample=False, padding=(1, 1),
epsilon=1e-5, layout='NCHW'): epsilon=1e-5, layout='NCHW', dtype="float32"):
"""Helper function to get a separable conv block""" """Helper function to get a separable conv block"""
if downsample: if downsample:
strides = (2, 2) strides = (2, 2)
else: else:
strides = (1, 1) strides = (1, 1)
# depthwise convolution + bn + relu # depthwise convolution + bn + relu
wshape = (depthwise_channels, 1) + kernel_size
weight = relay.var(name + "_weight", shape=wshape, dtype=dtype)
conv1 = layers.conv2d( conv1 = layers.conv2d(
data=data, data=data,
weight=weight,
channels=depthwise_channels, channels=depthwise_channels,
groups=depthwise_channels, groups=depthwise_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
...@@ -85,38 +88,41 @@ def mobile_net(num_classes=1000, data_shape=(1, 3, 224, 224), ...@@ -85,38 +88,41 @@ def mobile_net(num_classes=1000, data_shape=(1, 3, 224, 224),
body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2, 2), body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2, 2),
layout=layout) layout=layout)
body = separable_conv_block(body, 'separable_conv_block_1', body = separable_conv_block(body, 'separable_conv_block_1',
int(32*alpha), int(64*alpha), layout=layout) int(32*alpha), int(64*alpha), layout=layout,
dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_2', body = separable_conv_block(body, 'separable_conv_block_2',
int(64*alpha), int(128*alpha), downsample=True, int(64*alpha), int(128*alpha), downsample=True,
layout=layout) layout=layout, dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_3', body = separable_conv_block(body, 'separable_conv_block_3',
int(128*alpha), int(128*alpha), layout=layout) int(128*alpha), int(128*alpha), layout=layout,
dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_4', body = separable_conv_block(body, 'separable_conv_block_4',
int(128*alpha), int(256*alpha), downsample=True, int(128*alpha), int(256*alpha), downsample=True,
layout=layout) layout=layout, dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_5', body = separable_conv_block(body, 'separable_conv_block_5',
int(256*alpha), int(256*alpha), layout=layout) int(256*alpha), int(256*alpha), layout=layout,
dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_6', body = separable_conv_block(body, 'separable_conv_block_6',
int(256*alpha), int(512*alpha), downsample=True, int(256*alpha), int(512*alpha), downsample=True,
layout=layout) layout=layout, dtype=dtype)
if is_shallow: if is_shallow:
body = separable_conv_block(body, 'separable_conv_block_7', body = separable_conv_block(body, 'separable_conv_block_7',
int(512*alpha), int(1024*alpha), int(512*alpha), int(1024*alpha),
downsample=True, layout=layout) downsample=True, layout=layout, dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_8', body = separable_conv_block(body, 'separable_conv_block_8',
int(1024*alpha), int(1024*alpha), int(1024*alpha), int(1024*alpha),
downsample=True, layout=layout) downsample=True, layout=layout, dtype=dtype)
else: else:
for i in range(7, 12): for i in range(7, 12):
body = separable_conv_block(body, 'separable_conv_block_%d' % i, body = separable_conv_block(body, 'separable_conv_block_%d' % i,
int(512*alpha), int(512*alpha), int(512*alpha), int(512*alpha),
layout=layout) layout=layout, dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_12', body = separable_conv_block(body, 'separable_conv_block_12',
int(512*alpha), int(1024*alpha), int(512*alpha), int(1024*alpha),
downsample=True, layout=layout) downsample=True, layout=layout, dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_13', body = separable_conv_block(body, 'separable_conv_block_13',
int(1024*alpha), int(1024*alpha), int(1024*alpha), int(1024*alpha),
layout=layout) layout=layout, dtype=dtype)
pool = relay.nn.global_avg_pool2d(data=body, layout=layout) pool = relay.nn.global_avg_pool2d(data=body, layout=layout)
flatten = relay.nn.batch_flatten(data=pool) flatten = relay.nn.batch_flatten(data=pool)
weight = relay.var('fc_weight') weight = relay.var('fc_weight')
......
...@@ -184,6 +184,7 @@ def override_native_generic_func(func_name): ...@@ -184,6 +184,7 @@ def override_native_generic_func(func_name):
fresult = decorate(fdefault, dispatch_func) fresult = decorate(fdefault, dispatch_func)
fresult.fdefault = fdefault fresult.fdefault = fdefault
fresult.register = register fresult.register = register
fresult.generic_func_node = generic_func_node
return fresult return fresult
return fdecorate return fdecorate
...@@ -268,4 +269,5 @@ def generic_func(fdefault): ...@@ -268,4 +269,5 @@ def generic_func(fdefault):
fdecorate = decorate(fdefault, dispatch_func) fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register fdecorate.register = register
fdecorate.fdefault = fdefault fdecorate.fdefault = fdefault
fdecorate.dispatch_dict = dispatch_dict
return fdecorate return fdecorate
...@@ -23,8 +23,8 @@ from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, ...@@ -23,8 +23,8 @@ from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod,
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from tvm.tir import comm_reducer, min, max, sum from tvm.tir import comm_reducer, min, max, sum
from .schedule import Schedule, create_schedule from .schedule import Schedule, create_schedule, SpecializedCondition
from .tensor import Tensor from .tensor import TensorSlice, Tensor
from .tensor_intrin import decl_tensor_intrin from .tensor_intrin import decl_tensor_intrin
from .tag import tag_scope from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var from .operation import placeholder, compute, scan, extern, var, size_var
......
...@@ -517,4 +517,39 @@ class Stage(Object): ...@@ -517,4 +517,39 @@ class Stage(Object):
_ffi_api.StageOpenGL(self) _ffi_api.StageOpenGL(self)
@tvm._ffi.register_object
class SpecializedCondition(Object):
"""Specialized condition to enable op specialization."""
def __init__(self, conditions):
"""Create a specialized condition.
.. note::
Conditions are represented in conjunctive joint form (CNF).
Each condition should be a simple expression, e.g., n > 16,
m % 8 == 0, etc., where n, m are tvm.Var that represents a
dimension in the tensor shape.
Parameters
----------
conditions : List of tvm.Expr
List of conditions in conjunctive joint form (CNF).
"""
if not isinstance(conditions, (list, _container.Array)):
conditions = [conditions]
self.__init_handle_by_constructor__(
_ffi_api.CreateSpecializedCondition, conditions)
@staticmethod
def current():
"""Returns the current specialized condition"""
return _ffi_api.GetCurrentSpecialization()
def __enter__(self):
_ffi_api.EnterSpecializationScope(self)
return self
def __exit__(self, ptype, value, trace):
_ffi_api.ExitSpecializationScope(self)
tvm._ffi._init_api("schedule", __name__) tvm._ffi._init_api("schedule", __name__)
...@@ -964,3 +964,11 @@ class Let(PrimExprWithOp): ...@@ -964,3 +964,11 @@ class Let(PrimExprWithOp):
def __init__(self, var, value, body): def __init__(self, var, value, body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_ffi_api.Let, var, value, body) _ffi_api.Let, var, value, body)
@tvm._ffi.register_object
class Any(PrimExpr):
"""Any node.
"""
def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.Any)
...@@ -47,11 +47,19 @@ ...@@ -47,11 +47,19 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
TVM_REGISTER_NODE_TYPE(CachedFuncNode); TVM_REGISTER_NODE_TYPE(CachedFuncNode);
TVM_REGISTER_NODE_TYPE(CCacheKeyNode); TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
TVM_REGISTER_NODE_TYPE(CCacheValueNode); TVM_REGISTER_NODE_TYPE(CCacheValueNode);
TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); TVM_REGISTER_OBJECT_TYPE(CompileEngineNode);
LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
auto n = make_object<LoweredOutputNode>();
n->outputs = std::move(outputs);
n->implementation = std::move(impl);
data_ = std::move(n);
}
CCacheKey CCacheKeyNode::make(Function source_func, Target target) { CCacheKey CCacheKeyNode::make(Function source_func, Target target) {
auto n = make_object<CCacheKeyNode>(); auto n = make_object<CCacheKeyNode>();
n->source_func = std::move(source_func); n->source_func = std::move(source_func);
...@@ -108,9 +116,7 @@ class ScheduleGetter : ...@@ -108,9 +116,7 @@ class ScheduleGetter :
explicit ScheduleGetter(Target target) explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {} : target_(target), device_copy_op_(Op::Get("device_copy")) {}
std::pair<te::Schedule, CachedFunc> Create(const Function& prim_func) { CachedFunc Create(const Function& prim_func) {
static auto fschedule =
Op::GetAttr<FTVMSchedule>("FTVMSchedule");
auto cache_node = make_object<CachedFuncNode>(); auto cache_node = make_object<CachedFuncNode>();
cache_node->target = target_; cache_node->target = target_;
for (Var param : prim_func->params) { for (Var param : prim_func->params) {
...@@ -147,7 +153,6 @@ class ScheduleGetter : ...@@ -147,7 +153,6 @@ class ScheduleGetter :
} }
cache_node->func_name = candidate_name; cache_node->func_name = candidate_name;
CachedFunc cfunc(cache_node);
CHECK(master_op_.defined()); CHECK(master_op_.defined());
// Fusion over tupled results may leave identity relationships // Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled. // between inputs and outputs, and those should not be scheduled.
...@@ -161,15 +166,16 @@ class ScheduleGetter : ...@@ -161,15 +166,16 @@ class ScheduleGetter :
te::Schedule schedule; te::Schedule schedule;
// No need to register schedule for device copy op. // No need to register schedule for device copy op.
if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) { if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) {
schedule = CHECK(master_implementation_.defined());
fschedule[master_op_](master_attrs_, tensor_outs, target_); schedule = master_implementation_.Schedule(master_attrs_, tensor_outs, target_);
for (const auto& scalar : scalars_) { for (const auto& scalar : scalars_) {
if (schedule->Contain(scalar)) { if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline(); schedule[scalar].compute_inline();
} }
} }
} }
return std::make_pair(schedule, cfunc); cache_node->schedule = std::move(schedule);
return CachedFunc(cache_node);
} }
Array<te::Tensor> VisitExpr(const Expr& expr) { Array<te::Tensor> VisitExpr(const Expr& expr) {
...@@ -214,10 +220,10 @@ class ScheduleGetter : ...@@ -214,10 +220,10 @@ class ScheduleGetter :
} }
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final { Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fcompute =
Op::GetAttr<FTVMCompute>("FTVMCompute");
static auto fpattern = static auto fpattern =
Op::GetAttr<TOpPattern>("TOpPattern"); Op::GetAttr<TOpPattern>("TOpPattern");
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
CHECK(flower_call) << "relay.backend.lower_call is not registered.";
Array<te::Tensor> inputs; Array<te::Tensor> inputs;
int count_tuple = 0; int count_tuple = 0;
...@@ -234,36 +240,21 @@ class ScheduleGetter : ...@@ -234,36 +240,21 @@ class ScheduleGetter :
<< "Only allow function with a single tuple input"; << "Only allow function with a single tuple input";
} }
// Prepare the call_node->checked_type(). For the call node inputs, we ensure that the shape is
// Int32. Following code ensures the same for the output as well.
// TODO(@icemelon): Support recursive tuple
Type call_node_type = call_node->checked_type();
if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
call_node_type = TensorType(GetShape(tt->shape), tt->dtype);
} else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
std::vector<Type> new_fields;
for (auto field : tuple_t->fields) {
if (const auto* tt = field.as<TensorTypeNode>()) {
new_fields.push_back(TensorType(GetShape(tt->shape), tt->dtype));
} else {
new_fields.push_back(field);
}
}
call_node_type = TupleType(new_fields);
}
CHECK(call_node->op.as<OpNode>()) CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops"; << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op); Op op = Downcast<Op>(call_node->op);
Array<te::Tensor> outputs; Array<te::Tensor> outputs;
OpImplementation impl;
// Skip fcompute for device copy operators as it is not registered. // Skip fcompute for device copy operators as it is not registered.
if (op == device_copy_op_) { if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->(); const auto* copy_input = inputs[0].operator->();
outputs.push_back(te::TensorNode::make(copy_input->shape, copy_input->dtype, outputs.push_back(te::TensorNode::make(copy_input->shape, copy_input->dtype,
te::Operation(), 0)); te::Operation(), 0));
} else { } else {
outputs = fcompute[op](call_node->attrs, inputs, LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
call_node_type, target_); outputs = lowered_out->outputs;
impl = lowered_out->implementation;
} }
int op_pattern = fpattern[op]; int op_pattern = fpattern[op];
...@@ -276,6 +267,7 @@ class ScheduleGetter : ...@@ -276,6 +267,7 @@ class ScheduleGetter :
master_op_ = op; master_op_ = op;
master_attrs_ = call_node->attrs; master_attrs_ = call_node->attrs;
master_op_pattern_ = op_pattern; master_op_pattern_ = op_pattern;
master_implementation_ = impl;
} }
if (outputs.size() != 1) { if (outputs.size() != 1) {
const auto* tuple_type = const auto* tuple_type =
...@@ -332,6 +324,7 @@ class ScheduleGetter : ...@@ -332,6 +324,7 @@ class ScheduleGetter :
Op master_op_; Op master_op_;
Attrs master_attrs_; Attrs master_attrs_;
int master_op_pattern_{0}; int master_op_pattern_{0};
OpImplementation master_implementation_;
std::ostringstream readable_name_stream_; std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_; std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
Array<te::Operation> scalars_; Array<te::Operation> scalars_;
...@@ -677,8 +670,7 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -677,8 +670,7 @@ class CompileEngineImpl : public CompileEngineNode {
* \return Pair of schedule and cache. * \return Pair of schedule and cache.
* The funcs field in cache is not yet populated. * The funcs field in cache is not yet populated.
*/ */
std::pair<te::Schedule, CachedFunc> CreateSchedule( CachedFunc CreateSchedule(const Function& source_func, const Target& target) {
const Function& source_func, const Target& target) {
return ScheduleGetter(target).Create(source_func); return ScheduleGetter(target).Create(source_func);
} }
...@@ -713,9 +705,9 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -713,9 +705,9 @@ class CompileEngineImpl : public CompileEngineNode {
With<Target> target_scope(key->target); With<Target> target_scope(key->target);
CHECK(!value->cached_func.defined()); CHECK(!value->cached_func.defined());
auto spair = CreateSchedule(key->source_func, key->target); auto cfunc = CreateSchedule(key->source_func, key->target);
auto cache_node = make_object<CachedFuncNode>( auto cache_node = make_object<CachedFuncNode>(
*(spair.second.operator->())); *(cfunc.operator->()));
// Skip lowering for device copy node. // Skip lowering for device copy node.
const Expr body = (key->source_func)->body; const Expr body = (key->source_func)->body;
...@@ -735,11 +727,12 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -735,11 +727,12 @@ class CompileEngineImpl : public CompileEngineNode {
// lower the function // lower the function
if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
cache_node->funcs = (*f)( cache_node->funcs = (*f)(
spair.first, all_args, cache_node->func_name, key->source_func); cfunc->schedule, all_args, cache_node->func_name, key->source_func);
} else { } else {
tvm::BuildConfig bcfg = BuildConfig::Create(); tvm::BuildConfig bcfg = BuildConfig::Create();
std::unordered_map<te::Tensor, tir::Buffer> binds; std::unordered_map<te::Tensor, tir::Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name,
binds, bcfg);
} }
value->cached_func = CachedFunc(cache_node); value->cached_func = CachedFunc(cache_node);
return value; return value;
...@@ -820,6 +813,11 @@ const CompileEngine& CompileEngine::Global() { ...@@ -820,6 +813,11 @@ const CompileEngine& CompileEngine::Global() {
return *inst; return *inst;
} }
TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
.set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
return LoweredOutput(outputs, impl);
});
TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed(CCacheKeyNode::make); .set_body_typed(CCacheKeyNode::make);
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/relay/op_strategy.h>
#include <string> #include <string>
#include <functional> #include <functional>
...@@ -44,6 +45,28 @@ enum ShapeFuncParamState { ...@@ -44,6 +45,28 @@ enum ShapeFuncParamState {
kNeedBoth = 3, kNeedBoth = 3,
}; };
struct LoweredOutputNode : public Object {
/*! \brief The outputs to the function */
tvm::Array<te::Tensor> outputs;
/*! \brief The implementation used to compute the output */
OpImplementation implementation;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("outputs", &outputs);
v->Visit("implementation", &implementation);
}
static constexpr const char* _type_key = "relay.LoweredOutput";
TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object);
};
class LoweredOutput : public ObjectRef {
public:
TVM_DLL LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl);
TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode);
};
/*! \brief Node container to represent a cached function. */ /*! \brief Node container to represent a cached function. */
struct CachedFuncNode : public Object { struct CachedFuncNode : public Object {
/* \brief compiled target */ /* \brief compiled target */
...@@ -54,6 +77,8 @@ struct CachedFuncNode : public Object { ...@@ -54,6 +77,8 @@ struct CachedFuncNode : public Object {
tvm::Array<te::Tensor> inputs; tvm::Array<te::Tensor> inputs;
/* \brief The outputs to the function */ /* \brief The outputs to the function */
tvm::Array<te::Tensor> outputs; tvm::Array<te::Tensor> outputs;
/*! \brief The schedule to the function */
te::Schedule schedule;
/*! \brief The lowered functions to support the function. */ /*! \brief The lowered functions to support the function. */
tvm::Array<tir::LoweredFunc> funcs; tvm::Array<tir::LoweredFunc> funcs;
/*! \brief Parameter usage states in the shape function. */ /*! \brief Parameter usage states in the shape function. */
...@@ -64,6 +89,7 @@ struct CachedFuncNode : public Object { ...@@ -64,6 +89,7 @@ struct CachedFuncNode : public Object {
v->Visit("func_name", &func_name); v->Visit("func_name", &func_name);
v->Visit("inputs", &inputs); v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs); v->Visit("outputs", &outputs);
v->Visit("schedule", &schedule);
v->Visit("funcs", &funcs); v->Visit("funcs", &funcs);
v->Visit("shape_func_param_states", &shape_func_param_states); v->Visit("shape_func_param_states", &shape_func_param_states);
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/tvm/relay/ir/op_strategy.cc
* \brief The Relay operator Strategy and related data structure.
*/
#include <tvm/relay/op_strategy.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(OpImplementationNode);
TVM_REGISTER_NODE_TYPE(OpSpecializationNode);
TVM_REGISTER_NODE_TYPE(OpStrategyNode);
Array<te::Tensor> OpImplementation::Compute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) {
return (*this)->fcompute(attrs, inputs, out_type);
}
te::Schedule OpImplementation::Schedule(const Attrs& attrs,
const Array<te::Tensor> &outs,
const Target& target) {
return (*this)->fschedule(attrs, outs, target);
}
void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute,
tvm::relay::FTVMSchedule fschedule,
std::string name,
int plevel) {
auto n = make_object<OpImplementationNode>();
n->fcompute = fcompute;
n->fschedule = fschedule;
n->name = std::move(name);
n->plevel = plevel;
(*this)->implementations.push_back(OpImplementation(n));
}
void OpStrategy::AddImplementation(FTVMCompute fcompute,
FTVMSchedule fschedule,
std::string name,
int plevel) {
auto curr_cond = te::SpecializedCondition::Current();
auto self = this->operator->();
Array<OpSpecialization> specializations = self->specializations;
OpSpecialization op_spec;
for (OpSpecialization op_spec : specializations) {
if (op_spec->condition == curr_cond) {
op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel);
return;
}
}
ObjectPtr<OpSpecializationNode> n = make_object<OpSpecializationNode>();
n->condition = curr_cond;
op_spec = OpSpecialization(n);
op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel);
self->specializations.push_back(op_spec);
}
TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute")
.set_body([](TVMArgs args, TVMRetValue* rv) {
OpImplementation imp = args[0];
Attrs attrs = args[1];
Array<te::Tensor> inputs = args[2];
Type out_type = args[3];
*rv = imp.Compute(attrs, inputs, out_type);
});
TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
OpImplementation imp = args[0];
Attrs attrs = args[1];
Array<te::Tensor> outs = args[2];
Target target = args[3];
*rv = imp.Schedule(attrs, outs, target);
});
TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectPtr<OpStrategyNode> n = make_object<OpStrategyNode>();
*rv = OpStrategy(n);
});
TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation")
.set_body([](TVMArgs args, TVMRetValue* rv) {
OpStrategy strategy = args[0];
FTVMCompute compute = args[1];
FTVMSchedule schedule = args[2];
std::string name = args[3];
int plevel = args[4];
strategy.AddImplementation(compute, schedule, name, plevel);
});
} // namespace relay
} // namespace tvm
...@@ -79,7 +79,7 @@ TVM_ADD_FILELINE) ...@@ -79,7 +79,7 @@ TVM_ADD_FILELINE)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
...@@ -105,7 +105,7 @@ TVM_ADD_FILELINE) ...@@ -105,7 +105,7 @@ TVM_ADD_FILELINE)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
...@@ -123,7 +123,7 @@ Mark the start of bitpacking. ...@@ -123,7 +123,7 @@ Mark the start of bitpacking.
ElemwiseArbitraryLayout) ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
...@@ -140,7 +140,7 @@ Mark the end of bitpacking. ...@@ -140,7 +140,7 @@ Mark the end of bitpacking.
ElemwiseArbitraryLayout) ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
...@@ -163,7 +163,7 @@ Mark a checkpoint for checkpointing memory optimization. ...@@ -163,7 +163,7 @@ Mark a checkpoint for checkpointing memory optimization.
ElemwiseArbitraryLayout) ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
Array<te::Tensor> outputs; Array<te::Tensor> outputs;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
outputs.push_back(topi::identity(inputs[i])); outputs.push_back(topi::identity(inputs[i]));
...@@ -184,7 +184,7 @@ Beginning of a region that is handled by a given compiler. ...@@ -184,7 +184,7 @@ Beginning of a region that is handled by a given compiler.
ElemwiseArbitraryLayout) ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
...@@ -209,7 +209,7 @@ End of a region that is handled by a given compiler. ...@@ -209,7 +209,7 @@ End of a region that is handled by a given compiler.
ElemwiseArbitraryLayout) ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
......
...@@ -37,8 +37,7 @@ TVM_REGISTER_NODE_TYPE(DebugAttrs); ...@@ -37,8 +37,7 @@ TVM_REGISTER_NODE_TYPE(DebugAttrs);
Array<te::Tensor> DebugCompute(const Attrs& attrs, Array<te::Tensor> DebugCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
return Array<te::Tensor>{ topi::identity(inputs[0]) }; return Array<te::Tensor>{ topi::identity(inputs[0]) };
} }
......
...@@ -83,7 +83,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") ...@@ -83,7 +83,7 @@ RELAY_REGISTER_OP("memory.alloc_storage")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
...@@ -179,7 +179,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor") ...@@ -179,7 +179,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
...@@ -228,7 +228,7 @@ RELAY_REGISTER_OP("memory.invoke_tvm_op") ...@@ -228,7 +228,7 @@ RELAY_REGISTER_OP("memory.invoke_tvm_op")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
...@@ -252,7 +252,7 @@ RELAY_REGISTER_OP("memory.kill") ...@@ -252,7 +252,7 @@ RELAY_REGISTER_OP("memory.kill")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
...@@ -340,7 +340,7 @@ RELAY_REGISTER_OP("memory.shape_func") ...@@ -340,7 +340,7 @@ RELAY_REGISTER_OP("memory.shape_func")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", .set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs, [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> { const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])}; return {topi::identity(inputs[0])};
}); });
......
...@@ -735,58 +735,6 @@ weight transformation in advance. ...@@ -735,58 +735,6 @@ weight transformation in advance.
.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel); .add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
// Positional relay function to create conv2d winograd nnpack operator
// used by frontend FFI.
Expr MakeConv2DWinogradNNPACK(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_object<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_without_weight_transform");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform")
.set_body_typed(MakeConv2DWinogradNNPACK);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout.
This operator assumes the weight tensor is already pre-transformed by
nn.contrib_conv2d_winograd_nnpack_weight_transform.
- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
- **weight**: Any shape
We do not check the shape for this input tensor. Since different backend
has different layout strategy.
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs); TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
...@@ -850,55 +798,6 @@ weight transformation in advance. ...@@ -850,55 +798,6 @@ weight transformation in advance.
// Positional relay function to create conv2d NCHWc operator // Positional relay function to create conv2d NCHWc operator
// used by frontend FFI. // used by frontend FFI.
Expr MakeConv2DNCHWcInt8(Expr data,
Expr kernel,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_object<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc_int8");
return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc_int8")
.set_body_typed(MakeConv2DNCHWcInt8);
RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc_int8")
.describe(R"code(Compute conv2d with NCHWc data layout with int8 inputs.
- **data**: Input is 5D packed tensor.
- **weight**: 7D packed tensor.
- **out**: Output is 5D packed tensor
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DNCHWcInt8", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ConvInferCorrectLayout<Conv2DAttrs>);
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
Expr MakeConv2DNCHWc(Expr data, Expr MakeConv2DNCHWc(Expr data,
Expr kernel, Expr kernel,
Array<IndexExpr> strides, Array<IndexExpr> strides,
......
...@@ -153,6 +153,16 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -153,6 +153,16 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< " But got " << out_layout; << " But got " << out_layout;
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape); Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
bool is_depthwise = false;
if (param->groups > 1) {
CHECK(weight && weight->shape.defined()) <<
"Weight shape must be specified when groups is greater than 1.";
Array<IndexExpr> wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape);
if (tvm::tir::Equal(param->groups, dshape_nchw[1]) &&
tvm::tir::Equal(param->groups, wshape_oihw[0])) {
is_depthwise = true;
}
}
IndexExpr channels, dilated_ksize_y, dilated_ksize_x; IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
// infer weight if the kernel_size and channels are defined // infer weight if the kernel_size and channels are defined
...@@ -161,9 +171,9 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -161,9 +171,9 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK_EQ(param->dilation.size(), 2); CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape; Array<IndexExpr> wshape;
if (tvm::tir::Equal(param->channels, param->groups) && !tvm::tir::Equal(param->channels, 1)) { if (is_depthwise) {
// infer weight's shape for depthwise convolution // infer weight's shape for depthwise convolution
wshape = {{dshape_nchw[1], indexdiv(param->groups, dshape_nchw[1]), param->kernel_size[0], wshape = {{dshape_nchw[1], indexdiv(param->channels, dshape_nchw[1]), param->kernel_size[0],
param->kernel_size[1]}}; param->kernel_size[1]}};
} else { } else {
wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0], wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0],
......
...@@ -93,8 +93,9 @@ RELAY_REGISTER_OP("nn.bias_add") ...@@ -93,8 +93,9 @@ RELAY_REGISTER_OP("nn.bias_add")
.add_argument("bias", "1D Tensor", "Bias.") .add_argument("bias", "1D Tensor", "Bias.")
.set_support_level(1) .set_support_level(1)
.add_type_rel("BiasAdd", BiasAddRel) .add_type_rel("BiasAdd", BiasAddRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs, .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Type& out_type, const Target& target) { const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<BiasAddAttrs>(); const auto* param = attrs.as<BiasAddAttrs>();
return tvm::Array<tvm::te::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)}; return tvm::Array<tvm::te::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
}); });
...@@ -234,8 +235,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") ...@@ -234,8 +235,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs, "FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
const auto* param = attrs.as<LeakyReluAttrs>(); const auto* param = attrs.as<LeakyReluAttrs>();
return Array<te::Tensor>{ topi::leaky_relu(inputs[0], param->alpha) }; return Array<te::Tensor>{ topi::leaky_relu(inputs[0], param->alpha) };
}); });
...@@ -315,8 +315,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. ...@@ -315,8 +315,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs, "FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
const auto* param = attrs.as<PReluAttrs>(); const auto* param = attrs.as<PReluAttrs>();
return Array<te::Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)}; return Array<te::Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)};
}); });
...@@ -351,8 +350,7 @@ RELAY_REGISTER_OP("nn.softmax") ...@@ -351,8 +350,7 @@ RELAY_REGISTER_OP("nn.softmax")
.add_type_rel("Identity", IdentityRel) .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
const auto* param = attrs.as<SoftmaxAttrs>(); const auto* param = attrs.as<SoftmaxAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
return Array<te::Tensor>{ topi::nn::softmax(inputs[0], param->axis) }; return Array<te::Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
...@@ -385,8 +383,7 @@ RELAY_REGISTER_OP("nn.log_softmax") ...@@ -385,8 +383,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
.add_type_rel("Identity", IdentityRel) .add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
const auto* param = attrs.as<SoftmaxAttrs>(); const auto* param = attrs.as<SoftmaxAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1) CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
...@@ -462,8 +459,7 @@ Example:: ...@@ -462,8 +459,7 @@ Example::
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs, "FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
return Array<te::Tensor>{ topi::nn::flatten(inputs[0]) }; return Array<te::Tensor>{ topi::nn::flatten(inputs[0]) };
}); });
...@@ -489,8 +485,7 @@ RELAY_REGISTER_OP("nn.relu") ...@@ -489,8 +485,7 @@ RELAY_REGISTER_OP("nn.relu")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
return Array<te::Tensor>{ topi::relu(inputs[0], 0.0f) }; return Array<te::Tensor>{ topi::relu(inputs[0], 0.0f) };
}); });
......
...@@ -162,8 +162,7 @@ bool PadRel(const Array<Type>& types, ...@@ -162,8 +162,7 @@ bool PadRel(const Array<Type>& types,
Array<te::Tensor> PadCompute(const Attrs& attrs, Array<te::Tensor> PadCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
const auto* param = attrs.as<PadAttrs>(); const auto* param = attrs.as<PadAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
......
...@@ -165,8 +165,7 @@ bool Pool2DRel(const Array<Type>& types, ...@@ -165,8 +165,7 @@ bool Pool2DRel(const Array<Type>& types,
template<typename AttrType, topi::nn::PoolType mode> template<typename AttrType, topi::nn::PoolType mode>
Array<te::Tensor> Pool2DCompute(const Attrs& attrs, Array<te::Tensor> Pool2DCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
static const Layout kNCHW("NCHW"); static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>(); const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -332,8 +331,7 @@ bool GlobalPool2DRel(const Array<Type>& types, ...@@ -332,8 +331,7 @@ bool GlobalPool2DRel(const Array<Type>& types,
template<topi::nn::PoolType mode> template<topi::nn::PoolType mode>
Array<te::Tensor> GlobalPool2DCompute(const Attrs& attrs, Array<te::Tensor> GlobalPool2DCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
static const Layout kNCHW("NCHW"); static const Layout kNCHW("NCHW");
const auto* param = attrs.as<GlobalPool2DAttrs>(); const auto* param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -466,8 +464,7 @@ bool AdaptivePool2DRel(const Array<Type>& types, ...@@ -466,8 +464,7 @@ bool AdaptivePool2DRel(const Array<Type>& types,
template<topi::nn::PoolType mode> template<topi::nn::PoolType mode>
Array<te::Tensor> AdaptivePool2DCompute(const Attrs& attrs, Array<te::Tensor> AdaptivePool2DCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
static const Layout kNCHW("NCHW"); static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AdaptivePool2DAttrs>(); const auto* param = attrs.as<AdaptivePool2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -593,8 +590,9 @@ bool Pool2DGradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -593,8 +590,9 @@ bool Pool2DGradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
} }
template <typename AttrType, topi::nn::PoolType mode> template <typename AttrType, topi::nn::PoolType mode>
Array<te::Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, Array<te::Tensor> Pool2DGradCompute(const Attrs& attrs,
const Type& out_type, const Target& target) { const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCHW("NCHW"); static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>(); const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -794,8 +792,7 @@ bool Pool1DRel(const Array<Type>& types, ...@@ -794,8 +792,7 @@ bool Pool1DRel(const Array<Type>& types,
template<typename AttrType, topi::nn::PoolType mode> template<typename AttrType, topi::nn::PoolType mode>
Array<te::Tensor> Pool1DCompute(const Attrs& attrs, Array<te::Tensor> Pool1DCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
static const Layout kNCW("NCW"); static const Layout kNCW("NCW");
const auto* param = attrs.as<AttrType>(); const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -986,8 +983,7 @@ bool Pool3DRel(const Array<Type>& types, ...@@ -986,8 +983,7 @@ bool Pool3DRel(const Array<Type>& types,
template<typename AttrType, topi::nn::PoolType mode> template<typename AttrType, topi::nn::PoolType mode>
Array<te::Tensor> Pool3DCompute(const Attrs& attrs, Array<te::Tensor> Pool3DCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
static const Layout kNCDHW("NCDHW"); static const Layout kNCDHW("NCDHW");
const auto* param = attrs.as<AttrType>(); const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr); CHECK(param != nullptr);
......
...@@ -33,8 +33,7 @@ namespace relay { ...@@ -33,8 +33,7 @@ namespace relay {
#define RELAY_BINARY_COMPUTE(FTOPI) \ #define RELAY_BINARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \ [] (const Attrs& attrs, \
const Array<te::Tensor>& inputs, \ const Array<te::Tensor>& inputs, \
const Type& out_type, \ const Type& out_type) -> Array<te::Tensor> { \
const Target& target) -> Array<te::Tensor> { \
CHECK_EQ(inputs.size(), 2U); \ CHECK_EQ(inputs.size(), 2U); \
return {FTOPI(inputs[0], inputs[1])}; \ return {FTOPI(inputs[0], inputs[1])}; \
} \ } \
......
...@@ -176,7 +176,6 @@ template<typename F> ...@@ -176,7 +176,6 @@ template<typename F>
Array<te::Tensor> ReduceCompute(const Attrs& attrs, Array<te::Tensor> ReduceCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type,
const Target& target,
F f) { F f) {
const ReduceAttrs* param = attrs.as<ReduceAttrs>(); const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -322,9 +321,8 @@ bool ReduceRel(const Array<Type>& types, ...@@ -322,9 +321,8 @@ bool ReduceRel(const Array<Type>& types,
Array<te::Tensor> ArgMaxCompute(const Attrs& attrs, Array<te::Tensor> ArgMaxCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) { return ReduceCompute(attrs, inputs, out_type, topi::argmax);
return ReduceCompute(attrs, inputs, out_type, target, topi::argmax);
} }
...@@ -342,9 +340,8 @@ values over a given axis. ...@@ -342,9 +340,8 @@ values over a given axis.
Array<te::Tensor> ArgMinCompute(const Attrs& attrs, Array<te::Tensor> ArgMinCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) { return ReduceCompute(attrs, inputs, out_type, topi::argmin);
return ReduceCompute(attrs, inputs, out_type, target, topi::argmin);
} }
RELAY_REGISTER_REDUCE_OP("argmin") RELAY_REGISTER_REDUCE_OP("argmin")
...@@ -360,9 +357,8 @@ values over a given axis. ...@@ -360,9 +357,8 @@ values over a given axis.
Array<te::Tensor> SumCompute(const Attrs& attrs, Array<te::Tensor> SumCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) { return ReduceCompute(attrs, inputs, out_type, topi::sum);
return ReduceCompute(attrs, inputs, out_type, target, topi::sum);
} }
...@@ -394,9 +390,8 @@ Example:: ...@@ -394,9 +390,8 @@ Example::
Array<te::Tensor> AllCompute(const Attrs& attrs, Array<te::Tensor> AllCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) { return ReduceCompute(attrs, inputs, out_type, topi::all);
return ReduceCompute(attrs, inputs, out_type, target, topi::all);
} }
...@@ -431,9 +426,8 @@ Example:: ...@@ -431,9 +426,8 @@ Example::
Array<te::Tensor> AnyCompute(const Attrs& attrs, Array<te::Tensor> AnyCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) { return ReduceCompute(attrs, inputs, out_type, topi::any);
return ReduceCompute(attrs, inputs, out_type, target, topi::any);
} }
...@@ -468,9 +462,8 @@ Example:: ...@@ -468,9 +462,8 @@ Example::
Array<te::Tensor> MaxCompute(const Attrs& attrs, Array<te::Tensor> MaxCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) { return ReduceCompute(attrs, inputs, out_type, topi::max);
return ReduceCompute(attrs, inputs, out_type, target, topi::max);
} }
RELAY_REGISTER_REDUCE_OP("max") RELAY_REGISTER_REDUCE_OP("max")
...@@ -486,9 +479,8 @@ RELAY_REGISTER_REDUCE_OP("max") ...@@ -486,9 +479,8 @@ RELAY_REGISTER_REDUCE_OP("max")
Array<te::Tensor> MinCompute(const Attrs& attrs, Array<te::Tensor> MinCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) { return ReduceCompute(attrs, inputs, out_type, topi::min);
return ReduceCompute(attrs, inputs, out_type, target, topi::min);
} }
...@@ -505,9 +497,8 @@ RELAY_REGISTER_REDUCE_OP("min") ...@@ -505,9 +497,8 @@ RELAY_REGISTER_REDUCE_OP("min")
Array<te::Tensor> ProdCompute(const Attrs& attrs, Array<te::Tensor> ProdCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) { return ReduceCompute(attrs, inputs, out_type, topi::prod);
return ReduceCompute(attrs, inputs, out_type, target, topi::prod);
} }
RELAY_REGISTER_REDUCE_OP("prod") RELAY_REGISTER_REDUCE_OP("prod")
...@@ -535,8 +526,7 @@ Example:: ...@@ -535,8 +526,7 @@ Example::
Array<te::Tensor> MeanCompute(const Attrs& attrs, Array<te::Tensor> MeanCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
IndexExpr count = tir::make_const(inputs[0]->dtype, 1); IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
const ReduceAttrs* param = attrs.as<ReduceAttrs>(); const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -546,7 +536,7 @@ Array<te::Tensor> MeanCompute(const Attrs& attrs, ...@@ -546,7 +536,7 @@ Array<te::Tensor> MeanCompute(const Attrs& attrs,
param->exclude)) { param->exclude)) {
count *= inputs[0]->shape[i]; count *= inputs[0]->shape[i];
} }
auto res = ReduceCompute(attrs, inputs, out_type, target, topi::sum); auto res = ReduceCompute(attrs, inputs, out_type, topi::sum);
return {topi::divide(res[0], count)}; return {topi::divide(res[0], count)};
} }
...@@ -600,8 +590,7 @@ bool VarianceRel(const Array<Type>& types, ...@@ -600,8 +590,7 @@ bool VarianceRel(const Array<Type>& types,
Array<te::Tensor> VarianceCompute(const Attrs& attrs, Array<te::Tensor> VarianceCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
IndexExpr count = tir::make_const(inputs[0]->dtype, 1); IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
const ReduceAttrs* param = attrs.as<ReduceAttrs>(); const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -615,7 +604,7 @@ Array<te::Tensor> VarianceCompute(const Attrs& attrs, ...@@ -615,7 +604,7 @@ Array<te::Tensor> VarianceCompute(const Attrs& attrs,
} }
std::vector<Integer> expand_shape; std::vector<Integer> expand_shape;
auto sq_diff = topi::power(topi::subtract(data, mean), 2); auto sq_diff = topi::power(topi::subtract(data, mean), 2);
auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, target, topi::sum)[0], count); auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, topi::sum)[0], count);
return {var}; return {var};
} }
......
...@@ -35,8 +35,7 @@ namespace relay { ...@@ -35,8 +35,7 @@ namespace relay {
#define RELAY_UNARY_COMPUTE(FTOPI) \ #define RELAY_UNARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \ [] (const Attrs& attrs, \
const Array<te::Tensor>& inputs, \ const Array<te::Tensor>& inputs, \
const Type& out_type, \ const Type& out_type) -> Array<te::Tensor> { \
const Target& target) -> Array<te::Tensor> { \
return {FTOPI(inputs[0])}; \ return {FTOPI(inputs[0])}; \
} \ } \
...@@ -303,8 +302,7 @@ bool ShapeOfRel(const Array<Type>& types, ...@@ -303,8 +302,7 @@ bool ShapeOfRel(const Array<Type>& types,
Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, Array<te::Tensor> ShapeOfCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
CHECK_EQ(inputs.size(), 1); CHECK_EQ(inputs.size(), 1);
const auto* param = attrs.as<ShapeOfAttrs>(); const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -354,8 +352,7 @@ bool NdarraySizeRel(const Array<Type>& types, ...@@ -354,8 +352,7 @@ bool NdarraySizeRel(const Array<Type>& types,
Array<te::Tensor> NdarraySizeCompute(const Attrs& attrs, Array<te::Tensor> NdarraySizeCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
CHECK_EQ(inputs.size(), 1); CHECK_EQ(inputs.size(), 1);
const auto* param = attrs.as<NdarraySizeAttrs>(); const auto* param = attrs.as<NdarraySizeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
......
...@@ -83,8 +83,7 @@ Its function is mostly shape transform.")doc" TVM_ADD_FILELINE) ...@@ -83,8 +83,7 @@ Its function is mostly shape transform.")doc" TVM_ADD_FILELINE)
.add_type_rel("YoloReorg", YoloReorgRel) .add_type_rel("YoloReorg", YoloReorgRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs, const Array<te::Tensor>& inputs,
const Type& out_type, const Type& out_type) {
const Target& target) {
const auto* params = attrs.as<YoloReorgAttrs>(); const auto* params = attrs.as<YoloReorgAttrs>();
CHECK(params != nullptr); CHECK(params != nullptr);
return Array<te::Tensor>{ topi::vision::reorg(inputs[0], params->stride) }; return Array<te::Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
......
...@@ -83,7 +83,10 @@ class AlterTransformMemorizer : public TransformMemorizer { ...@@ -83,7 +83,10 @@ class AlterTransformMemorizer : public TransformMemorizer {
auto ttype = expr->type_as<TensorTypeNode>(); auto ttype = expr->type_as<TensorTypeNode>();
tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype)); tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
} }
Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos); // TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes.
// Probably we need to disable the AlterOpLayout when compiling dynamic models.
Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos,
ref_call->checked_type());
if (altered_value.defined()) { if (altered_value.defined()) {
new_e = altered_value; new_e = altered_value;
modified = true; modified = true;
......
...@@ -20,9 +20,11 @@ ...@@ -20,9 +20,11 @@
/*! /*!
* \file schedule_lang.cc * \file schedule_lang.cc
*/ */
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/te/schedule.h> #include <tvm/te/schedule.h>
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <stack>
#include <unordered_set> #include <unordered_set>
#include "graph.h" #include "graph.h"
...@@ -787,6 +789,53 @@ IterVarRelation SingletonNode::make(IterVar iter) { ...@@ -787,6 +789,53 @@ IterVarRelation SingletonNode::make(IterVar iter) {
return IterVarRelation(n); return IterVarRelation(n);
} }
SpecializedCondition::SpecializedCondition(Array<PrimExpr> conditions) {
ObjectPtr<SpecializedConditionNode> n = make_object<SpecializedConditionNode>();
n->clauses = std::move(conditions);
data_ = std::move(n);
}
/*! \brief Entry to hold the SpecializedCondition context stack. */
struct TVMSpecializationThreadLocalEntry {
/*! \brief The current specialized condition */
std::stack<SpecializedCondition> condition_stack;
};
/*! \brief Thread local store to hold the Target context stack. */
typedef dmlc::ThreadLocalStore<TVMSpecializationThreadLocalEntry> TVMSpecializationThreadLocalStore;
void SpecializedCondition::EnterWithScope() {
TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
entry->condition_stack.push(*this);
}
void SpecializedCondition::ExitWithScope() {
TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
CHECK(!entry->condition_stack.empty());
CHECK(entry->condition_stack.top().same_as(*this));
entry->condition_stack.pop();
}
SpecializedCondition SpecializedCondition::Current() {
TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
SpecializedCondition cond;
if (entry->condition_stack.size() > 0) {
cond = entry->condition_stack.top();
}
return cond;
}
class SpecializedCondition::Internal {
public:
static void EnterScope(SpecializedCondition cond) {
cond.EnterWithScope();
}
static void ExitScope(SpecializedCondition cond) {
cond.ExitWithScope();
}
};
TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode); TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(SplitNode);
...@@ -794,6 +843,7 @@ TVM_REGISTER_NODE_TYPE(FuseNode); ...@@ -794,6 +843,7 @@ TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode); TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(SingletonNode); TVM_REGISTER_NODE_TYPE(SingletonNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode); TVM_REGISTER_NODE_TYPE(ScheduleNode);
TVM_REGISTER_NODE_TYPE(SpecializedConditionNode);
// Printer // Printer
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -848,7 +898,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -848,7 +898,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ScheduleNode*>(node.get()); auto* op = static_cast<const ScheduleNode*>(node.get());
p->stream << "schedule(" << op << ")"; p->stream << "schedule(" << op << ")";
}); })
.set_dispatch<SpecializedConditionNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SpecializedConditionNode*>(node.get());
p->stream << "specialized_condition(";
p->Print(op->clauses);
p->stream << ')';
});
TVM_REGISTER_GLOBAL("te.CreateSchedule") TVM_REGISTER_GLOBAL("te.CreateSchedule")
...@@ -962,5 +1018,22 @@ TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite") ...@@ -962,5 +1018,22 @@ TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
TVM_REGISTER_GLOBAL("te.ScheduleRFactor") TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
.set_body_method(&Schedule::rfactor); .set_body_method(&Schedule::rfactor);
TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition")
.set_body_typed([](Array<PrimExpr> condition) {
return SpecializedCondition(condition);
});
TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SpecializedCondition::Current();
});
TVM_REGISTER_GLOBAL("te.EnterSpecializationScope")
.set_body_typed(SpecializedCondition::Internal::EnterScope);
TVM_REGISTER_GLOBAL("te.ExitSpecializationScope")
.set_body_typed(SpecializedCondition::Internal::ExitScope);
} // namespace te } // namespace te
} // namespace tvm } // namespace tvm
...@@ -24,18 +24,56 @@ ...@@ -24,18 +24,56 @@
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <tvm/relay/op_strategy.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/broadcast.h>
#include <topi/generic/injective.h> #include <topi/generic/injective.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
TVM_REGISTER_GLOBAL("test.sch") using namespace tvm;
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) { using namespace tvm::relay;
*rv = topi::generic::schedule_injective(args[0], args[1]);
}); TVM_REGISTER_GLOBAL("test.strategy")
.set_body_typed([](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type, const Target& target) {
FTVMCompute fcompute = [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) -> Array<te::Tensor> {
CHECK_EQ(inputs.size(), 2U);
return {topi::add(inputs[0], inputs[1])};
};
FTVMSchedule fschedule = [](const Attrs& attrs,
const Array<te::Tensor>& outs,
const Target& target) {
With<Target> target_scope(target);
return topi::generic::schedule_injective(target, outs);
};
auto n = make_object<OpStrategyNode>();
auto strategy = tvm::relay::OpStrategy(std::move(n));
strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10);
return strategy;
});
TVM_REGISTER_GLOBAL("relay.backend.lower_call")
.set_body_typed([](const relay::Call& call, const Array<te::Tensor>& inputs,
const Target& target) {
static auto fstrategy = Op::GetAttr<relay::FTVMStrategy>("FTVMStrategy");
Op op = Downcast<Op>(call->op);
auto out_type = call->checked_type();
OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target);
auto impl = strategy->specializations[0]->implementations[0];
auto outs = impl.Compute(call->attrs, inputs, out_type);
auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput");
if (!f) {
LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered";
}
return (*f)(outs, impl);
});
TEST(Relay, BuildModule) { TEST(Relay, BuildModule) {
using namespace tvm;
auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32)); auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
auto a = relay::VarNode::make("a", tensor_type); auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type); auto b = relay::VarNode::make("b", tensor_type);
...@@ -59,14 +97,15 @@ TEST(Relay, BuildModule) { ...@@ -59,14 +97,15 @@ TEST(Relay, BuildModule) {
} }
// get schedule // get schedule
auto reg = tvm::runtime::Registry::Get("relay.op._Register"); auto reg = tvm::runtime::Registry::Get("relay.op._Register");
auto s_i = tvm::runtime::Registry::Get("test.sch");
if (!reg) { if (!reg) {
LOG(FATAL) << "no _Register"; LOG(FATAL) << "no _Register";
} }
if (!s_i) { auto fs = tvm::runtime::Registry::Get("test.strategy");
LOG(FATAL) << "no _Register"; if (!fs) {
LOG(FATAL) << "No test_strategy registered.";
} }
(*reg)("add", "FTVMSchedule", *s_i, 10); auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs);
(*reg)("add", "FTVMStrategy", fgeneric, 10);
// build // build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)(); tvm::runtime::Module build_mod = (*pfb)();
......
...@@ -852,17 +852,22 @@ def test_forward_slice(): ...@@ -852,17 +852,22 @@ def test_forward_slice():
def test_forward_convolution(): def test_forward_convolution():
def verify(data_shape, kernel_size, stride, pad, num_filter): def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False):
if is_depthwise:
groups = data_shape[1]
weight_shape=(data_shape[1], num_filter // groups,) + kernel_size
else:
groups = 1
weight_shape=(num_filter, data_shape[1],) + kernel_size weight_shape=(num_filter, data_shape[1],) + kernel_size
x = np.random.uniform(size=data_shape).astype("float32") x = np.random.uniform(size=data_shape).astype("float32")
weight = np.random.uniform(size=weight_shape).astype("float32") weight = np.random.uniform(size=weight_shape).astype("float32")
bias = np.random.uniform(size=num_filter).astype("float32") bias = np.random.uniform(size=num_filter).astype("float32")
ref_res = mx.nd.Convolution(data=mx.nd.array(x), weight=mx.nd.array(weight), ref_res = mx.nd.Convolution(data=mx.nd.array(x), weight=mx.nd.array(weight),
bias=mx.nd.array(bias), kernel=kernel_size, stride=stride, bias=mx.nd.array(bias), kernel=kernel_size, stride=stride,
pad=pad, num_filter=num_filter) pad=pad, num_filter=num_filter, num_group=groups)
mx_sym = mx.sym.Convolution(mx.sym.var("x"), mx.sym.var("weight"), mx.sym.var("bias"), mx_sym = mx.sym.Convolution(mx.sym.var("x"), mx.sym.var("weight"), mx.sym.var("bias"),
kernel=kernel_size, stride=stride, kernel=kernel_size, stride=stride,
pad=pad, num_filter=num_filter) pad=pad, num_filter=num_filter, num_group=groups)
shape_dict = {"x": x.shape, "weight": weight.shape, "bias": bias.shape} shape_dict = {"x": x.shape, "weight": weight.shape, "bias": bias.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
...@@ -879,6 +884,8 @@ def test_forward_convolution(): ...@@ -879,6 +884,8 @@ def test_forward_convolution():
verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=8,
is_depthwise=True)
def test_forward_deconvolution(): def test_forward_deconvolution():
def verify(data_shape, kernel_size, stride, pad, num_filter): def verify(data_shape, kernel_size, stride, pad, num_filter):
......
...@@ -25,7 +25,7 @@ import tvm ...@@ -25,7 +25,7 @@ import tvm
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.tuner import RandomTuner from tvm.autotvm.tuner import RandomTuner
@autotvm.template @autotvm.register_customized_task("testing/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CI, CO, KH, KW): def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
"""An example template for testing""" """An example template for testing"""
assert N == 1, "Only consider batch_size = 1 in this template" assert N == 1, "Only consider batch_size = 1 in this template"
...@@ -114,7 +114,7 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW): ...@@ -114,7 +114,7 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
def get_sample_task(target=tvm.target.cuda(), target_host=None): def get_sample_task(target=tvm.target.cuda(), target_host=None):
"""return a sample task for testing""" """return a sample task for testing"""
task = autotvm.task.create(conv2d_no_batching, task = autotvm.task.create("testing/conv2d_no_batching",
args=(1, 7, 7, 512, 512, 3, 3), args=(1, 7, 7, 512, 512, 3, 3),
target=target, target_host=target_host) target=target, target_host=target_host)
return task, target return task, target
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import numpy as np import numpy as np
import pytest
import tvm import tvm
from tvm import relay from tvm import relay
...@@ -384,6 +385,8 @@ def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation ...@@ -384,6 +385,8 @@ def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation
assert result.asnumpy().shape == ref_out_shape, \ assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
# TODO(@kevinthesun): Need to fix the compute in conv2d_NCHWc to support any
@pytest.mark.skip
def test_any_conv2d_NCHWc(): def test_any_conv2d_NCHWc():
verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (1, 1), verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (1, 1),
"NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 224, 224, 8)) "NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 224, 224, 8))
......
...@@ -39,25 +39,28 @@ def test_task_extraction(): ...@@ -39,25 +39,28 @@ def test_task_extraction():
target = 'llvm' target = 'llvm'
mod_list = [] mod_list = []
params_list = [] params_list = []
conv2d = relay.op.get("nn.conv2d")
conv2d_transpose = relay.op.get("nn.conv2d_transpose")
dense = relay.op.get("nn.dense")
mod, params, _ = get_network('resnet-18', batch_size=1) mod, params, _ = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d,)) ops=(conv2d,))
assert len(tasks) == 12 assert len(tasks) == 12
tasks = autotvm.task.extract_from_program(mod, target=target, tasks = autotvm.task.extract_from_program(mod, target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d,)) ops=(conv2d,))
assert len(tasks) == 12 assert len(tasks) == 12
mod, params, _ = get_network('resnet-18', batch_size=1) mod, params, _ = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, params=params,
ops=(relay.op.nn.dense,)) ops=(dense,))
assert len(tasks) == 1 assert len(tasks) == 1
tasks = autotvm.task.extract_from_program(mod, target=target, tasks = autotvm.task.extract_from_program(mod, target=target,
params=params, params=params,
ops=(relay.op.nn.dense,)) ops=(dense,))
assert len(tasks) == 1 assert len(tasks) == 1
mod, params, _ = get_network('resnet-18', batch_size=1) mod, params, _ = get_network('resnet-18', batch_size=1)
...@@ -65,11 +68,14 @@ def test_task_extraction(): ...@@ -65,11 +68,14 @@ def test_task_extraction():
params_list.append(params) params_list.append(params)
tasks = autotvm.task.extract_from_program(mod["main"], target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense)) ops=(conv2d, dense))
assert len(tasks) == 13 assert len(tasks) == 13
tasks = autotvm.task.extract_from_program(mod, target=target, tasks = autotvm.task.extract_from_program(mod, target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense)) ops=(conv2d, dense))
assert len(tasks) == 13
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params)
assert len(tasks) == 13 assert len(tasks) == 13
mod, params, _ = get_network('mobilenet', batch_size=1) mod, params, _ = get_network('mobilenet', batch_size=1)
...@@ -77,65 +83,19 @@ def test_task_extraction(): ...@@ -77,65 +83,19 @@ def test_task_extraction():
params_list.append(params) params_list.append(params)
tasks = autotvm.task.extract_from_program(mod, target=target, tasks = autotvm.task.extract_from_program(mod, target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense)) ops=(conv2d, dense))
assert len(tasks) == 20 assert len(tasks) == 20
mod, params, _ = get_network('dcgan', batch_size=1) mod, params, _ = get_network('dcgan', batch_size=1)
tasks = autotvm.task.extract_from_program(mod, target=target, tasks = autotvm.task.extract_from_program(mod, target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d_transpose,)) ops=(conv2d_transpose,))
assert len(tasks) == 4 assert len(tasks) == 4
tasks = autotvm.task.extract_from_multiple_program(mod_list, params_list, tasks = autotvm.task.extract_from_multiple_program(mod_list, params_list,
target=target, target=target,
ops=(relay.op.nn.conv2d,)) ops=(conv2d,))
assert len(tasks) == 31 assert len(tasks) == 31
def test_template_key_provided():
"""test task extraction using non-'direct' template_key"""
target = 'llvm'
import topi
template_keys = {
# topi.nn.conv2d - is left blank to test fallback logic
topi.nn.dense: 'direct_nopack',
topi.nn.depthwise_conv2d_nchw: 'direct',
}
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense),
template_keys=template_keys)
for task in tasks:
if 'dense' in task.name:
assert task.config_space.template_key == 'direct_nopack'
else:
assert task.config_space.template_key == 'direct'
def test_template_key_empty():
"""test task extraction using empty template_key"""
target = 'llvm'
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense),
template_keys=None)
for task in tasks:
assert task.config_space.template_key == 'direct'
def test_template_key_default():
"""test task extraction without template_key"""
target = 'llvm'
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
for task in tasks:
assert task.config_space.template_key == 'direct'
if __name__ == '__main__': if __name__ == '__main__':
test_task_extraction() test_task_extraction()
test_template_key_provided()
test_template_key_empty()
test_template_key_default()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment