Unverified Commit 623dd208 by Haichen Shen Committed by GitHub

[Relay][AutoTVM] Relay op strategy (#4644)

* relay op strategy

fix lint

bitpack strategy

bitserial_dense (#6)

* update strategy

* address comments

fix a few topi test

Dense strategy (#5)

* dense

* add biforst; remove comments

* address comment

Refactor x86 conv2d_NCHWc (#4)

* Refactor x86 conv2d

* Add x86 depthwise_conv2d_NCHWc

* Add back topi x86 conv2d_nchw

* Merge x86 conv2d_nchw and conv2d_NCHWc

* Minor fix for x86 conv2d

fix more strategy

Add x86 conv2d_NCHWc_int8 strategy (#8)

* Add x86 conv2d_NCHWc_int8 strategy

* Remove contrib_conv2d_nchwc_int8

* Fix generic conv2d_NCHWc for int8

* Fix topi arm_cpu conv2d_NCHWc_int8

update x86 conv2d

enable specify relay ops to be tuned for autotvm

add cuda conv2d strategy

add conv2d strategy for rocm

add conv2d strategy for hls

add conv2d strategy for arm cpu

add conv2d strategy for mali

add conv2d strategy for bifrost

add conv2d strategy for intel graphics

clean up and fix lint

remove template keys from autotvm

remove 2 in the func name

address comments

fix

* fix bugs

* lint

* address comments

* add name to op implement

* Modify topi tests (#9)

* Add pooling, reorg, softmax and vision

* Add lrn

* fix topi test

* fix more topi test

* lint

* address comments

* x

* fix more tests & bugs

* Modify more tests (#10)

* Modify tests for bitserial_conv2d, bitserial_dense, bitserial_conv2d_rasp and bnn

* Minor fix

* More minor fix

* fix more test

* try to update vta using strategy

* fix cpptest

* x

* fix rebase err

* Fix two tests (#11)

* change autotvm log format

* lint

* minor fix

* try fix vta test

* fix rebase err

* tweak

* tmp hack for vta pass

* fix tutorial

* fix

* fix more tutorials

* fix vta tutorial

* minor

* address comments

* fix

* address comments

* fix cpptest

* fix docs

* change data structure name and api

* address comments

* lint

* fix rebase err

* updates

* fix winograd test

* fix doc

* rebase

* upgrade tophub version number

* fix bug

* re-enable vta tsim test after tophub is upgraded

* fix vta test to use the correct args so the config can be found in tophub

Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
parent c4c61cb7
......@@ -29,6 +29,7 @@
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
#include <tvm/target/target.h>
#include <tvm/target/generic_func.h>
#include <tvm/tir/data_layout.h>
#include <string>
......@@ -105,9 +106,8 @@ using TShapeDataDependant = bool;
*/
using FTVMCompute = runtime::TypedPackedFunc<
Array<te::Tensor>(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target)>;
const Array<te::Tensor>& inputs,
const Type& out_type)>;
/*!
* \brief Build the computation schedule for
......@@ -120,8 +120,18 @@ using FTVMCompute = runtime::TypedPackedFunc<
*/
using FTVMSchedule = runtime::TypedPackedFunc<
te::Schedule(const Attrs& attrs,
const Array<te::Tensor>& outs,
const Target& target)>;
const Array<te::Tensor>& outs,
const Target& target)>;
/*!
* \brief Generate the strategy of operators. This function is a generic
* function and can be re-defined for different targets.
*
* The function signature of generic function is:
* OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
* const Type& out_type, const Target& target)
*/
using FTVMStrategy = GenericFunc;
/*!
* \brief Alternate the layout of operators or replace the
......@@ -136,7 +146,8 @@ using FTVMSchedule = runtime::TypedPackedFunc<
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<te::Tensor>& tinfos)>;
const Array<te::Tensor>& tinfos,
const Type& out_type)>;
/*!
* \brief Convert the layout of operators or replace the
......@@ -191,9 +202,7 @@ using FForwardRewrite = runtime::TypedPackedFunc<
* \brief Gradient for a specific op.
*
* \param orig_call the original Expr.
*
* \param output_grad the gradient of the Expr.
*
* \return the gradient for each parameters.
*/
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
......@@ -207,13 +216,13 @@ enum AnyCodegenStrategy {
kVariableDimensions
};
/* \brief A runtime representation of shape. */
/*! \brief A runtime representation of shape. */
using Shape = Array<IndexExpr>;
using FShapeFunc = runtime::TypedPackedFunc<
Array<te::Tensor>(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Array<IndexExpr>& out_ndims)>;
const Array<te::Tensor>& inputs,
const Array<IndexExpr>& out_ndims)>;
} // namespace relay
} // namespace tvm
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/op_strategy.h
* \brief The Relay operator Strategy and related data structure.
*/
#ifndef TVM_RELAY_OP_STRATEGY_H_
#define TVM_RELAY_OP_STRATEGY_H_
#include <tvm/te/tensor.h>
#include <tvm/te/schedule.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/target/target.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Operator implementation that includes compute and schedule function.
*/
class OpImplementationNode : public Object {
public:
/*! \brief Compute function */
FTVMCompute fcompute;
/*! \brief Schedule function */
FTVMSchedule fschedule;
/*! \brief Name of the implementation */
std::string name;
/*! \brief Priority level */
int plevel;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("plevel", &plevel);
}
static constexpr const char* _type_key = "relay.OpImplementation";
TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object);
};
/*!
* \brief Operator implementation class.
*/
class OpImplementation : public ObjectRef {
public:
/*!
* \brief Invoke the operator compute function.
* \param attrs The attribute of the primitive
* \param inputs The input tensors.
* \param out_type The output type information.
* \return The output compute description of the operator.
*/
TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type);
/*!
* \brief Build the computation schedule.
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return The computation schedule.
*/
TVM_DLL te::Schedule Schedule(const Attrs& attrs,
const Array<te::Tensor>& outs,
const Target& target);
TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode);
};
/*!
* \brief Specialized implementations for operators under certain conditions.
*/
class OpSpecializationNode : public Object {
public:
/*! \brief List of implementations. */
Array<OpImplementation> implementations;
/*! \brief Condition to enable the specialization.
* Could be undefined to represent generic case. */
te::SpecializedCondition condition;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("condition", &condition);
v->Visit("implementations", &implementations);
}
static constexpr const char* _type_key = "relay.OpSpecialization";
TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);
};
/*!
* \brief Operator specialization class.
*/
class OpSpecialization : public ObjectRef {
public:
/*!
* \brief Add an implementation.
* \param fcompute Compute function
* \param fschedule Schedule function
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
std::string name, int plevel);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
};
/*!
* \brief Operator strategy to choose implementation.
*/
class OpStrategyNode : public Object {
public:
/*! \brief List of operator specializations. */
Array<OpSpecialization> specializations;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("specializations", &specializations);
}
static constexpr const char* _type_key = "relay.OpStrategy";
TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
};
/*!
* \brief Operator strategy class.
*/
class OpStrategy : public ObjectRef {
public:
/*!
* \brief Add an implementation.
* \param fcompute Compute function
* \param fschedule Schedule function
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
std::string name, int plevel);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_STRATEGY_H_
......@@ -28,6 +28,7 @@
#include <tvm/tir/expr.h>
#include <tvm/te/tensor.h>
#include <tvm/te/tensor_intrin.h>
#include <tvm/support/with.h>
#include <string>
#include <unordered_map>
......@@ -742,6 +743,53 @@ class SingletonNode : public IterVarRelationNode {
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
};
/*! \brief Container for specialization conditions. */
class SpecializedConditionNode : public Object {
public:
/*!
* \brief List of conditions in conjunctive joint form (CNF).
* Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc.,
* where n, m are tvm::Var that represents a dimension in the tensor shape.
*/
Array<PrimExpr> clauses;
void VisitAttrs(AttrVisitor* v) {
v->Visit("clauses", &clauses);
}
static constexpr const char* _type_key = "SpecializedCondition";
TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
};
/*!
* \brief Specialized condition to enable op specialization
*/
class SpecializedCondition : public ObjectRef {
public:
/*!
* \brief construct from conditions
* \param conditions The clauses in the specialized condition.
*/
TVM_DLL SpecializedCondition(Array<PrimExpr> conditions); // NOLINT(*)
/*!
* \brief Get the current specialized condition.
* \return the current specialized condition.
*/
TVM_DLL static SpecializedCondition Current();
TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode);
class Internal;
private:
// enable with syntax.
friend class Internal;
friend class With<SpecializedCondition>;
/*! \brief Push a new specialized condition onto the thread local stack. */
TVM_DLL void EnterWithScope();
/*! \brief Pop a specialized condition off the thread local context stack. */
TVM_DLL void ExitWithScope();
};
// implementations
inline const StageNode* Stage::operator->() const {
......@@ -765,6 +813,7 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(get());
}
} // namespace te
} // namespace tvm
#endif // TVM_TE_SCHEDULE_H_
......@@ -41,8 +41,8 @@ from . import tophub
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, \
from .task import get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, register_customized_task, \
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
ApplyGraphBest as apply_graph_best
from .env import GLOBAL_SCOPE
......@@ -125,7 +125,7 @@ class RedisDatabase(Database):
current = self.get(measure_str_key(inp))
if current is not None:
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
results = [rec[1] for rec in records]
results = [rec[1] for rec in records if rec is not None]
if get_all:
return results
return max(results, key=lambda result: result.timestamp)
......@@ -167,9 +167,12 @@ class RedisDatabase(Database):
current = self.get(key)
try:
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
records = [rec for rec in records if rec is not None]
except TypeError: # got a badly formatted/old format record
continue
if not records:
continue
inps, results = zip(*records)
inp = inps[0]
if not func(inp, results):
......
......@@ -153,7 +153,10 @@ def get_flatten_name(fea):
from .record import decode
# flatten line to feature
line = fea
inp, _ = decode(line)
ret = decode(line)
if ret is None:
raise ValueError("Unsupported AutoTVM log format")
inp, _ = ret
target = _target.create(inp.target)
with target:
s, args = inp.template.instantiate(inp.config)
......
......@@ -25,7 +25,6 @@ import topi
import tvm
from tvm import autotvm, relay
from tvm.autotvm.task import get_config
from tvm.autotvm.task.topi_integration import deserialize_args, serialize_args
from tvm.autotvm.record import encode, load_from_file
from tvm.autotvm.measure import MeasureResult, MeasureInput
......@@ -35,18 +34,16 @@ from .utils import is_boundary_node, get_in_nodes, get_out_nodes, has_multiple_i
from ._base import INVALID_LAYOUT_TIME
# Setup topi_op_name -> layout function
# NOTE: To add more ops, change the following dictionary.
OP2LAYOUT = {
"topi_nn_conv2d": topi.nn.conv2d_infer_layout,
"topi_nn_depthwise_conv2d_nchw": topi.nn.depthwise_conv2d_infer_layout,
}
def get_infer_layout(task_name):
if task_name.startswith("conv2d"):
return topi.nn.conv2d_infer_layout
if task_name.startswith("depthwise_conv2d"):
return topi.nn.depthwise_conv2d_infer_layout
raise ValueError("Cannot find infer layout for task %s" % task_name)
@autotvm.template
@autotvm.register_customized_task("layout_transform")
def layout_transform(*args):
"""Autotvm layout transform template."""
args = deserialize_args(args)
cfg = get_config()
cfg.add_flop(-1)
data = args[0]
......@@ -82,7 +79,7 @@ class BaseGraphTuner(object):
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
target_ops : List of str
target_ops : List of relay.op.Op
Target tuning operators.
target : str or tvm.target
......@@ -104,7 +101,7 @@ class BaseGraphTuner(object):
self._layout_transform_perf_records = {}
self._layout_transform_interlayer_cost = {}
self._input_shapes = input_shapes
self._target_ops = [op.__name__ for op in target_ops]
self._target_ops = target_ops
self._name = name
self._max_sch_num = max_sch_num
......@@ -179,7 +176,7 @@ class BaseGraphTuner(object):
dtype = first_tensor[-1]
new_shape = tuple([val.value for val in node_entry["types"][0].shape])
actual_workload = (input_workload[0],) + \
((new_shape + (dtype,)),) + input_workload[2:]
(("TENSOR", new_shape, dtype),) + input_workload[2:]
node_entry["workloads"].append(actual_workload)
if "record_candidates" not in node_entry:
node_entry["record_candidates"] = input_node["record_candidates"]
......@@ -212,7 +209,7 @@ class BaseGraphTuner(object):
node_entry["record_candidates"] = cache_dict[workload]
continue
record_candidates = []
infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
layout_tracking_dict = {}
for record in cfg_dict[workload]:
in_measure, out_measure = record
......@@ -264,7 +261,7 @@ class BaseGraphTuner(object):
if node_entry["op"] in self._target_ops:
o_idx = key
o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
o_infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
o_wkl = node_entry["workloads"][0]
i_topi_op = in_node_entry["topi_op"][0]
i_wkl = in_node_entry["workloads"][0]
......@@ -273,14 +270,14 @@ class BaseGraphTuner(object):
pivot += 1
i_topi_op = in_node_entry["topi_op"][pivot]
i_wkl = in_node_entry["workloads"][pivot]
i_infer_layout_func = OP2LAYOUT[i_topi_op]
i_infer_layout_func = get_infer_layout(i_topi_op)
else:
o_idx = target_input_idx
if i <= target_input_pos:
continue
o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
o_infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
o_wkl = node_entry["workloads"][target_input_pos]
i_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][i]]
i_infer_layout_func = get_infer_layout(node_entry["topi_op"][i])
i_wkl = node_entry["workloads"][i]
if (i_idx, o_idx) in pair_tracker:
......@@ -314,9 +311,8 @@ class BaseGraphTuner(object):
to_sch_idx, args):
"""Create dictionary containing matrix format of layout transformation
between nodes."""
sargs = serialize_args(args)
in_layout, out_layout = args[1], args[2]
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(sargs)
ltf_workload = autotvm.task.args_to_workload(args, 'layout_transform')
idx_pair_key = (from_node_idx, to_node_idx)
if in_layout == out_layout:
......@@ -449,9 +445,8 @@ class BaseGraphTuner(object):
measure_option = autotvm.measure_option(builder=builder, runner=runner)
for args in args_list:
data, in_layout, out_layout = args
args = serialize_args(args)
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
if ltf_workload in self._layout_transform_perf_records:
ltf_workload = autotvm.task.args_to_workload(args, 'layout_transform')
if ltf_workload in self._layout_transform_perf_records:
continue
if infer_layout:
......@@ -478,9 +473,8 @@ class BaseGraphTuner(object):
continue
records = []
task = autotvm.task.create(layout_transform, args=args, target=self._target,
task = autotvm.task.create("layout_transform", args=args, target=self._target,
target_host=target_host)
task.workload = ltf_workload
tuner = autotvm.tuner.GridSearchTuner(task)
tuner.tune(n_trial=1, measure_option=measure_option,
callbacks=[_log_to_list(records)])
......
......@@ -18,8 +18,6 @@
"""API for graph traversing."""
import threading
import topi
import tvm
from tvm import relay, autotvm
from tvm.relay import transform
......@@ -30,13 +28,6 @@ from tvm.autotvm.task import TaskExtractEnv
from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
# Setup relay op base name -> topi compute functions
# NOTE: To add more ops, change the following dictionary.
OP2COMPUTE = {
"conv2d" : [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw],
}
def expr2graph(expr, target_ops, node_dict, node_list):
"""Convert relay expr to graph data structure
and fetch workloads of target operators.
......@@ -46,8 +37,8 @@ def expr2graph(expr, target_ops, node_dict, node_list):
expr : tvm.relay.Expr.Function
Input relay function expression.
target_ops: List of str
List of target relay base op name
target_ops: List of relay.op.Op
List of target relay ops
node_dict : dictionary from tvm.relay.Expr to int
Dictionary to record node index
......@@ -58,14 +49,11 @@ def expr2graph(expr, target_ops, node_dict, node_list):
{"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type],
"name": str, "workloads": [tuple], "topi_op": [function]}
"""
# TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact
# that # autotvm tasks == # ops. But this won't be true after having relay op
# strategy. We need to find a solution to fix this.
env = TaskExtractEnv.get(allow_duplicate=True)
topi_funcs = []
for op_name in target_ops:
if op_name not in OP2COMPUTE:
raise RuntimeError("Not supported relay op in graph tuner: %s"
% op_name)
topi_funcs += OP2COMPUTE[op_name]
env.reset(topi_funcs)
env.reset(target_ops)
# pylint: disable=not-context-manager
with env:
_expr2graph_impl(expr, target_ops, node_dict, node_list)
......@@ -75,8 +63,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
task_name, args = env.task_collection[task_pos]
task = autotvm.task.create(task_name, args,
target="llvm",
target_host=None,
template_key='direct')
target_host=None)
node_entry["workloads"] = [task.workload]
node_entry["topi_op"] = [task_name]
task_pos += 1
......@@ -98,11 +85,11 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
return
node_index = len(node_list)
node_entry = {"node": node, "inputs": [], "types": [],
"op": "null", "name": None}
"op": None, "name": None}
if isinstance(node, Call):
op_name = node.op.name.split(".")[-1]
node_entry["op"] = op_name
op = node.op
node_entry["op"] = node.op
for arg in node.args:
in_node_idx = node_dict[arg]
if isinstance(arg, (Tuple, TupleGetItem)):
......@@ -118,12 +105,12 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
node_entry["types"].append(tupe_type)
else:
raise RuntimeError("Unsupported output type %s in operator %s"
% (type(out_type), op_name))
% (type(out_type), op.name))
# Utilize tracing target to fetch workload with topo-order.
# Since we only need workload, dummy target can be used to
# create task.
if op_name in target_ops:
if op in target_ops:
params = []
for i, input_idx in enumerate(node_entry["inputs"]):
input_node_entry = node_list[input_idx[0]]
......@@ -133,7 +120,7 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
"operators with input node of type "
"relay.expr.Var/Constant/Call. Now "
"find a target op %s with input type %s"
% (op_name, str(type(input_node_entry["node"]))))
% (op, str(type(input_node_entry["node"]))))
free_var = relay.Var("var_%d" % i, input_type)
params.append(free_var)
call = relay.Call(node.op, params, node.attrs)
......@@ -155,11 +142,9 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
_expr2graph_impl(node, target_ops, node_dict, node_list)
return
elif isinstance(node, TupleGetItem):
node_entry["op"] = "TupleGetItem"
in_node_idx = node_dict[node.tuple_value]
node_entry["inputs"].append([in_node_idx, node.index, 0])
elif isinstance(node, Tuple):
node_entry["op"] = "Tuple"
for tuple_item in node:
in_node_idx = node_dict[tuple_item]
if isinstance(tuple_item, TupleGetItem):
......
......@@ -47,7 +47,7 @@ def has_multiple_inputs(node_list, node_idx, input_names):
in_idx = in_idx[0]
in_node = node_list[in_idx]
# Exclude parameter nodes
if in_node["op"] != "null" or \
if in_node["op"] is not None or \
("name" in in_node and in_node["name"] in input_names):
num_inputs += 1
return num_inputs > 1
......@@ -72,9 +72,10 @@ def is_boundary_node(node_entry, input_names):
whether node is a boundary node.
"""
# Operators dependent on original layouts.
_LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]
_LAYOUT_FIXED_OP = [relay.op.get(name) for name in (
"nn.batch_flatten", "transpose", "reshape", "vision.multibox_prior",
"vision.multibox_transform_loc", "where", "vision.non_max_suppression",
"strided_slice")]
out = node_entry["op"] in _LAYOUT_FIXED_OP or \
("name" in node_entry and node_entry["name"] in input_names)
......@@ -95,9 +96,7 @@ def is_skipped_node(node_entry):
whether node is skipped.
"""
# Operators not counted in graph tuner.
_SKIPPED_OP = ["Tuple"]
return node_entry["op"] in _SKIPPED_OP
return isinstance(node_entry["node"], relay.Tuple)
def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
......
......@@ -28,14 +28,16 @@ import time
import os
import itertools
from collections import OrderedDict
import numpy as np
from .. import build, lower, target as _target
from .. import __version__
from . import task
from .task import ConfigEntity, ApplyHistoryBest
from .measure import MeasureInput, MeasureResult
AUTOTVM_LOG_VERSION = 0.1
AUTOTVM_LOG_VERSION = 0.2
_old_version_warning = True
logger = logging.getLogger('autotvm')
try: # convert unicode to str for python2
......@@ -88,27 +90,30 @@ def encode(inp, result, protocol='json'):
if protocol == 'json':
json_dict = {
"i": (str(inp.target),
inp.task.name, inp.task.args, inp.task.kwargs,
inp.task.workload,
inp.config.to_json_dict()),
"input": (str(inp.target),
inp.task.name, inp.task.args, inp.task.kwargs),
"config": inp.config.to_json_dict(),
"result": (result.costs if result.error_no == 0 else (1e9,),
result.error_no,
result.all_cost,
result.timestamp),
"r": (result.costs if result.error_no == 0 else (1e9,),
result.error_no,
result.all_cost,
result.timestamp),
"version": AUTOTVM_LOG_VERSION,
"v": AUTOTVM_LOG_VERSION
"tvm_version": __version__
}
return json.dumps(json_dict)
if protocol == 'pickle':
row = (str(inp.target),
str(base64.b64encode(pickle.dumps([inp.task.name,
inp.task.args,
inp.task.kwargs,
inp.task.workload])).decode()),
inp.task.kwargs])).decode()),
str(base64.b64encode(pickle.dumps(inp.config)).decode()),
str(base64.b64encode(pickle.dumps(tuple(result))).decode()))
str(base64.b64encode(pickle.dumps(tuple(result))).decode()),
str(AUTOTVM_LOG_VERSION),
str(__version__))
return '\t'.join(row)
raise RuntimeError("Invalid log protocol: " + protocol)
......@@ -119,20 +124,29 @@ def decode(row, protocol='json'):
Parameters
----------
row: str
row : str
a row in the logger file
protocol: str
protocol : str
log protocol, json or pickle
Returns
-------
input: autotvm.tuner.MeasureInput
result: autotvm.tuner.MeasureResult
ret : tuple(autotvm.tuner.MeasureInput, autotvm.tuner.MeasureResult), or None
The tuple of input and result, or None if input uses old version log format.
"""
# pylint: disable=unused-variable
global _old_version_warning
if protocol == 'json':
row = json.loads(row)
tgt, task_name, task_args, task_kwargs, workload, config = row['i']
if 'v' in row and row['v'] == 0.1:
if _old_version_warning:
logger.warning("AutoTVM log version 0.1 is no longer supported.")
_old_version_warning = False
return None
tgt, task_name, task_args, task_kwargs = row["input"]
tgt = _target.create(str(tgt))
def clean_json_to_python(x):
......@@ -148,22 +162,27 @@ def decode(row, protocol='json'):
return x
tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
tsk.workload = clean_json_to_python(workload)
config = ConfigEntity.from_json_dict(config)
config = ConfigEntity.from_json_dict(row["config"])
inp = MeasureInput(tgt, tsk, config)
result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["r"]])
result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["result"]])
config.cost = np.mean(result.costs)
return inp, result
if protocol == 'pickle':
items = row.split("\t")
if len(items) == 4:
if _old_version_warning:
logger.warning("AutoTVM log version 0.1 is no longer supported.")
_old_version_warning = False
return None
tgt = _target.create(items[0])
task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
config = pickle.loads(base64.b64decode(items[2].encode()))
result = pickle.loads(base64.b64decode(items[3].encode()))
result = MeasureResult(*pickle.loads(base64.b64decode(items[3].encode())))
config.cost = np.mean(result.costs)
tsk = task.Task(task_tuple[0], task_tuple[1])
tsk.workload = task_tuple[3]
return MeasureInput(tgt, tsk, config), MeasureResult(*result)
return MeasureInput(tgt, tsk, config), result
raise RuntimeError("Invalid log protocol: " + protocol)
......@@ -183,7 +202,10 @@ def load_from_file(filename):
"""
for row in open(filename):
if row and not row.startswith('#'):
inp, res = decode(row)
ret = decode(row)
if ret is None:
continue
inp, res = ret
# Avoid loading the record with an empty config. The TOPI schedule with no entities
# will result in an empty entity map (e.g., depthwise_conv2d_nchw on x86).
# Using an empty config will cause problems when applying alter op like NCHW to NCHWc.
......@@ -208,7 +230,7 @@ def split_workload(in_file, clean=True):
logger.info("start converting...")
pool = multiprocessing.Pool()
lines = pool.map(decode, lines)
lines = [rec for rec in pool.map(decode, lines) if rec is not None]
logger.info("map done %.2f", time.time() - tic)
wkl_dict = OrderedDict()
......
......@@ -22,12 +22,13 @@ This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest.
"""
from .task import Task, create, register, template, get_config, args_to_workload
from .task import Task, create, get_config, args_to_workload, \
register_customized_task
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
FallbackContext, clear_fallback_cache, ApplyGraphBest
from .topi_integration import register_topi_compute, register_topi_schedule, \
TaskExtractEnv
TaskExtractEnv, get_workload
from .relay_integration import extract_from_program, extract_from_multiple_program
......@@ -33,9 +33,6 @@ from __future__ import absolute_import as _abs
import logging
import numpy as np
from decorator import decorate
from tvm import target as _target
from .space import FallbackConfigEntity
......@@ -152,79 +149,6 @@ class DispatchContext(object):
DispatchContext.current = self._old_ctx
def dispatcher(fworkload):
"""Wrap a workload dispatcher function.
Parameters
----------
fworkload : function
The workload extraction function from arguments.
Returns
-------
fdispatcher : function
A wrapped dispatcher function, which will
dispatch based on DispatchContext and
the current workload.
"""
dispatch_dict = {}
func_name = fworkload.__name__
def register(key, func=None, override=False):
"""Register template function.
Parameters
----------
key : str or List of str
The template key to identify the template
under this dispatcher.
func : function
The function to be registered.
The first argument of the function is always
cfg returned by DispatchContext,
the rest arguments are the same as the fworkload.
override : bool
Whether override existing registration.
Returns
-------
The register function if necessary.
"""
if isinstance(key, str):
key = [key]
def _do_reg(myf):
for x in key:
if x in dispatch_dict and not override:
raise ValueError(
"Key %s is already registered for %s" % (x, func_name))
dispatch_dict[x] = myf
return myf
if func:
return _do_reg(func)
return _do_reg
def dispatch_func(func, *args, **kwargs):
"""The wrapped dispatch function"""
tgt = _target.Target.current()
workload = func(*args, **kwargs)
cfg = DispatchContext.current.query(tgt, workload)
if cfg.is_fallback and not cfg.template_key:
# first try 'direct' template
if 'direct' in dispatch_dict:
return dispatch_dict['direct'](cfg, *args, **kwargs)
# otherwise pick a random template
for v in dispatch_dict.values():
return v(cfg, *args, **kwargs)
else:
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
fdecorate = decorate(fworkload, dispatch_func)
fdecorate.register = register
return fdecorate
class ApplyConfig(DispatchContext):
"""Apply a deterministic config entity for all queries.
......@@ -334,7 +258,8 @@ class ApplyHistoryBest(DispatchContext):
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_model:
return self.best_by_model[key][0].config
inp, _ = self.best_by_model[key]
return inp.config
# then try matching by target key
for k in target.keys:
......@@ -342,13 +267,16 @@ class ApplyHistoryBest(DispatchContext):
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config
inp, _ = self.best_by_targetkey[key]
return inp.config
return None
def update(self, target, workload, cfg):
model = target.model
key = (model, workload)
# assume user provided config is the best
cfg.cost = 0
self._best_user_defined[key] = cfg
for k in target.keys:
......@@ -481,8 +409,12 @@ class ApplyGraphBest(DispatchContext):
"""
if self._counter < len(self._records):
cfg = self._records[self._counter][0].config
wkl = self._records[self._counter][0].task.workload
if workload is not None:
assert wkl == workload
self._counter += 1
self.update(target, workload, cfg)
self.update(target, wkl, cfg)
cfg.workload = wkl
return cfg
key = (str(target), workload)
if key not in self._global_cfg_dict:
......
......@@ -21,10 +21,9 @@ Decorator and utilities for the integration with TOPI and Relay
"""
import threading
import warnings
import logging
import tvm
from .task import create
from .topi_integration import TaskExtractEnv
......@@ -55,8 +54,7 @@ def _lower(mod,
compiler.lower(mod, target=target)
def extract_from_program(mod, params, ops, target, target_host=None,
template_keys=None):
def extract_from_program(mod, params, target, target_host=None, ops=None):
""" Extract tuning tasks from a relay program.
This function is the single program version of extract_from_multiple_program.
......@@ -67,27 +65,22 @@ def extract_from_program(mod, params, ops, target, target_host=None,
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
ops: List of relay op
List of relay ops to be tuned
target: tvm.target.Target
The compilation target
target_host: tvm.target.Target
The host compilation target
template_keys: dict of topi op to str
The tuning template keys map for schedules, default to None.
Example: {topi.nn.conv2d: 'direct'}
ops: List[relay.op.Op] or None
List of relay ops to be tuned. If not specified, all tunable ops will be extracted.
Returns
-------
task: Array of autotvm.task.Task
collected tasks
"""
return extract_from_multiple_program([mod], [params], ops, target, target_host,
template_keys)
return extract_from_multiple_program([mod], [params], target, target_host, ops)
def extract_from_multiple_program(mods, params, ops, target, target_host=None,
template_keys=None):
def extract_from_multiple_program(mods, params, target, target_host=None, ops=None):
""" Extract tuning tasks from multiple relay programs.
This function collects tuning tasks by building a list of programs
......@@ -99,15 +92,12 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
ops: List of relay op
List of relay ops to be tuned
target: tvm.target.Target
The compilation target
target_host: tvm.target.Target
The host compilation target
template_keys: dict of topi op to str
The tuning template keys map for schedules, default to None.
Example: {topi.nn.conv2d: 'direct'}
ops: List[relay.op.Op] or None
List of relay ops to be tuned. If not specified, all tunable ops will be extracted.
Returns
-------
......@@ -115,36 +105,13 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
collected tasks
"""
# pylint: disable=import-outside-toplevel
import tvm.relay.op
from tvm import relay
import topi
env = TaskExtractEnv.get()
# NOTE: To add more ops, you only need to change the following lists
# relay op -> topi compute
OP2TOPI = {
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw,
topi.nn.conv2d_NCHWc,
topi.nn.conv2d_NCHWc_int8],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
tvm.relay.op.nn.conv1d_transpose: [topi.nn.conv1d_transpose_ncw],
tvm.relay.op.nn.conv3d: [topi.nn.conv3d],
}
topi_funcs = []
for op_name in ops:
if op_name in OP2TOPI:
topi_funcs.extend(OP2TOPI[op_name])
else:
warnings.warn("Op %s is not tunable, ignored" % op_name)
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
env.reset(ops)
with env:
# disable logger temporarily
old_state = logger.disabled
......@@ -164,24 +131,12 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
logger.disabled = old_state
# convert *topi op to template key* map to *task name to template key* map
task_name_to_keys = {}
if template_keys is not None:
for op in template_keys.keys():
if op in env.topi_to_task:
task_name_to_keys[env.topi_to_task[op]] = template_keys[op]
else:
logger.warning("Invalid template key, fallback to direct")
task_name_to_keys[env.topi_to_task[op]] = 'direct'
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
try:
key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key=key)
target=target, target_host=target_host)
tasks.append(tsk)
except topi.InvalidShapeError:
logger.warning("Invalid shape during AutoTVM task creation")
......
......@@ -613,9 +613,9 @@ class ConfigSpace(object):
self._entity_map = OrderedDict() # name -> entity
self._constraints = []
self.errors = []
self.template_key = None
self.code_hash = None
self.flop = 0
self.cost = None
self.is_fallback = False
@staticmethod
......@@ -796,7 +796,7 @@ class ConfigSpace(object):
for name, space in self.space_map.items():
entities[name] = space[t % len(space)]
t //= len(space)
ret = ConfigEntity(index, self.code_hash, self.template_key, entities, self._constraints)
ret = ConfigEntity(index, self.code_hash, entities, self._constraints)
return ret
def __iter__(self):
......@@ -836,17 +836,14 @@ class ConfigEntity(ConfigSpace):
index of this config in space
code_hash: str
hash of schedule code
template_key : str
The specific template key
entity_map: dict
map name to transform entity
constraints : list
List of constraints
"""
def __init__(self, index, code_hash, template_key, entity_map, constraints):
def __init__(self, index, code_hash, entity_map, constraints):
super(ConfigEntity, self).__init__()
self.index = index
self.template_key = template_key
self._collect = False
self._entity_map = entity_map
self._space_map = None
......@@ -896,9 +893,8 @@ class ConfigEntity(ConfigSpace):
a json serializable dictionary
"""
ret = {}
ret['i'] = int(self.index)
ret['t'] = self.template_key
ret['c'] = self.code_hash
ret['index'] = int(self.index)
ret['code_hash'] = self.code_hash
entity_map = []
for k, v in self._entity_map.items():
if isinstance(v, SplitEntity):
......@@ -911,7 +907,7 @@ class ConfigEntity(ConfigSpace):
entity_map.append((k, 'ot', v.val))
else:
raise RuntimeError("Invalid entity instance: " + v)
ret['e'] = entity_map
ret['entity'] = entity_map
return ret
@staticmethod
......@@ -930,13 +926,12 @@ class ConfigEntity(ConfigSpace):
The corresponding config object
"""
index = json_dict["i"]
code_hash = json_dict["c"]
template_key = json_dict["t"]
index = json_dict["index"]
code_hash = json_dict["code_hash"]
constraints = []
entity_map = OrderedDict()
for item in json_dict["e"]:
for item in json_dict["entity"]:
key, knob_type, knob_args = item
if knob_type == 'sp':
entity = SplitEntity(knob_args)
......@@ -950,11 +945,10 @@ class ConfigEntity(ConfigSpace):
raise RuntimeError("Invalid config knob type: " + knob_type)
entity_map[str(key)] = entity
return ConfigEntity(index, code_hash, template_key, entity_map, constraints)
return ConfigEntity(index, code_hash, entity_map, constraints)
def __repr__(self):
return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key,
self.code_hash, self.index)
return "%s,%s,%d" % (str(self._entity_map)[12:-1], self.code_hash, self.index)
class FallbackConfigEntity(ConfigSpace):
......@@ -1068,4 +1062,4 @@ class FallbackConfigEntity(ConfigSpace):
self._entity_map[name] = entity
def __repr__(self):
return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
return "%s,%s" % (str(self._entity_map)[12:-1], self.code_hash)
......@@ -46,16 +46,16 @@ AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub
# the version of each package
PACKAGE_VERSION = {
'arm_cpu': "v0.04",
'llvm': "v0.03",
'arm_cpu': "v0.06",
'llvm': "v0.04",
'cuda': "v0.06",
'rocm': "v0.03",
'opencl': "v0.03",
'mali': "v0.05",
'intel_graphics': "v0.01",
'cuda': "v0.08",
'rocm': "v0.04",
'opencl': "v0.04",
'mali': "v0.06",
'intel_graphics': "v0.02",
'vta': "v0.06",
'vta': "v0.08",
}
logger = logging.getLogger('autotvm')
......@@ -189,7 +189,7 @@ def download_package(tophub_location, package_name):
# global cache for load_reference_log
REFERENCE_LOG_CACHE = {}
def load_reference_log(backend, model, workload_name, template_key):
def load_reference_log(backend, model, workload_name):
""" Load reference log from TopHub to support fallback in template.
Template will use these reference logs to choose fallback config.
......@@ -201,8 +201,6 @@ def load_reference_log(backend, model, workload_name, template_key):
The name of the device model
workload_name: str
The name of the workload. (The first item in the workload tuple)
template_key: str
The template key
"""
backend = _alias(backend)
......@@ -211,7 +209,7 @@ def load_reference_log(backend, model, workload_name, template_key):
filename = os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)
global REFERENCE_LOG_CACHE
key = (backend, model, workload_name, template_key)
key = (backend, model, workload_name)
if key not in REFERENCE_LOG_CACHE:
tmp = []
......@@ -233,8 +231,7 @@ def load_reference_log(backend, model, workload_name, template_key):
model = max(counts.items(), key=lambda k: k[1])[0]
for inp, res in load_from_file(filename):
if (model == inp.target.model and inp.task.workload[0] == workload_name and
inp.config.template_key == template_key):
if model == inp.target.model and inp.task.workload[0] == workload_name:
tmp.append((inp, res))
REFERENCE_LOG_CACHE[key] = tmp
......
......@@ -219,8 +219,7 @@ class XGBoostCostModel(CostModel):
# filter data, only pick the data with a same task
data = []
for inp, res in records:
if inp.task.name == self.task.name and \
inp.config.template_key == self.task.config_space.template_key:
if inp.task.name == self.task.name:
data.append((inp, res))
logger.debug("XGB load %d entries from history log file", len(data))
......
......@@ -14,18 +14,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=len-as-condition,no-else-return,invalid-name
"""Backend code generation engine."""
from __future__ import absolute_import
import logging
import numpy as np
import tvm
from ..base import register_relay_node, Object
from ... import target as _target
from ... import autotvm
from .. import expr as _expr
from .. import op as _op
from .. import ty as _ty
from . import _backend
logger = logging.getLogger('compile_engine')
@register_relay_node
class CachedFunc(Object):
"""Low-level tensor function to back a relay primitive function.
"""
class LoweredOutput(Object):
"""Lowered output"""
def __init__(self, outputs, implement):
self.__init_handle_by_constructor__(
_backend._make_LoweredOutput, outputs, implement)
@register_relay_node
......@@ -63,6 +75,191 @@ def _get_cache_key(source_func, target):
return source_func
def get_shape(shape):
"""Convert the shape to correct dtype and vars."""
ret = []
for dim in shape:
if isinstance(dim, tvm.expr.IntImm):
val = int(dim)
assert val <= np.iinfo(np.int32).max
ret.append(tvm.expr.IntImm("int32", val))
elif isinstance(dim, tvm.expr.Any):
ret.append(tvm.var("any_dim", "int32"))
else:
ret.append(dim)
return ret
def get_valid_implementations(op, attrs, inputs, out_type, target):
"""Get all valid implementations from the op strategy.
Note that this function doesn't support op with symbolic input shapes.
Parameters
----------
op : relay.op.Op
Relay operator.
attrs : object
The op attribute.
inputs : List[tvm.Tensor]
Input tensors to the op.
out_type : relay.Type
The output type.
target : tvm.target.Target
The target to compile the op.
Returns
-------
ret : List[relay.op.OpImplementation]
The list of all valid op implementations.
"""
fstrategy = op.get_attr("FTVMStrategy")
assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
with target:
strategy = fstrategy(attrs, inputs, out_type, target)
analyzer = tvm.arith.Analyzer()
ret = []
for spec in strategy.specializations:
if spec.condition:
# check if all the clauses in the specialized condition are true
flag = True
for clause in spec.condition.clauses:
clause = analyzer.canonical_simplify(clause)
if isinstance(clause, tvm.expr.IntImm) and clause.value:
continue
flag = False
break
if flag:
for impl in spec.implementations:
ret.append(impl)
else:
for impl in spec.implementations:
ret.append(impl)
return ret
def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True):
"""Select the best implementation from the op strategy.
If use_autotvm is True, it'll first try to find the best implementation
based on AutoTVM profile results. If no AutoTVM profile result is found,
it'll choose the implementation with highest plevel.
If use_autotvm is False, it'll directly choose the implementation with
highest plevel.
Note that this function doesn't support op with symbolic input shapes.
Parameters
----------
op : relay.op.Op
Relay operator.
attrs : object
The op attribute.
inputs : List[tvm.Tensor]
Input tensors to the op.
out_type : relay.Type
The output type.
target : tvm.target.Target
The target to compile the op.
use_autotvm : bool
Whether query AutoTVM to pick the best.
Returns
-------
ret : tuple(relay.op.OpImplementation, List[tvm.Tensor])
The best op implementation and the corresponding output tensors.
"""
all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
best_plevel_impl = None
for impl in all_impls:
if best_plevel_impl is None or impl.plevel > best_plevel_impl.plevel:
best_plevel_impl = impl
if not use_autotvm:
outs = best_plevel_impl.compute(attrs, inputs, out_type)
return best_plevel_impl, outs
outputs = {}
best_autotvm_impl = None
best_cfg = None
dispatch_ctx = autotvm.task.DispatchContext.current
for impl in all_impls:
outs = impl.compute(attrs, inputs, out_type)
outputs[impl] = outs
workload = autotvm.task.get_workload(outs)
if workload is None:
continue
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback:
# It's a fallback config
continue
if best_cfg is None or best_cfg.cost > cfg.cost:
best_autotvm_impl = impl
best_cfg = cfg
if best_autotvm_impl:
return best_autotvm_impl, outputs[best_autotvm_impl]
return best_plevel_impl, outputs[best_plevel_impl]
@tvm._ffi.register_func("relay.backend.lower_call")
def lower_call(call, inputs, target):
"""Lower the call expression to op implementation and tensor outputs."""
assert isinstance(call.op, _op.Op)
op = call.op
# Prepare the call_node->checked_type(). For the call node inputs, we ensure that
# the shape is Int32. Following code ensures the same for the output as well.
# TODO(@icemelon9): Support recursive tuple
ret_type = call.checked_type
if isinstance(ret_type, _ty.TensorType):
ret_type = _ty.TensorType(get_shape(ret_type.shape), ret_type.dtype)
elif isinstance(ret_type, _ty.TupleType):
new_fields = []
for field in ret_type.fields:
if isinstance(field, _ty.TensorType):
new_fields.append(_ty.TensorType(get_shape(field.shape), field.dtype))
else:
new_fields.append(field)
ret_type = _ty.TupleType(new_fields)
is_dyn = _ty.type_has_any(call.checked_type)
for arg in call.args:
is_dyn = is_dyn or _ty.type_has_any(arg.checked_type)
# check if in the AutoTVM tracing mode, and disable if op is not in wanted list
env = autotvm.task.TaskExtractEnv.current
reenable_tracing = False
if env is not None and env.tracing:
if env.wanted_relay_ops is not None and op not in env.wanted_relay_ops:
env.tracing = False
reenable_tracing = True
if not is_dyn:
best_impl, outputs = select_implementation(
op, call.attrs, inputs, ret_type, target)
logger.info("Use implementation %s for op %s", best_impl.name, op.name)
else:
# TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes.
# Currently, we just use the implementation with highest plevel
best_impl, outputs = select_implementation(
op, call.attrs, inputs, ret_type, target, use_autotvm=False)
# re-enable AutoTVM tracing
if reenable_tracing:
env.tracing = True
return LoweredOutput(outputs, best_impl)
@register_relay_node
class CompileEngine(Object):
"""CompileEngine to get lowered code.
......
......@@ -131,22 +131,22 @@ class ExprVisitor(ExprFunctor):
The default behavior recursively traverses the AST.
"""
def visit_tuple(self, t):
for x in t.fields:
def visit_tuple(self, tup):
for x in tup.fields:
self.visit(x)
def visit_call(self, c):
self.visit(c.op)
for a in c.args:
def visit_call(self, call):
self.visit(call.op)
for a in call.args:
self.visit(a)
def visit_var(self, v):
def visit_var(self, var):
pass
def visit_let(self, l):
self.visit(l.var)
self.visit(l.value)
self.visit(l.body)
def visit_let(self, let):
self.visit(let.var)
self.visit(let.value)
self.visit(let.body)
def visit_function(self, f):
self.visit(f.body)
......
......@@ -311,6 +311,7 @@ def _conv(opname):
flip_layout = True
if attr['data_format'] == 'NHWC':
in_channels = input_shape[3]
kernel_h, kernel_w, _, depth_mult = weights_shape
attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
if opname == 'conv':
......@@ -324,6 +325,7 @@ def _conv(opname):
attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW':
in_channels = input_shape[1]
_, depth_mult, kernel_h, kernel_w = weights_shape
attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
if opname == 'conv':
......@@ -344,7 +346,7 @@ def _conv(opname):
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if opname == 'depthwise':
attr['groups'] = attr['channels']
attr['groups'] = in_channels
# Fix padding
attr['padding'] = attr['padding'].decode("utf-8")
......
......@@ -1156,7 +1156,7 @@ class OperatorConverter(object):
if is_depthwise_conv:
params['channels'] = int(in_channels)
params['groups'] = int(in_channels)
params['groups'] = int(input_c)
params['kernel_layout'] = 'HWOI'
else:
params['channels'] = int(output_channels)
......
......@@ -28,8 +28,8 @@ from .backend import compile_engine
def is_primitive(call):
return hasattr(call.op, 'attrs') and hasattr(call.op.attrs, 'Primitive') and \
int(call.op.attrs.Primitive) == 1
return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \
hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1
# TODO(@jroesch): port to c++ and unify with existing code
class LinearizeRetType:
......
......@@ -17,9 +17,10 @@
#pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators."""
# operator defs
from .op import get, register, register_schedule, register_compute, register_gradient, \
from .op import get, register, register_compute, register_gradient, \
register_pattern, register_alter_op_layout, register_legalize, \
schedule_injective, Op, OpPattern, debug
Op, OpPattern, OpStrategy, debug
from . import strategy
# Operators
from .reduce import *
......
......@@ -18,48 +18,14 @@
# pylint: disable=invalid-name,unused-argument
from __future__ import absolute_import
import topi
from topi.util import get_const_int
from ..op import OpPattern, register_compute, register_schedule, register_pattern
@register_schedule("argsort")
def schedule_argsort(_, outs, target):
"""Schedule definition of argsort"""
with target:
return topi.generic.schedule_argsort(outs)
@register_compute("argsort")
def compute_argsort(attrs, inputs, _, target):
"""Compute definition of argsort"""
axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
from . import strategy
from .op import OpPattern, register_pattern
from .op import register_strategy
# argsort
register_strategy("argsort", strategy.argsort_strategy)
register_pattern("argsort", OpPattern.OPAQUE)
@register_schedule("topk")
def schedule_topk(_, outs, target):
"""Schedule definition of argsort"""
with target:
return topi.generic.schedule_topk(outs)
@register_compute("topk")
def compute_topk(attrs, inputs, _, target):
"""Compute definition of argsort"""
k = get_const_int(attrs.k)
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype)
out = out if isinstance(out, list) else [out]
return out
# topk
register_strategy("topk", strategy.topk_strategy)
register_pattern("topk", OpPattern.OPAQUE)
......@@ -17,33 +17,21 @@
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import topi
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ...api import convert
from ...hybrid import script
def _schedule_reduce(_, outs, target):
"""Generic schedule for reduce"""
with target:
return topi.generic.schedule_reduce(outs)
_reg.register_schedule("argmax", _schedule_reduce)
_reg.register_schedule("argmin", _schedule_reduce)
_reg.register_schedule("sum", _schedule_reduce)
_reg.register_schedule("all", _schedule_reduce)
_reg.register_schedule("any", _schedule_reduce)
_reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce)
_reg.register_schedule("mean", _schedule_reduce)
_reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)
_reg.register_reduce_schedule("argmax")
_reg.register_reduce_schedule("argmin")
_reg.register_reduce_schedule("sum")
_reg.register_reduce_schedule("all")
_reg.register_reduce_schedule("any")
_reg.register_reduce_schedule("max")
_reg.register_reduce_schedule("min")
_reg.register_reduce_schedule("prod")
_reg.register_reduce_schedule("mean")
_reg.register_reduce_schedule("variance")
def _create_axis_record(attrs, inputs):
axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis))
......
......@@ -19,101 +19,99 @@
from __future__ import absolute_import
import topi
from topi.util import get_const_tuple
from .op import register_compute, register_schedule, register_pattern, register_shape_func
from .op import schedule_injective, OpPattern
from .op import register_compute, register_shape_func
from .op import register_broadcast_schedule, register_injective_schedule
from .op import register_pattern, OpPattern
from ...hybrid import script
from ...api import convert
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
register_schedule("log", schedule_broadcast)
register_schedule("cos", schedule_broadcast)
register_schedule("sin", schedule_broadcast)
register_schedule("atan", schedule_broadcast)
register_schedule("exp", schedule_broadcast)
register_schedule("erf", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast)
register_schedule("sigmoid", schedule_broadcast)
register_schedule("floor", schedule_broadcast)
register_schedule("ceil", schedule_broadcast)
register_schedule("trunc", schedule_broadcast)
register_schedule("round", schedule_broadcast)
register_schedule("sign", schedule_broadcast)
register_schedule("abs", schedule_broadcast)
register_schedule("tanh", schedule_broadcast)
register_schedule("logical_not", schedule_broadcast)
register_schedule("bitwise_not", schedule_broadcast)
register_schedule("negative", schedule_broadcast)
register_schedule("copy", schedule_broadcast)
register_schedule("add", schedule_broadcast)
register_schedule("subtract", schedule_broadcast)
register_schedule("multiply", schedule_broadcast)
register_schedule("divide", schedule_broadcast)
register_schedule("floor_divide", schedule_broadcast)
register_schedule("power", schedule_injective)
register_schedule("mod", schedule_broadcast)
register_schedule("floor_mod", schedule_broadcast)
register_schedule("logical_and", schedule_broadcast)
register_schedule("logical_or", schedule_broadcast)
register_schedule("bitwise_and", schedule_broadcast)
register_schedule("bitwise_or", schedule_broadcast)
register_schedule("bitwise_xor", schedule_broadcast)
register_schedule("equal", schedule_broadcast)
register_schedule("not_equal", schedule_broadcast)
register_schedule("less", schedule_broadcast)
register_schedule("less_equal", schedule_broadcast)
register_schedule("greater", schedule_broadcast)
register_schedule("greater_equal", schedule_broadcast)
register_schedule("maximum", schedule_injective)
register_schedule("minimum", schedule_injective)
register_schedule("right_shift", schedule_injective)
register_schedule("left_shift", schedule_injective)
register_schedule("shape_of", schedule_injective)
register_broadcast_schedule("log")
register_broadcast_schedule("cos")
register_broadcast_schedule("sin")
register_broadcast_schedule("atan")
register_broadcast_schedule("exp")
register_broadcast_schedule("erf")
register_broadcast_schedule("sqrt")
register_broadcast_schedule("rsqrt")
register_broadcast_schedule("sigmoid")
register_broadcast_schedule("floor")
register_broadcast_schedule("ceil")
register_broadcast_schedule("trunc")
register_broadcast_schedule("round")
register_broadcast_schedule("sign")
register_broadcast_schedule("abs")
register_broadcast_schedule("tanh")
register_broadcast_schedule("add")
register_broadcast_schedule("subtract")
register_broadcast_schedule("multiply")
register_broadcast_schedule("divide")
register_broadcast_schedule("floor_divide")
register_broadcast_schedule("power")
register_broadcast_schedule("copy")
register_broadcast_schedule("logical_not")
register_broadcast_schedule("logical_and")
register_broadcast_schedule("logical_or")
register_broadcast_schedule("bitwise_not")
register_broadcast_schedule("bitwise_and")
register_broadcast_schedule("bitwise_or")
register_broadcast_schedule("bitwise_xor")
register_broadcast_schedule("negative")
register_broadcast_schedule("mod")
register_broadcast_schedule("floor_mod")
register_broadcast_schedule("equal")
register_broadcast_schedule("not_equal")
register_broadcast_schedule("less")
register_broadcast_schedule("less_equal")
register_broadcast_schedule("greater")
register_broadcast_schedule("greater_equal")
register_injective_schedule("maximum")
register_injective_schedule("minimum")
register_injective_schedule("right_shift")
register_injective_schedule("left_shift")
register_injective_schedule("shape_of")
# zeros
@register_compute("zeros")
def zeros_compute(attrs, inputs, output_type, target):
def zeros_compute(attrs, inputs, output_type):
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 0.0)]
register_schedule("zeros", schedule_broadcast)
register_broadcast_schedule("zeros")
register_pattern("zeros", OpPattern.ELEMWISE)
# zeros_like
@register_compute("zeros_like")
def zeros_like_compute(attrs, inputs, output_type, target):
def zeros_like_compute(attrs, inputs, output_type):
assert len(inputs) == 1
return [topi.full_like(inputs[0], 0.0)]
register_schedule("zeros_like", schedule_broadcast)
register_broadcast_schedule("zeros_like")
# ones
@register_compute("ones")
def ones_compute(attrs, inputs, output_type, target):
def ones_compute(attrs, inputs, output_type):
assert not inputs
return [topi.full(output_type.shape, output_type.dtype, 1.0)]
register_schedule("ones", schedule_broadcast)
register_broadcast_schedule("ones")
register_pattern("ones", OpPattern.ELEMWISE)
# ones_like
@register_compute("ones_like")
def ones_like(attrs, inputs, output_type, target):
def ones_like_compute(attrs, inputs, output_type):
assert len(inputs) == 1
return [topi.full_like(inputs[0], 1.0)]
register_schedule("ones_like", schedule_broadcast)
register_broadcast_schedule("ones_like")
# clip
@register_compute("clip")
def clip_compute(attrs, inputs, output_type, target):
def clip_compute(attrs, inputs, output_type):
assert len(inputs) == 1
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
register_schedule("clip", schedule_elemwise)
register_injective_schedule("clip")
@script
def _cast_shape_function(x):
......@@ -198,6 +196,7 @@ register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("floor_mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("bitwise_not", False, broadcast_shape_func)
register_shape_func("bitwise_and", False, broadcast_shape_func)
register_shape_func("bitwise_or", False, broadcast_shape_func)
register_shape_func("bitwise_xor", False, broadcast_shape_func)
......
......@@ -21,52 +21,74 @@ import tvm
import topi
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ._reduce import _schedule_reduce
from . import strategy
from .op import OpPattern
from ...hybrid import script
from ...api import convert
schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective
schedule_concatenate = _reg.schedule_concatenate
_reg.register_schedule("collapse_sum_like", _schedule_reduce)
_reg.register_schedule("broadcast_to", schedule_broadcast)
_reg.register_schedule("broadcast_to_like", schedule_broadcast)
_reg.register_schedule("expand_dims", schedule_broadcast)
_reg.register_schedule("squeeze", schedule_injective)
_reg.register_schedule("reshape", schedule_injective)
_reg.register_schedule("reshape_like", schedule_injective)
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("arange", schedule_injective)
_reg.register_schedule("reverse", schedule_injective)
_reg.register_schedule("repeat", schedule_broadcast)
_reg.register_schedule("tile", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("cast_like", schedule_injective)
_reg.register_schedule("reinterpret", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("strided_set", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast)
_reg.register_schedule("stack", schedule_injective)
_reg.register_schedule("concatenate", schedule_concatenate)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective)
_reg.register_schedule("sequence_mask", schedule_injective)
_reg.register_schedule("one_hot", schedule_injective)
_reg.register_broadcast_schedule("broadcast_to")
_reg.register_broadcast_schedule("broadcast_to_like")
_reg.register_broadcast_schedule("expand_dims")
_reg.register_broadcast_schedule("repeat")
_reg.register_broadcast_schedule("tile")
_reg.register_broadcast_schedule("where")
_reg.register_injective_schedule("squeeze")
_reg.register_injective_schedule("reshape")
_reg.register_injective_schedule("reshape_like")
_reg.register_injective_schedule("full")
_reg.register_injective_schedule("full_like")
_reg.register_injective_schedule("arange")
_reg.register_injective_schedule("reverse")
_reg.register_injective_schedule("cast")
_reg.register_injective_schedule("cast_like")
_reg.register_injective_schedule("reinterpret")
_reg.register_injective_schedule("strided_slice")
_reg.register_injective_schedule("slice_like")
_reg.register_injective_schedule("split")
_reg.register_injective_schedule("take")
_reg.register_injective_schedule("transpose")
_reg.register_injective_schedule("stack")
_reg.register_injective_schedule("_contrib_reverse_reshape")
_reg.register_injective_schedule("gather_nd")
_reg.register_injective_schedule("sequence_mask")
_reg.register_injective_schedule("one_hot")
_reg.register_reduce_schedule("collapse_sum_like")
# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
# strided_set
@_reg.register_compute("strided_set")
def compute_strided_set(attrs, inputs, output_type):
"""Compute definition of strided_set"""
return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])]
_reg.register_injective_schedule("strided_set")
# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
_reg.register_injective_schedule("layout_transform")
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
# shape func
# argwhere
@_reg.register_compute("argwhere")
def compute_argwhere(attrs, inputs, output_type):
"""Compute definition of argwhere"""
output_shape = []
for s in output_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
# see Any, replace it with a var
output_shape.append(tvm.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]
_reg.register_schedule("argwhere", strategy.schedule_argwhere)
#####################
# Shape functions #
#####################
@script
def _arange_shape_func(start, stop, step):
out = output_tensor((1,), "int64")
......@@ -284,31 +306,6 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
return [_argwhere_shape_func_5d(inputs[0])]
return ValueError("Does not support rank higher than 5 in argwhere")
@_reg.register_schedule("argwhere")
def schedule_argwhere(_, outs, target):
"""Schedule definition of argwhere"""
with target:
return topi.generic.schedule_argwhere(outs)
@_reg.register_compute("argwhere")
def compute_argwhere(attrs, inputs, output_type, _):
"""Compute definition of argwhere"""
output_shape = []
for s in output_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
# see Any, replace it with a var
output_shape.append(tvm.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]
@_reg.register_compute("strided_set")
def compute_strided_set(attrs, inputs, output_type, _):
"""Compute definition of strided_set"""
return [topi.strided_set(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4])]
@script
def _layout_transform_shape_func(data_shape,
out_layout_len,
......
......@@ -19,7 +19,7 @@ from tvm.runtime import ndarray as _nd
from tvm.runtime import TVMContext as _TVMContext
from . import _make
from ..op import register_schedule, schedule_injective
from .. import op as reg
def on_device(data, device):
......@@ -79,7 +79,7 @@ def checkpoint(data):
"""
return _make.checkpoint(data)
register_schedule("annotation.checkpoint", schedule_injective)
reg.register_injective_schedule("annotation.checkpoint")
def compiler_begin(data, compiler):
......
......@@ -18,29 +18,19 @@
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import topi
from .. import op as reg
from ..op import schedule_injective, OpPattern
from .. import strategy
from ..op import OpPattern
# adaptive_max_pool2d
@reg.register_schedule("contrib.adaptive_max_pool2d")
def schedule_adaptive_max_pool2d(_, outs, target):
"""Schedule definition of adaptive_max_pool2d"""
with target:
return topi.generic.schedule_adaptive_pool(outs)
reg.register_schedule("contrib.adaptive_max_pool2d", strategy.schedule_adaptive_pool)
reg.register_pattern("contrib.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# adaptive_avg_pool2d
@reg.register_schedule("contrib.adaptive_avg_pool2d")
def schedule_adaptive_avg_pool2d(_, outs, target):
"""Schedule definition of adaptive_avg_pool2d"""
with target:
return topi.generic.schedule_adaptive_pool(outs)
reg.register_schedule("contrib.adaptive_avg_pool2d", strategy.schedule_adaptive_pool)
reg.register_pattern("contrib.adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# relay.contrib.ndarray_size
reg.register_schedule("contrib.ndarray_size", schedule_injective)
reg.register_injective_schedule("contrib.ndarray_size")
......@@ -20,13 +20,10 @@ from __future__ import absolute_import
import topi
from .. import op as reg
from ..op import schedule_injective
# resize
reg.register_schedule("image.resize", schedule_injective)
@reg.register_compute("image.resize")
def compute_resize(attrs, inputs, out_type, target):
def compute_resize(attrs, inputs, out_type):
size = attrs.size
layout = attrs.layout
method = attrs.method
......@@ -34,12 +31,12 @@ def compute_resize(attrs, inputs, out_type, target):
out_dtype = attrs.out_dtype
return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)]
reg.register_injective_schedule("image.resize")
# crop and resize
reg.register_schedule("image.crop_and_resize", schedule_injective)
# crop and resize
@reg.register_compute("image.crop_and_resize")
def compute_crop_and_resize(attrs, inputs, out_type, target):
def compute_crop_and_resize(attrs, inputs, out_type):
crop_size = attrs.crop_size
layout = attrs.layout
method = attrs.method
......@@ -48,3 +45,5 @@ def compute_crop_and_resize(attrs, inputs, out_type, target):
return [topi.image.crop_and_resize(inputs[0], inputs[1], inputs[2],
crop_size, layout, method,
extrapolation_value, out_dtype)]
reg.register_injective_schedule("image.crop_and_resize")
......@@ -204,7 +204,6 @@ def conv2d(data,
# TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.conv2d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
......@@ -298,7 +297,6 @@ def conv3d(data,
dilation = (dilation, dilation, dilation)
if isinstance(padding, int):
padding = (padding, padding, padding)
return _make.conv3d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
......@@ -1772,74 +1770,6 @@ def contrib_conv2d_winograd_without_weight_transform(data,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_winograd_nnpack_without_weight_transform(data,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""2D convolution with the NNPACK implementation of winograd algorithm.
The basic parameters are the same as the ones in vanilla conv2d.
It assumes the weight is pre-transformed by nn.contrib_conv2d_winograd_nnpack_weight_transform
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : tuple of int, optional
The strides of convolution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.contrib_conv2d_winograd_nnpack_without_weight_transform(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_nchwc(data,
kernel,
strides=(1, 1),
......@@ -1974,73 +1904,6 @@ def contrib_depthwise_conv2d_nchwc(data,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_nchwc_int8(data,
kernel,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW8c",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""Variant of 2D convolution. It deals with only int8 inputs.
This operator takes the weight as the convolution kernel
and convolves it with data to produce an output, following a specialized
NCHWc data layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
kernel : tvm.relay.Expr
The kernel expressions.
strides : tuple of int, optional
The strides of convolution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 2-way padding to 4-way padding
padding = get_pad_tuple2d(padding)
return _make.contrib_conv2d_NCHWc_int8(data, kernel, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_winograd_weight_transform(weight,
tile_size):
......
......@@ -14,15 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=unused-argument
#pylint: disable=unused-argument,invalid-name
"""The base node types for the Relay language."""
import topi
import tvm._ffi
from tvm.driver import lower, build
from ..base import register_relay_node
from ..expr import RelayExpr
from ...api import register_func
from ...target import get_native_generic_func, GenericFunc
from ...runtime import Object
from . import _make
@register_relay_node
......@@ -143,39 +144,208 @@ class OpPattern(object):
OPAQUE = 8
def register_schedule(op_name, schedule=None, level=10):
"""Register schedule function for an op
@tvm._ffi.register_object("relay.OpImplementation")
class OpImplementation(Object):
"""Operator implementation"""
def compute(self, attrs, inputs, out_type):
"""Call compute function.
Parameters
----------
attrs : Attrs
Op attributes.
inputs : list[tvm.tensor.Tensor]
The input tensors.
out_type : relay.Type
The output type.
Returns
-------
outs : list[tvm.tensor.Tensor]
The output tensors.
"""
return _OpImplementationCompute(self, attrs, inputs, out_type)
def schedule(self, attrs, outs, target):
"""Call schedule function.
Parameters
----------
attrs : Attrs
Op attributes.
outs : list[tvm.tensor.Tensor]
The output tensors.
target : tvm.target.Target
The target to schedule the op.
Returns
-------
schedule : tvm.Schedule
The schedule.
"""
return _OpImplementationSchedule(self, attrs, outs, target)
@tvm._ffi.register_object("relay.OpSpecialization")
class OpSpecialization(Object):
"""Operator specialization"""
@tvm._ffi.register_object("relay.OpStrategy")
class OpStrategy(Object):
"""Operator strategy"""
def __init__(self):
self.__init_handle_by_constructor__(_make.OpStrategy)
def add_implementation(self, compute, schedule, name="default", plevel=10):
"""Add an implementation to the strategy
Parameters
----------
compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type)
-> List[Tensor]
The compute function.
schedule : function (attrs: Attrs, outs: List[Tensor], target:Target) -> Schedule
The schedule function.
name : str
The name of implementation.
plevel : int
The priority level of implementation.
"""
_OpStrategyAddImplementation(self, compute, schedule, name, plevel)
def _wrap_default_fstrategy(compute, schedule, name):
def _fstrategy(attrs, inputs, out_type, target):
strategy = OpStrategy()
strategy.add_implementation(compute, schedule, name=name)
return strategy
return _fstrategy
def _create_fstrategy_from_schedule(op_name, schedule):
assert hasattr(schedule, "dispatch_dict")
compute = get(op_name).get_attr("FTVMCompute")
assert compute is not None, "FTVMCompute is not registered for op %s" % op_name
fstrategy = get_native_generic_func("{}_strategy".format(op_name))
name_pfx = schedule.__name__
name_pfx = name_pfx[name_pfx.index('_')+1:]
fstrategy.set_default(
_wrap_default_fstrategy(compute, schedule.fdefault, "%s.generic" % name_pfx))
for key, sch in schedule.dispatch_dict.items():
fstrategy.register(
_wrap_default_fstrategy(compute, sch, "%s.%s" % (name_pfx, key)), [key])
return fstrategy
def register_compute(op_name, compute=None, level=10):
"""Register compute function for an op.
Parameters
----------
op_name : str
The name of the op.
schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule
The schedule function.
compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type)
-> List[Tensor]
The compute function.
level : int
The priority level
"""
return register(op_name, "FTVMSchedule", schedule, level)
return register(op_name, "FTVMCompute", compute, level)
def register_compute(op_name, compute=None, level=10):
"""Register compute function for an op.
def register_strategy(op_name, fstrategy=None, level=10):
"""Register strategy function for an op.
Parameters
----------
op_name : str
The name of the op.
compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target)
-> List[Tensor]
The compute function.
fstrategy : function (attrs: Attrs, inputs: List[Tensor], out_type: Type,
target:Target) -> OpStrategy
The strategy function. Need to be native GenericFunc.
level : int
The priority level
"""
return register(op_name, "FTVMCompute", compute, level)
if not isinstance(fstrategy, GenericFunc):
assert hasattr(fstrategy, "generic_func_node")
fstrategy = fstrategy.generic_func_node
return register(op_name, "FTVMStrategy", fstrategy, level)
def register_schedule(op_name, schedule, level=10):
"""Register schedule function for an op.
This is used when compute function is the same for all targets and only
schedule is different. It requires FTVMCompute is already registered to
the op.
Parameters
----------
op_name : str
The name of the op.
schedule : function (attrs: Attrs, outs: List[Tensor], target:Target) -> Schedule
The schedule function. Need to be target.generic_func.
level : int
The priority level
"""
fstrategy = _create_fstrategy_from_schedule(op_name, schedule)
return register_strategy(op_name, fstrategy, level)
def register_injective_schedule(op_name, level=10):
"""Register injective schedule function for an op.
Parameters
----------
op_name : str
The name of the op.
level : int
The priority level
"""
return register_schedule(op_name, _schedule_injective, level)
def register_broadcast_schedule(op_name, level=10):
"""Register broadcast schedule function for an op.
Parameters
----------
op_name : str
The name of the op.
level : int
The priority level
"""
return register_schedule(op_name, _schedule_injective, level)
def register_reduce_schedule(op_name, level=10):
"""Register reduce schedule function for an op.
Parameters
----------
op_name : str
The name of the op.
level : int
The priority level
"""
return register_schedule(op_name, _schedule_reduce, level)
def register_alter_op_layout(op_name, alter_layout=None, level=10):
......@@ -245,6 +415,7 @@ def register_pattern(op_name, pattern, level=10):
"""
return register(op_name, "TOpPattern", pattern, level)
def register_gradient(op_name, fgradient=None, level=10):
"""Register operator pattern for an op.
......@@ -261,6 +432,7 @@ def register_gradient(op_name, fgradient=None, level=10):
"""
return register(op_name, "FPrimalGradient", fgradient, level)
def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
"""Register operator shape function for an op.
......@@ -290,18 +462,8 @@ def _lower(name, schedule, inputs, outputs):
def _build(lowered_funcs):
return build(lowered_funcs, target="llvm")
def schedule_injective(attrs, outputs, target):
"""Generic schedule for binary broadcast."""
with target:
return topi.generic.schedule_injective(outputs)
def schedule_concatenate(attrs, outputs, target):
"""Generic schedule for concatinate."""
with target:
return topi.generic.schedule_concatenate(outputs)
_schedule_injective = None
_schedule_reduce = None
__DEBUG_COUNTER__ = 0
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Relay op strategies."""
from __future__ import absolute_import as _abs
from .generic import *
from . import x86
from . import arm_cpu
from . import cuda
from . import hls
from . import mali
from . import bifrost
from . import opengl
from . import rocm
from . import intel_graphics
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of bifrost operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import re
import topi
from .generic import *
from .. import op as _op
@conv2d_strategy.register("bifrost")
def conv2d_strategy_bifrost(attrs, inputs, out_type, target):
"""conv2d mali(bifrost) strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
stride_h, stride_w = attrs.get_int_tuple("strides")
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
if kernel_layout == "OIHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.bifrost.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.bifrost")
_, _, kh, kw = get_const_tuple(kernel.shape)
if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.bifrost.conv2d_nchw_winograd),
wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.bifrost",
plevel=15)
elif re.match(r"OIHW\d*o", kernel_layout):
strategy.add_implementation(
wrap_compute_conv2d(topi.bifrost.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.bifrost")
else:
raise RuntimeError("Unsupported conv2d layout {} for Mali(Bifrost)".
format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.bifrost.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.bifrost")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {} for Mali(Bifrost)".
format(layout))
else: # group_conv2d
raise RuntimeError("group_conv2d is not supported for Mali(Bifrost)")
return strategy
@conv2d_winograd_without_weight_transfrom_strategy.register("bifrost")
def conv2d_winograd_without_weight_transfrom_strategy_bifrost(attrs, inputs, out_type, target):
"""conv2d_winograd_without_weight_transfrom mali(bifrost) strategy"""
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.data_layout
strides = attrs.get_int_tuple("strides")
assert dilation == (1, 1), "Do not support dilate now"
assert strides == (1, 1), "Do not support strides now"
assert groups == 1, "Do not supoort arbitrary group number"
strategy = _op.OpStrategy()
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.bifrost.conv2d_nchw_winograd),
wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.bifrost")
else:
raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
format(layout))
return strategy
@dense_strategy.register("bifrost")
def dense_strategy_bifrost(attrs, inputs, out_type, target):
"""dense mali(bifrost) strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_dense(topi.bifrost.dense),
wrap_topi_schedule(topi.bifrost.schedule_dense),
name="dense.bifrost")
return strategy
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of HLS operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import topi
from .generic import *
from .. import op as _op
@schedule_injective.register("hls")
def schedule_injective_hls(attrs, outs, target):
"""schedule injective ops for hls"""
with target:
return topi.hls.schedule_injective(outs)
@schedule_reduce.register("hls")
def schedule_reduce_hls(attrs, outs, target):
"""schedule reduction ops for hls"""
with target:
return topi.hls.schedule_reduce(outs)
@schedule_concatenate.register("hls")
def schedule_concatenate_hls(attrs, outs, target):
"""schedule concatenate for hls"""
with target:
return topi.hls.schedule_injective(outs)
@schedule_pool.register("hls")
def schedule_pool_hls(attrs, outs, target):
"""schedule pooling ops for hls"""
with target:
return topi.hls.schedule_pool(outs, attrs.layout)
@schedule_adaptive_pool.register("hls")
def schedule_adaptive_pool_hls(attrs, outs, target):
"""schedule adaptive pooling ops for hls"""
with target:
return topi.hls.schedule_adaptive_pool(outs)
@schedule_softmax.register("hls")
def schedule_softmax_hls(attrs, outs, target):
"""schedule softmax for hls"""
with target:
return topi.hls.schedule_softmax(outs)
@override_native_generic_func("conv2d_strategy")
def conv2d_strategy_hls(attrs, inputs, out_type, target):
"""conv2d hls strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nchw),
wrap_topi_schedule(topi.hls.schedule_conv2d_nchw),
name="conv2d_nchw.hls")
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc),
wrap_topi_schedule(topi.hls.schedule_conv2d_nhwc),
name="conv2d_nhwc.hls")
else:
raise RuntimeError("Unsupported conv2d layout {}".format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.hls.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.hls")
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.hls.schedule_depthwise_conv2d_nhwc),
name="depthwise_nhwc.hls")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
raise RuntimeError("group_conv2d is not supported for hls")
return strategy
@override_native_generic_func("conv2d_NCHWc_strategy")
def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, target):
"""conv2d_NCHWc hls strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.hls.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.hls")
return strategy
@conv2d_transpose_strategy.register("hls")
def conv2d_transpose_strategy_hls(attrs, inputs, out_type, target):
"""conv2d_transpose hls strategy"""
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.hls.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.hls")
return strategy
@dense_strategy.register("hls")
def dense_strategy_hls(attrs, inputs, out_type, target):
"""dense hls strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_dense(topi.nn.dense),
wrap_topi_schedule(topi.hls.schedule_dense),
name="dense.hls")
return strategy
@bitserial_conv2d_strategy.register("hls")
def bitserial_conv2d_strategy_hls(attrs, inputs, out_type, target):
"""bitserial_conv2d hls strategy"""
strategy = _op.OpStrategy()
layout = attrs.data_layout
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nchw),
wrap_topi_schedule(topi.hls.schedule_bitserial_conv2d_nchw),
name="bitserial_conv2d_nchw.hls")
elif layout == "NHWC":
strategy.add_implementation(
wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nhwc),
wrap_topi_schedule(topi.hls.schedule_bitserial_conv2d_nhwc),
name="bitserial_conv2d_nhwc.hls")
else:
raise ValueError("Data layout {} not supported.".format(layout))
return strategy
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of x86 operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import topi
from .generic import *
from .. import op as _op
@conv2d_strategy.register("intel_graphics")
def conv2d_strategy_intel_graphics(attrs, inputs, out_type, target):
"""conv2d intel graphics strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_nchw),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_nchw),
name="conv2d_nchw.intel_graphics")
# conv2d_NCHWc won't work without alter op layout pass
# TODO(@Laurawly): fix this
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.intel_graphics",
plevel=5)
else:
raise RuntimeError("Unsupported conv2d layout {} for intel graphics".
format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.intel_graphics.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.intel_graphics")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
raise RuntimeError("group_conv2d is not supported for intel graphics")
return strategy
@conv2d_NCHWc_strategy.register("intel_graphics")
def conv2d_NCHWc_strategy_intel_graphics(attrs, inputs, out_type, target):
"""conv2d_NCHWc intel_graphics strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.intel_graphics")
return strategy
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of mali operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import re
import topi
from .generic import *
from .. import op as _op
@conv2d_strategy.register("mali")
def conv2d_strategy_mali(attrs, inputs, out_type, target):
"""conv2d mali strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
stride_h, stride_w = attrs.get_int_tuple("strides")
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
if kernel_layout == "OIHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.mali")
# check if winograd algorithm is applicable
_, _, kh, kw = get_const_tuple(kernel.shape)
if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.conv2d_nchw_winograd),
wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.mali",
plevel=15)
elif re.match(r"OIHW\d*o", kernel_layout):
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.mali")
else:
raise RuntimeError("Unsupported weight layout {} for conv2d NCHW".
format(kernel_layout))
else:
raise RuntimeError("Unsupported conv2d layout {} for mali".format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.mali.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.mali")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {} for mali".format(layout))
else: # group_conv2d
raise RuntimeError("group_conv2d is not supported for mali")
return strategy
@conv2d_winograd_without_weight_transfrom_strategy.register("mali")
def conv2d_winograd_without_weight_transfrom_strategy_mali(attrs, inputs, out_type, target):
"""conv2d_winograd_without_weight_transfrom mali strategy"""
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.data_layout
strides = attrs.get_int_tuple("strides")
kernel = inputs[1]
assert dilation == (1, 1), "Do not support dilate now"
assert strides == (1, 1), "Do not support strides now"
assert groups == 1, "Do not supoort arbitrary group number"
strategy = _op.OpStrategy()
if layout == "NCHW":
assert len(kernel.shape) == 5, "Kernel must be packed into 5-dim"
strategy.add_implementation(
wrap_compute_conv2d(topi.mali.conv2d_nchw_winograd),
wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.mali")
else:
raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
format(layout))
return strategy
@dense_strategy.register("mali")
def dense_strategy_mali(attrs, inputs, out_type, target):
"""dense mali strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_dense(topi.mali.dense),
wrap_topi_schedule(topi.mali.schedule_dense),
name="dense.mali")
return strategy
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of OpenGL operator strategy."""
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import topi
from .generic import *
from .. import op as _op
@schedule_injective.register("opengl")
def schedule_injective_opengl(attrs, outs, target):
"""schedule injective ops for opengl"""
with target:
return topi.opengl.schedule_injective(outs)
@schedule_concatenate.register("opengl")
def schedule_concatenate_opengl(attrs, outs, target):
"""schedule concatenate for opengl"""
with target:
return topi.opengl.schedule_injective(outs)
@schedule_pool.register("opengl")
def schedule_pool_opengl(attrs, outs, target):
"""schedule pooling ops for opengl"""
with target:
return topi.opengl.schedule_pool(outs, attrs.layout)
@schedule_adaptive_pool.register("opengl")
def schedule_adaptive_pool_opengl(attrs, outs, target):
"""schedule adative pooling ops for opengl"""
with target:
return topi.opengl.schedule_adaptive_pool(outs)
@schedule_softmax.register("opengl")
def schedule_softmax_opengl(attrs, outs, target):
"""schedule softmax for opengl"""
with target:
return topi.opengl.schedule_softmax(outs)
@conv2d_strategy.register("opengl")
def conv2d_strategy_opengl(attrs, inputs, out_type, target):
"""conv2d opengl strategy"""
strategy = _op.OpStrategy()
groups = attrs.groups
layout = attrs.data_layout
assert groups == 1, "Don't support group conv2d on OpenGL"
assert layout == "NCHW", "Only support conv2d layout NCHW for OpenGL"
strategy.add_implementation(wrap_compute_conv2d(topi.nn.conv2d),
wrap_topi_schedule(topi.opengl.schedule_conv2d_nchw),
name="conv2d_nchw.opengl")
return strategy
@dense_strategy.register("opengl")
def dense_strategy_opengl(attrs, inputs, out_type, target):
"""dense opengl strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(wrap_compute_dense(topi.nn.dense),
wrap_topi_schedule(topi.opengl.schedule_dense),
name="dense.opengl")
return strategy
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of ROCm operator strategy."""
# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
import topi
from .generic import *
from .. import op as _op
@schedule_lrn.register("rocm")
def schedule_lrn_rocm(attrs, outs, target):
"""schedule LRN for rocm"""
with target:
return topi.rocm.schedule_lrn(outs)
@conv2d_strategy.register("rocm")
def conv2d_strategy_rocm(attrs, inputs, out_type, target):
"""conv2d rocm strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
groups = attrs.groups
layout = attrs.data_layout
stride_h, stride_w = attrs.get_int_tuple("strides")
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda")
_, _, kh, kw = get_const_tuple(kernel.shape)
if kh <= 7 and kw <= 7 and kh == kw and stride_h == 1 and stride_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.cuda",
plevel=15)
elif layout == "HWCN":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda")
# TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
# elif layout == "NHWC":
# assert kernel_layout == "HWIO"
# strategy.add_implementation(
# wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
# wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
# name="conv2d_nhwc.cuda")
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda")
else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add miopen implementation
if "miopen" in target.libs:
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=15)
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.cuda")
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.cuda")
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
if layout == 'NCHW':
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.cuda")
elif layout == 'NCHW4c' and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
name="group_conv2d_NCHWc_int8.cuda")
else:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy
@dense_strategy.register("rocm")
def dense_strategy_rocm(attrs, inputs, out_type, target):
"""Dense strategy for ROCM"""
strategy = _op.OpStrategy()
assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense),
wrap_topi_schedule(topi.rocm.schedule_dense),
name="dense.rocm")
if target.target_name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense_rocblas),
wrap_topi_schedule(topi.rocm.dense_rocblas),
name="dense_rocblas.rocm",
plevel=5)
return strategy
......@@ -17,65 +17,27 @@
# pylint: disable=invalid-name, unused-argument
"""Faster R-CNN and Mask R-CNN operations."""
import topi
from topi.util import get_const_tuple, get_float_tuple, get_const_int
from topi.util import get_const_tuple
from .. import op as reg
from .. import strategy
from ..op import OpPattern
@reg.register_compute("vision.roi_align")
def compute_roi_align(attrs, inputs, _, target):
"""Compute definition of roi_align"""
assert attrs.layout == "NCHW"
return [topi.vision.rcnn.roi_align_nchw(
inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size),
spatial_scale=attrs.spatial_scale, sample_ratio=attrs.sample_ratio)]
@reg.register_schedule("vision.roi_align")
def schedule_roi_align(_, outs, target):
"""Schedule definition of roi_align"""
with target:
return topi.generic.vision.schedule_roi_align(outs)
# roi_align
reg.register_strategy("vision.roi_align", strategy.roi_align_strategy)
reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)
# roi_pool
@reg.register_compute("vision.roi_pool")
def compute_roi_pool(attrs, inputs, _, target):
def compute_roi_pool(attrs, inputs, _):
"""Compute definition of roi_pool"""
assert attrs.layout == "NCHW"
return [topi.vision.rcnn.roi_pool_nchw(
inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size),
spatial_scale=attrs.spatial_scale)]
@reg.register_schedule("vision.roi_pool")
def schedule_roi_pool(_, outs, target):
"""Schedule definition of roi_pool"""
with target:
return topi.generic.vision.schedule_roi_pool(outs)
reg.register_schedule("vision.roi_pool", strategy.schedule_roi_pool)
reg.register_pattern("vision.roi_pool", OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("vision.proposal")
def compute_proposal(attrs, inputs, _, target):
"""Compute definition of proposal"""
scales = get_float_tuple(attrs.scales)
ratios = get_float_tuple(attrs.ratios)
feature_stride = attrs.feature_stride
threshold = attrs.threshold
rpn_pre_nms_top_n = attrs.rpn_pre_nms_top_n
rpn_post_nms_top_n = attrs.rpn_post_nms_top_n
rpn_min_size = attrs.rpn_min_size
iou_loss = bool(get_const_int(attrs.iou_loss))
with target:
return [
topi.vision.rcnn.proposal(inputs[0], inputs[1], inputs[2], scales, ratios,
feature_stride, threshold, rpn_pre_nms_top_n,
rpn_post_nms_top_n, rpn_min_size, iou_loss)
]
@reg.register_schedule("vision.proposal")
def schedule_proposal(_, outs, target):
"""Schedule definition of proposal"""
with target:
return topi.generic.schedule_proposal(outs)
# proposal
reg.register_strategy("vision.proposal", strategy.proposal_strategy)
reg.register_pattern("vision.proposal", OpPattern.OPAQUE)
......@@ -18,104 +18,25 @@
"""Definition of vision ops"""
from __future__ import absolute_import
import topi
from topi.util import get_const_int, get_const_float, get_float_tuple
from .. import op as reg
from .. import strategy
from ..op import OpPattern
@reg.register_schedule("vision.multibox_prior")
def schedule_multibox_prior(_, outs, target):
"""Schedule definition of multibox_prior"""
with target:
return topi.generic.schedule_multibox_prior(outs)
@reg.register_compute("vision.multibox_prior")
def compute_multibox_prior(attrs, inputs, _, target):
"""Compute definition of multibox_prior"""
sizes = get_float_tuple(attrs.sizes)
ratios = get_float_tuple(attrs.ratios)
steps = get_float_tuple(attrs.steps)
offsets = get_float_tuple(attrs.offsets)
clip = bool(get_const_int(attrs.clip))
return [
topi.vision.ssd.multibox_prior(inputs[0], sizes, ratios, steps,
offsets, clip)
]
# multibox_prior
reg.register_strategy("vision.multibox_prior", strategy.multibox_prior_strategy)
reg.register_pattern("vision.multibox_prior", OpPattern.OPAQUE)
# multibox_transform_loc
@reg.register_schedule("vision.multibox_transform_loc")
def schedule_multibox_transform_loc(_, outs, target):
"""Schedule definition of multibox_detection"""
with target:
return topi.generic.schedule_multibox_transform_loc(outs)
@reg.register_compute("vision.multibox_transform_loc")
def compute_multibox_transform_loc(attrs, inputs, _, target):
"""Compute definition of multibox_detection"""
clip = bool(get_const_int(attrs.clip))
threshold = get_const_float(attrs.threshold)
variances = get_float_tuple(attrs.variances)
return topi.vision.ssd.multibox_transform_loc(
inputs[0], inputs[1], inputs[2], clip, threshold, variances)
reg.register_strategy("vision.multibox_transform_loc", strategy.multibox_transform_loc_strategy)
reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE)
reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE)
# Get counts of valid boxes
@reg.register_schedule("vision.get_valid_counts")
def schedule_get_valid_counts(_, outs, target):
"""Schedule definition of get_valid_counts"""
with target:
return topi.generic.schedule_get_valid_counts(outs)
@reg.register_compute("vision.get_valid_counts")
def compute_get_valid_counts(attrs, inputs, _, target):
"""Compute definition of get_valid_counts"""
score_threshold = get_const_float(attrs.score_threshold)
id_index = get_const_int(attrs.id_index)
score_index = get_const_int(attrs.score_index)
return topi.vision.get_valid_counts(inputs[0], score_threshold,
id_index, score_index)
reg.register_strategy("vision.get_valid_counts", strategy.get_valid_counts_strategy)
reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE)
# non-maximum suppression
@reg.register_schedule("vision.non_max_suppression")
def schedule_nms(_, outs, target):
"""Schedule definition of nms"""
with target:
return topi.generic.schedule_nms(outs)
@reg.register_compute("vision.non_max_suppression")
def compute_nms(attrs, inputs, _, target):
"""Compute definition of nms"""
return_indices = bool(get_const_int(attrs.return_indices))
max_output_size = get_const_int(attrs.max_output_size)
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
top_k = get_const_int(attrs.top_k)
coord_start = get_const_int(attrs.coord_start)
score_index = get_const_int(attrs.score_index)
id_index = get_const_int(attrs.id_index)
invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
return [
topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
iou_threshold, force_suppress, top_k,
coord_start, score_index, id_index,
return_indices, invalid_to_bottom)
]
reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy)
reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE)
......@@ -17,9 +17,9 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from ..op import register_schedule, register_pattern
from ..op import schedule_injective, OpPattern
from ..op import register_pattern, OpPattern
from ..op import register_injective_schedule
# reorg
register_pattern("vision.yolo_reorg", OpPattern.INJECTIVE)
register_schedule("vision.yolo_reorg", schedule_injective)
register_injective_schedule("vision.yolo_reorg")
......@@ -31,7 +31,7 @@ from .quantize import _forward_op
@_reg.register_compute("relay.op.annotation.simulated_quantize")
def simulated_quantize_compute(attrs, inputs, out_type, target):
def simulated_quantize_compute(attrs, inputs, out_type):
"""Compiler for simulated_quantize."""
assert len(inputs) == 4
assert attrs.sign
......@@ -52,11 +52,10 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):
return [rdata]
_reg.register_schedule("relay.op.annotation.simulated_quantize",
_reg.schedule_injective)
_reg.register_injective_schedule("relay.op.annotation.simulated_quantize")
_reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.ELEMWISE)
_reg.register_schedule("annotation.cast_hint", _reg.schedule_injective)
_reg.register_injective_schedule("annotation.cast_hint")
@register_relay_node
......
......@@ -44,15 +44,18 @@ def conv_block(data, name, channels, kernel_size=(3, 3), strides=(1, 1),
def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
kernel_size=(3, 3), downsample=False, padding=(1, 1),
epsilon=1e-5, layout='NCHW'):
epsilon=1e-5, layout='NCHW', dtype="float32"):
"""Helper function to get a separable conv block"""
if downsample:
strides = (2, 2)
else:
strides = (1, 1)
# depthwise convolution + bn + relu
wshape = (depthwise_channels, 1) + kernel_size
weight = relay.var(name + "_weight", shape=wshape, dtype=dtype)
conv1 = layers.conv2d(
data=data,
weight=weight,
channels=depthwise_channels,
groups=depthwise_channels,
kernel_size=kernel_size,
......@@ -85,38 +88,41 @@ def mobile_net(num_classes=1000, data_shape=(1, 3, 224, 224),
body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2, 2),
layout=layout)
body = separable_conv_block(body, 'separable_conv_block_1',
int(32*alpha), int(64*alpha), layout=layout)
int(32*alpha), int(64*alpha), layout=layout,
dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_2',
int(64*alpha), int(128*alpha), downsample=True,
layout=layout)
layout=layout, dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_3',
int(128*alpha), int(128*alpha), layout=layout)
int(128*alpha), int(128*alpha), layout=layout,
dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_4',
int(128*alpha), int(256*alpha), downsample=True,
layout=layout)
layout=layout, dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_5',
int(256*alpha), int(256*alpha), layout=layout)
int(256*alpha), int(256*alpha), layout=layout,
dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_6',
int(256*alpha), int(512*alpha), downsample=True,
layout=layout)
layout=layout, dtype=dtype)
if is_shallow:
body = separable_conv_block(body, 'separable_conv_block_7',
int(512*alpha), int(1024*alpha),
downsample=True, layout=layout)
downsample=True, layout=layout, dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_8',
int(1024*alpha), int(1024*alpha),
downsample=True, layout=layout)
downsample=True, layout=layout, dtype=dtype)
else:
for i in range(7, 12):
body = separable_conv_block(body, 'separable_conv_block_%d' % i,
int(512*alpha), int(512*alpha),
layout=layout)
layout=layout, dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_12',
int(512*alpha), int(1024*alpha),
downsample=True, layout=layout)
downsample=True, layout=layout, dtype=dtype)
body = separable_conv_block(body, 'separable_conv_block_13',
int(1024*alpha), int(1024*alpha),
layout=layout)
layout=layout, dtype=dtype)
pool = relay.nn.global_avg_pool2d(data=body, layout=layout)
flatten = relay.nn.batch_flatten(data=pool)
weight = relay.var('fc_weight')
......
......@@ -184,6 +184,7 @@ def override_native_generic_func(func_name):
fresult = decorate(fdefault, dispatch_func)
fresult.fdefault = fdefault
fresult.register = register
fresult.generic_func_node = generic_func_node
return fresult
return fdecorate
......@@ -268,4 +269,5 @@ def generic_func(fdefault):
fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register
fdecorate.fdefault = fdefault
fdecorate.dispatch_dict = dispatch_dict
return fdecorate
......@@ -23,8 +23,8 @@ from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod,
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from tvm.tir import comm_reducer, min, max, sum
from .schedule import Schedule, create_schedule
from .tensor import Tensor
from .schedule import Schedule, create_schedule, SpecializedCondition
from .tensor import TensorSlice, Tensor
from .tensor_intrin import decl_tensor_intrin
from .tag import tag_scope
from .operation import placeholder, compute, scan, extern, var, size_var
......
......@@ -517,4 +517,39 @@ class Stage(Object):
_ffi_api.StageOpenGL(self)
@tvm._ffi.register_object
class SpecializedCondition(Object):
"""Specialized condition to enable op specialization."""
def __init__(self, conditions):
"""Create a specialized condition.
.. note::
Conditions are represented in conjunctive joint form (CNF).
Each condition should be a simple expression, e.g., n > 16,
m % 8 == 0, etc., where n, m are tvm.Var that represents a
dimension in the tensor shape.
Parameters
----------
conditions : List of tvm.Expr
List of conditions in conjunctive joint form (CNF).
"""
if not isinstance(conditions, (list, _container.Array)):
conditions = [conditions]
self.__init_handle_by_constructor__(
_ffi_api.CreateSpecializedCondition, conditions)
@staticmethod
def current():
"""Returns the current specialized condition"""
return _ffi_api.GetCurrentSpecialization()
def __enter__(self):
_ffi_api.EnterSpecializationScope(self)
return self
def __exit__(self, ptype, value, trace):
_ffi_api.ExitSpecializationScope(self)
tvm._ffi._init_api("schedule", __name__)
......@@ -964,3 +964,11 @@ class Let(PrimExprWithOp):
def __init__(self, var, value, body):
self.__init_handle_by_constructor__(
_ffi_api.Let, var, value, body)
@tvm._ffi.register_object
class Any(PrimExpr):
"""Any node.
"""
def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.Any)
......@@ -47,11 +47,19 @@
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
TVM_REGISTER_NODE_TYPE(CachedFuncNode);
TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
TVM_REGISTER_NODE_TYPE(CCacheValueNode);
TVM_REGISTER_OBJECT_TYPE(CompileEngineNode);
LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
auto n = make_object<LoweredOutputNode>();
n->outputs = std::move(outputs);
n->implementation = std::move(impl);
data_ = std::move(n);
}
CCacheKey CCacheKeyNode::make(Function source_func, Target target) {
auto n = make_object<CCacheKeyNode>();
n->source_func = std::move(source_func);
......@@ -108,9 +116,7 @@ class ScheduleGetter :
explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
std::pair<te::Schedule, CachedFunc> Create(const Function& prim_func) {
static auto fschedule =
Op::GetAttr<FTVMSchedule>("FTVMSchedule");
CachedFunc Create(const Function& prim_func) {
auto cache_node = make_object<CachedFuncNode>();
cache_node->target = target_;
for (Var param : prim_func->params) {
......@@ -147,7 +153,6 @@ class ScheduleGetter :
}
cache_node->func_name = candidate_name;
CachedFunc cfunc(cache_node);
CHECK(master_op_.defined());
// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
......@@ -161,15 +166,16 @@ class ScheduleGetter :
te::Schedule schedule;
// No need to register schedule for device copy op.
if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) {
schedule =
fschedule[master_op_](master_attrs_, tensor_outs, target_);
CHECK(master_implementation_.defined());
schedule = master_implementation_.Schedule(master_attrs_, tensor_outs, target_);
for (const auto& scalar : scalars_) {
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline();
}
}
}
return std::make_pair(schedule, cfunc);
cache_node->schedule = std::move(schedule);
return CachedFunc(cache_node);
}
Array<te::Tensor> VisitExpr(const Expr& expr) {
......@@ -208,16 +214,16 @@ class ScheduleGetter :
LOG(FATAL) << "not handled";
return tvm::PrimExpr();
}
}, "compile_engine_const", topi::kBroadcast);
}, "compile_engine_const", topi::kBroadcast);
scalars_.push_back(value->op);
return {value};
}
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fcompute =
Op::GetAttr<FTVMCompute>("FTVMCompute");
static auto fpattern =
Op::GetAttr<TOpPattern>("TOpPattern");
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
CHECK(flower_call) << "relay.backend.lower_call is not registered.";
Array<te::Tensor> inputs;
int count_tuple = 0;
......@@ -231,51 +237,37 @@ class ScheduleGetter :
}
if (count_tuple) {
CHECK_EQ(call_node->args.size(), 1U)
<< "Only allow function with a single tuple input";
}
// Prepare the call_node->checked_type(). For the call node inputs, we ensure that the shape is
// Int32. Following code ensures the same for the output as well.
// TODO(@icemelon): Support recursive tuple
Type call_node_type = call_node->checked_type();
if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
call_node_type = TensorType(GetShape(tt->shape), tt->dtype);
} else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
std::vector<Type> new_fields;
for (auto field : tuple_t->fields) {
if (const auto* tt = field.as<TensorTypeNode>()) {
new_fields.push_back(TensorType(GetShape(tt->shape), tt->dtype));
} else {
new_fields.push_back(field);
}
}
call_node_type = TupleType(new_fields);
<< "Only allow function with a single tuple input";
}
CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
Array<te::Tensor> outputs;
OpImplementation impl;
// Skip fcompute for device copy operators as it is not registered.
if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->();
outputs.push_back(te::TensorNode::make(copy_input->shape, copy_input->dtype,
te::Operation(), 0));
} else {
outputs = fcompute[op](call_node->attrs, inputs,
call_node_type, target_);
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
outputs = lowered_out->outputs;
impl = lowered_out->implementation;
}
int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) {
CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce)
<< "Two complicated op in a primitive function "
<< " master=" << master_op_ << " current=" << op;
<< "Two complicated op in a primitive function "
<< " master=" << master_op_ << " current=" << op;
}
if (op_pattern >= master_op_pattern_) {
master_op_ = op;
master_attrs_ = call_node->attrs;
master_op_pattern_ = op_pattern;
master_implementation_ = impl;
}
if (outputs.size() != 1) {
const auto* tuple_type =
......@@ -332,6 +324,7 @@ class ScheduleGetter :
Op master_op_;
Attrs master_attrs_;
int master_op_pattern_{0};
OpImplementation master_implementation_;
std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
Array<te::Operation> scalars_;
......@@ -677,8 +670,7 @@ class CompileEngineImpl : public CompileEngineNode {
* \return Pair of schedule and cache.
* The funcs field in cache is not yet populated.
*/
std::pair<te::Schedule, CachedFunc> CreateSchedule(
const Function& source_func, const Target& target) {
CachedFunc CreateSchedule(const Function& source_func, const Target& target) {
return ScheduleGetter(target).Create(source_func);
}
......@@ -713,9 +705,9 @@ class CompileEngineImpl : public CompileEngineNode {
With<Target> target_scope(key->target);
CHECK(!value->cached_func.defined());
auto spair = CreateSchedule(key->source_func, key->target);
auto cfunc = CreateSchedule(key->source_func, key->target);
auto cache_node = make_object<CachedFuncNode>(
*(spair.second.operator->()));
*(cfunc.operator->()));
// Skip lowering for device copy node.
const Expr body = (key->source_func)->body;
......@@ -735,11 +727,12 @@ class CompileEngineImpl : public CompileEngineNode {
// lower the function
if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
cache_node->funcs = (*f)(
spair.first, all_args, cache_node->func_name, key->source_func);
cfunc->schedule, all_args, cache_node->func_name, key->source_func);
} else {
tvm::BuildConfig bcfg = BuildConfig::Create();
std::unordered_map<te::Tensor, tir::Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name,
binds, bcfg);
}
value->cached_func = CachedFunc(cache_node);
return value;
......@@ -820,6 +813,11 @@ const CompileEngine& CompileEngine::Global() {
return *inst;
}
TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
.set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
return LoweredOutput(outputs, impl);
});
TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed(CCacheKeyNode::make);
......
......@@ -30,6 +30,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op_strategy.h>
#include <string>
#include <functional>
......@@ -44,6 +45,28 @@ enum ShapeFuncParamState {
kNeedBoth = 3,
};
struct LoweredOutputNode : public Object {
/*! \brief The outputs to the function */
tvm::Array<te::Tensor> outputs;
/*! \brief The implementation used to compute the output */
OpImplementation implementation;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("outputs", &outputs);
v->Visit("implementation", &implementation);
}
static constexpr const char* _type_key = "relay.LoweredOutput";
TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object);
};
class LoweredOutput : public ObjectRef {
public:
TVM_DLL LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl);
TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode);
};
/*! \brief Node container to represent a cached function. */
struct CachedFuncNode : public Object {
/* \brief compiled target */
......@@ -54,6 +77,8 @@ struct CachedFuncNode : public Object {
tvm::Array<te::Tensor> inputs;
/* \brief The outputs to the function */
tvm::Array<te::Tensor> outputs;
/*! \brief The schedule to the function */
te::Schedule schedule;
/*! \brief The lowered functions to support the function. */
tvm::Array<tir::LoweredFunc> funcs;
/*! \brief Parameter usage states in the shape function. */
......@@ -64,6 +89,7 @@ struct CachedFuncNode : public Object {
v->Visit("func_name", &func_name);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("schedule", &schedule);
v->Visit("funcs", &funcs);
v->Visit("shape_func_param_states", &shape_func_param_states);
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/tvm/relay/ir/op_strategy.cc
* \brief The Relay operator Strategy and related data structure.
*/
#include <tvm/relay/op_strategy.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(OpImplementationNode);
TVM_REGISTER_NODE_TYPE(OpSpecializationNode);
TVM_REGISTER_NODE_TYPE(OpStrategyNode);
Array<te::Tensor> OpImplementation::Compute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) {
return (*this)->fcompute(attrs, inputs, out_type);
}
te::Schedule OpImplementation::Schedule(const Attrs& attrs,
const Array<te::Tensor> &outs,
const Target& target) {
return (*this)->fschedule(attrs, outs, target);
}
void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute,
tvm::relay::FTVMSchedule fschedule,
std::string name,
int plevel) {
auto n = make_object<OpImplementationNode>();
n->fcompute = fcompute;
n->fschedule = fschedule;
n->name = std::move(name);
n->plevel = plevel;
(*this)->implementations.push_back(OpImplementation(n));
}
void OpStrategy::AddImplementation(FTVMCompute fcompute,
FTVMSchedule fschedule,
std::string name,
int plevel) {
auto curr_cond = te::SpecializedCondition::Current();
auto self = this->operator->();
Array<OpSpecialization> specializations = self->specializations;
OpSpecialization op_spec;
for (OpSpecialization op_spec : specializations) {
if (op_spec->condition == curr_cond) {
op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel);
return;
}
}
ObjectPtr<OpSpecializationNode> n = make_object<OpSpecializationNode>();
n->condition = curr_cond;
op_spec = OpSpecialization(n);
op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel);
self->specializations.push_back(op_spec);
}
TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute")
.set_body([](TVMArgs args, TVMRetValue* rv) {
OpImplementation imp = args[0];
Attrs attrs = args[1];
Array<te::Tensor> inputs = args[2];
Type out_type = args[3];
*rv = imp.Compute(attrs, inputs, out_type);
});
TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
OpImplementation imp = args[0];
Attrs attrs = args[1];
Array<te::Tensor> outs = args[2];
Target target = args[3];
*rv = imp.Schedule(attrs, outs, target);
});
TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectPtr<OpStrategyNode> n = make_object<OpStrategyNode>();
*rv = OpStrategy(n);
});
TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation")
.set_body([](TVMArgs args, TVMRetValue* rv) {
OpStrategy strategy = args[0];
FTVMCompute compute = args[1];
FTVMSchedule schedule = args[2];
std::string name = args[3];
int plevel = args[4];
strategy.AddImplementation(compute, schedule, name, plevel);
});
} // namespace relay
} // namespace tvm
......@@ -79,7 +79,7 @@ TVM_ADD_FILELINE)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......@@ -105,7 +105,7 @@ TVM_ADD_FILELINE)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......@@ -123,7 +123,7 @@ Mark the start of bitpacking.
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......@@ -140,7 +140,7 @@ Mark the end of bitpacking.
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......@@ -163,7 +163,7 @@ Mark a checkpoint for checkpointing memory optimization.
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
Array<te::Tensor> outputs;
for (size_t i = 0; i < inputs.size(); ++i) {
outputs.push_back(topi::identity(inputs[i]));
......@@ -184,7 +184,7 @@ Beginning of a region that is handled by a given compiler.
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......@@ -209,7 +209,7 @@ End of a region that is handled by a given compiler.
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......
......@@ -36,9 +36,8 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(DebugAttrs);
Array<te::Tensor> DebugCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
return Array<te::Tensor>{ topi::identity(inputs[0]) };
}
......
......@@ -83,7 +83,7 @@ RELAY_REGISTER_OP("memory.alloc_storage")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......@@ -179,7 +179,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......@@ -228,7 +228,7 @@ RELAY_REGISTER_OP("memory.invoke_tvm_op")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......@@ -252,7 +252,7 @@ RELAY_REGISTER_OP("memory.kill")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......@@ -340,7 +340,7 @@ RELAY_REGISTER_OP("memory.shape_func")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<te::Tensor> {
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});
......
......@@ -735,58 +735,6 @@ weight transformation in advance.
.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
// Positional relay function to create conv2d winograd nnpack operator
// used by frontend FFI.
Expr MakeConv2DWinogradNNPACK(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_object<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_without_weight_transform");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform")
.set_body_typed(MakeConv2DWinogradNNPACK);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout.
This operator assumes the weight tensor is already pre-transformed by
nn.contrib_conv2d_winograd_nnpack_weight_transform.
- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
- **weight**: Any shape
We do not check the shape for this input tensor. Since different backend
has different layout strategy.
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
......@@ -850,55 +798,6 @@ weight transformation in advance.
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
Expr MakeConv2DNCHWcInt8(Expr data,
Expr kernel,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_object<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc_int8");
return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc_int8")
.set_body_typed(MakeConv2DNCHWcInt8);
RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc_int8")
.describe(R"code(Compute conv2d with NCHWc data layout with int8 inputs.
- **data**: Input is 5D packed tensor.
- **weight**: 7D packed tensor.
- **out**: Output is 5D packed tensor
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DNCHWcInt8", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ConvInferCorrectLayout<Conv2DAttrs>);
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
Expr MakeConv2DNCHWc(Expr data,
Expr kernel,
Array<IndexExpr> strides,
......
......@@ -153,6 +153,16 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< " But got " << out_layout;
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
bool is_depthwise = false;
if (param->groups > 1) {
CHECK(weight && weight->shape.defined()) <<
"Weight shape must be specified when groups is greater than 1.";
Array<IndexExpr> wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape);
if (tvm::tir::Equal(param->groups, dshape_nchw[1]) &&
tvm::tir::Equal(param->groups, wshape_oihw[0])) {
is_depthwise = true;
}
}
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
// infer weight if the kernel_size and channels are defined
......@@ -161,9 +171,9 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape;
if (tvm::tir::Equal(param->channels, param->groups) && !tvm::tir::Equal(param->channels, 1)) {
if (is_depthwise) {
// infer weight's shape for depthwise convolution
wshape = {{dshape_nchw[1], indexdiv(param->groups, dshape_nchw[1]), param->kernel_size[0],
wshape = {{dshape_nchw[1], indexdiv(param->channels, dshape_nchw[1]), param->kernel_size[0],
param->kernel_size[1]}};
} else {
wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0],
......
......@@ -93,8 +93,9 @@ RELAY_REGISTER_OP("nn.bias_add")
.add_argument("bias", "1D Tensor", "Bias.")
.set_support_level(1)
.add_type_rel("BiasAdd", BiasAddRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type, const Target& target) {
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<BiasAddAttrs>();
return tvm::Array<tvm::te::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
});
......@@ -234,8 +235,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Type& out_type) {
const auto* param = attrs.as<LeakyReluAttrs>();
return Array<te::Tensor>{ topi::leaky_relu(inputs[0], param->alpha) };
});
......@@ -315,8 +315,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Type& out_type) {
const auto* param = attrs.as<PReluAttrs>();
return Array<te::Tensor>{ topi::prelu(inputs[0], inputs[1], param->axis)};
});
......@@ -351,8 +350,7 @@ RELAY_REGISTER_OP("nn.softmax")
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Type& out_type) {
const auto* param = attrs.as<SoftmaxAttrs>();
CHECK(param != nullptr);
return Array<te::Tensor>{ topi::nn::softmax(inputs[0], param->axis) };
......@@ -385,8 +383,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Type& out_type) {
const auto* param = attrs.as<SoftmaxAttrs>();
CHECK(param != nullptr);
CHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
......@@ -462,8 +459,7 @@ Example::
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Type& out_type) {
return Array<te::Tensor>{ topi::nn::flatten(inputs[0]) };
});
......@@ -489,8 +485,7 @@ RELAY_REGISTER_OP("nn.relu")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Type& out_type) {
return Array<te::Tensor>{ topi::relu(inputs[0], 0.0f) };
});
......
......@@ -161,9 +161,8 @@ bool PadRel(const Array<Type>& types,
}
Array<te::Tensor> PadCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<PadAttrs>();
CHECK(param != nullptr);
......
......@@ -164,9 +164,8 @@ bool Pool2DRel(const Array<Type>& types,
template<typename AttrType, topi::nn::PoolType mode>
Array<te::Tensor> Pool2DCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
......@@ -331,9 +330,8 @@ bool GlobalPool2DRel(const Array<Type>& types,
template<topi::nn::PoolType mode>
Array<te::Tensor> GlobalPool2DCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr);
......@@ -465,9 +463,8 @@ bool AdaptivePool2DRel(const Array<Type>& types,
template<topi::nn::PoolType mode>
Array<te::Tensor> AdaptivePool2DCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AdaptivePool2DAttrs>();
CHECK(param != nullptr);
......@@ -593,8 +590,9 @@ bool Pool2DGradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
template <typename AttrType, topi::nn::PoolType mode>
Array<te::Tensor> Pool2DGradCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type, const Target& target) {
Array<te::Tensor> Pool2DGradCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCHW("NCHW");
const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
......@@ -793,9 +791,8 @@ bool Pool1DRel(const Array<Type>& types,
template<typename AttrType, topi::nn::PoolType mode>
Array<te::Tensor> Pool1DCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCW("NCW");
const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
......@@ -985,9 +982,8 @@ bool Pool3DRel(const Array<Type>& types,
template<typename AttrType, topi::nn::PoolType mode>
Array<te::Tensor> Pool3DCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
static const Layout kNCDHW("NCDHW");
const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
......
......@@ -32,9 +32,8 @@ namespace relay {
#define RELAY_BINARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \
const Array<te::Tensor>& inputs, \
const Type& out_type, \
const Target& target) -> Array<te::Tensor> { \
const Array<te::Tensor>& inputs, \
const Type& out_type) -> Array<te::Tensor> { \
CHECK_EQ(inputs.size(), 2U); \
return {FTOPI(inputs[0], inputs[1])}; \
} \
......
......@@ -176,7 +176,6 @@ template<typename F>
Array<te::Tensor> ReduceCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target,
F f) {
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
......@@ -321,10 +320,9 @@ bool ReduceRel(const Array<Type>& types,
Array<te::Tensor> ArgMaxCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::argmax);
const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::argmax);
}
......@@ -341,10 +339,9 @@ values over a given axis.
Array<te::Tensor> ArgMinCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::argmin);
const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::argmin);
}
RELAY_REGISTER_REDUCE_OP("argmin")
......@@ -359,10 +356,9 @@ values over a given axis.
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
Array<te::Tensor> SumCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::sum);
const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::sum);
}
......@@ -393,10 +389,9 @@ Example::
Array<te::Tensor> AllCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::all);
const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::all);
}
......@@ -430,10 +425,9 @@ Example::
Array<te::Tensor> AnyCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::any);
const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::any);
}
......@@ -467,10 +461,9 @@ Example::
Array<te::Tensor> MaxCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::max);
const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::max);
}
RELAY_REGISTER_REDUCE_OP("max")
......@@ -485,10 +478,9 @@ RELAY_REGISTER_REDUCE_OP("max")
Array<te::Tensor> MinCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::min);
const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::min);
}
......@@ -504,10 +496,9 @@ RELAY_REGISTER_REDUCE_OP("min")
Array<te::Tensor> ProdCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::prod);
const Array<te::Tensor>& inputs,
const Type& out_type) {
return ReduceCompute(attrs, inputs, out_type, topi::prod);
}
RELAY_REGISTER_REDUCE_OP("prod")
......@@ -534,9 +525,8 @@ Example::
Array<te::Tensor> MeanCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
......@@ -546,7 +536,7 @@ Array<te::Tensor> MeanCompute(const Attrs& attrs,
param->exclude)) {
count *= inputs[0]->shape[i];
}
auto res = ReduceCompute(attrs, inputs, out_type, target, topi::sum);
auto res = ReduceCompute(attrs, inputs, out_type, topi::sum);
return {topi::divide(res[0], count)};
}
......@@ -599,9 +589,8 @@ bool VarianceRel(const Array<Type>& types,
}
Array<te::Tensor> VarianceCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
......@@ -615,7 +604,7 @@ Array<te::Tensor> VarianceCompute(const Attrs& attrs,
}
std::vector<Integer> expand_shape;
auto sq_diff = topi::power(topi::subtract(data, mean), 2);
auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, target, topi::sum)[0], count);
auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, topi::sum)[0], count);
return {var};
}
......
......@@ -34,9 +34,8 @@ namespace relay {
#define RELAY_UNARY_COMPUTE(FTOPI) \
[] (const Attrs& attrs, \
const Array<te::Tensor>& inputs, \
const Type& out_type, \
const Target& target) -> Array<te::Tensor> { \
const Array<te::Tensor>& inputs, \
const Type& out_type) -> Array<te::Tensor> { \
return {FTOPI(inputs[0])}; \
} \
......@@ -302,9 +301,8 @@ bool ShapeOfRel(const Array<Type>& types,
}
Array<te::Tensor> ShapeOfCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 1);
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
......@@ -353,9 +351,8 @@ bool NdarraySizeRel(const Array<Type>& types,
}
Array<te::Tensor> NdarraySizeCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Array<te::Tensor>& inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 1);
const auto* param = attrs.as<NdarraySizeAttrs>();
CHECK(param != nullptr);
......
......@@ -83,8 +83,7 @@ Its function is mostly shape transform.")doc" TVM_ADD_FILELINE)
.add_type_rel("YoloReorg", YoloReorgRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target) {
const Type& out_type) {
const auto* params = attrs.as<YoloReorgAttrs>();
CHECK(params != nullptr);
return Array<te::Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
......
......@@ -83,7 +83,10 @@ class AlterTransformMemorizer : public TransformMemorizer {
auto ttype = expr->type_as<TensorTypeNode>();
tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
}
Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos);
// TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes.
// Probably we need to disable the AlterOpLayout when compiling dynamic models.
Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos,
ref_call->checked_type());
if (altered_value.defined()) {
new_e = altered_value;
modified = true;
......
......@@ -20,9 +20,11 @@
/*!
* \file schedule_lang.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
#include <stack>
#include <unordered_set>
#include "graph.h"
......@@ -787,6 +789,53 @@ IterVarRelation SingletonNode::make(IterVar iter) {
return IterVarRelation(n);
}
SpecializedCondition::SpecializedCondition(Array<PrimExpr> conditions) {
ObjectPtr<SpecializedConditionNode> n = make_object<SpecializedConditionNode>();
n->clauses = std::move(conditions);
data_ = std::move(n);
}
/*! \brief Entry to hold the SpecializedCondition context stack. */
struct TVMSpecializationThreadLocalEntry {
/*! \brief The current specialized condition */
std::stack<SpecializedCondition> condition_stack;
};
/*! \brief Thread local store to hold the Target context stack. */
typedef dmlc::ThreadLocalStore<TVMSpecializationThreadLocalEntry> TVMSpecializationThreadLocalStore;
void SpecializedCondition::EnterWithScope() {
TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
entry->condition_stack.push(*this);
}
void SpecializedCondition::ExitWithScope() {
TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
CHECK(!entry->condition_stack.empty());
CHECK(entry->condition_stack.top().same_as(*this));
entry->condition_stack.pop();
}
SpecializedCondition SpecializedCondition::Current() {
TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get();
SpecializedCondition cond;
if (entry->condition_stack.size() > 0) {
cond = entry->condition_stack.top();
}
return cond;
}
class SpecializedCondition::Internal {
public:
static void EnterScope(SpecializedCondition cond) {
cond.EnterWithScope();
}
static void ExitScope(SpecializedCondition cond) {
cond.ExitWithScope();
}
};
TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
......@@ -794,6 +843,7 @@ TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode);
TVM_REGISTER_NODE_TYPE(SingletonNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
TVM_REGISTER_NODE_TYPE(SpecializedConditionNode);
// Printer
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -848,7 +898,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ScheduleNode*>(node.get());
p->stream << "schedule(" << op << ")";
});
})
.set_dispatch<SpecializedConditionNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SpecializedConditionNode*>(node.get());
p->stream << "specialized_condition(";
p->Print(op->clauses);
p->stream << ')';
});
TVM_REGISTER_GLOBAL("te.CreateSchedule")
......@@ -962,5 +1018,22 @@ TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite")
TVM_REGISTER_GLOBAL("te.ScheduleRFactor")
.set_body_method(&Schedule::rfactor);
TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition")
.set_body_typed([](Array<PrimExpr> condition) {
return SpecializedCondition(condition);
});
TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SpecializedCondition::Current();
});
TVM_REGISTER_GLOBAL("te.EnterSpecializationScope")
.set_body_typed(SpecializedCondition::Internal::EnterScope);
TVM_REGISTER_GLOBAL("te.ExitSpecializationScope")
.set_body_typed(SpecializedCondition::Internal::ExitScope);
} // namespace te
} // namespace tvm
......@@ -24,18 +24,56 @@
#include <tvm/relay/type.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op_strategy.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/broadcast.h>
#include <topi/generic/injective.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
TVM_REGISTER_GLOBAL("test.sch")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) {
*rv = topi::generic::schedule_injective(args[0], args[1]);
});
using namespace tvm;
using namespace tvm::relay;
TVM_REGISTER_GLOBAL("test.strategy")
.set_body_typed([](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type, const Target& target) {
FTVMCompute fcompute = [](const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) -> Array<te::Tensor> {
CHECK_EQ(inputs.size(), 2U);
return {topi::add(inputs[0], inputs[1])};
};
FTVMSchedule fschedule = [](const Attrs& attrs,
const Array<te::Tensor>& outs,
const Target& target) {
With<Target> target_scope(target);
return topi::generic::schedule_injective(target, outs);
};
auto n = make_object<OpStrategyNode>();
auto strategy = tvm::relay::OpStrategy(std::move(n));
strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10);
return strategy;
});
TVM_REGISTER_GLOBAL("relay.backend.lower_call")
.set_body_typed([](const relay::Call& call, const Array<te::Tensor>& inputs,
const Target& target) {
static auto fstrategy = Op::GetAttr<relay::FTVMStrategy>("FTVMStrategy");
Op op = Downcast<Op>(call->op);
auto out_type = call->checked_type();
OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target);
auto impl = strategy->specializations[0]->implementations[0];
auto outs = impl.Compute(call->attrs, inputs, out_type);
auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput");
if (!f) {
LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered";
}
return (*f)(outs, impl);
});
TEST(Relay, BuildModule) {
using namespace tvm;
auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type);
......@@ -59,14 +97,15 @@ TEST(Relay, BuildModule) {
}
// get schedule
auto reg = tvm::runtime::Registry::Get("relay.op._Register");
auto s_i = tvm::runtime::Registry::Get("test.sch");
if (!reg) {
LOG(FATAL) << "no _Register";
}
if (!s_i) {
LOG(FATAL) << "no _Register";
auto fs = tvm::runtime::Registry::Get("test.strategy");
if (!fs) {
LOG(FATAL) << "No test_strategy registered.";
}
(*reg)("add", "FTVMSchedule", *s_i, 10);
auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs);
(*reg)("add", "FTVMStrategy", fgeneric, 10);
// build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)();
......
......@@ -852,17 +852,22 @@ def test_forward_slice():
def test_forward_convolution():
def verify(data_shape, kernel_size, stride, pad, num_filter):
weight_shape=(num_filter, data_shape[1],) + kernel_size
def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False):
if is_depthwise:
groups = data_shape[1]
weight_shape=(data_shape[1], num_filter // groups,) + kernel_size
else:
groups = 1
weight_shape=(num_filter, data_shape[1],) + kernel_size
x = np.random.uniform(size=data_shape).astype("float32")
weight = np.random.uniform(size=weight_shape).astype("float32")
bias = np.random.uniform(size=num_filter).astype("float32")
ref_res = mx.nd.Convolution(data=mx.nd.array(x), weight=mx.nd.array(weight),
bias=mx.nd.array(bias), kernel=kernel_size, stride=stride,
pad=pad, num_filter=num_filter)
pad=pad, num_filter=num_filter, num_group=groups)
mx_sym = mx.sym.Convolution(mx.sym.var("x"), mx.sym.var("weight"), mx.sym.var("bias"),
kernel=kernel_size, stride=stride,
pad=pad, num_filter=num_filter)
pad=pad, num_filter=num_filter, num_group=groups)
shape_dict = {"x": x.shape, "weight": weight.shape, "bias": bias.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
......@@ -879,6 +884,8 @@ def test_forward_convolution():
verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=8,
is_depthwise=True)
def test_forward_deconvolution():
def verify(data_shape, kernel_size, stride, pad, num_filter):
......
......@@ -25,7 +25,7 @@ import tvm
from tvm import autotvm
from tvm.autotvm.tuner import RandomTuner
@autotvm.template
@autotvm.register_customized_task("testing/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
"""An example template for testing"""
assert N == 1, "Only consider batch_size = 1 in this template"
......@@ -114,7 +114,7 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
def get_sample_task(target=tvm.target.cuda(), target_host=None):
"""return a sample task for testing"""
task = autotvm.task.create(conv2d_no_batching,
task = autotvm.task.create("testing/conv2d_no_batching",
args=(1, 7, 7, 512, 512, 3, 3),
target=target, target_host=target_host)
return task, target
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
import tvm
from tvm import relay
......@@ -384,6 +385,8 @@ def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation
assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
# TODO(@kevinthesun): Need to fix the compute in conv2d_NCHWc to support any
@pytest.mark.skip
def test_any_conv2d_NCHWc():
verify_any_conv2d_NCHWc((relay.Any(), 8, relay.Any(), relay.Any(), 8), (8, 8, 3, 3, 8, 8), (1, 1), (1, 1), (1, 1),
"NCHW8c", "OIHW8i8o", "NCHW8c", (1, 8, 224, 224, 8), (1, 8, 224, 224, 8))
......
......@@ -39,25 +39,28 @@ def test_task_extraction():
target = 'llvm'
mod_list = []
params_list = []
conv2d = relay.op.get("nn.conv2d")
conv2d_transpose = relay.op.get("nn.conv2d_transpose")
dense = relay.op.get("nn.dense")
mod, params, _ = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.conv2d,))
ops=(conv2d,))
assert len(tasks) == 12
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.conv2d,))
ops=(conv2d,))
assert len(tasks) == 12
mod, params, _ = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.dense,))
ops=(dense,))
assert len(tasks) == 1
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.dense,))
ops=(dense,))
assert len(tasks) == 1
mod, params, _ = get_network('resnet-18', batch_size=1)
......@@ -65,11 +68,14 @@ def test_task_extraction():
params_list.append(params)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
ops=(conv2d, dense))
assert len(tasks) == 13
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
ops=(conv2d, dense))
assert len(tasks) == 13
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params)
assert len(tasks) == 13
mod, params, _ = get_network('mobilenet', batch_size=1)
......@@ -77,65 +83,19 @@ def test_task_extraction():
params_list.append(params)
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
ops=(conv2d, dense))
assert len(tasks) == 20
mod, params, _ = get_network('dcgan', batch_size=1)
tasks = autotvm.task.extract_from_program(mod, target=target,
params=params,
ops=(relay.op.nn.conv2d_transpose,))
ops=(conv2d_transpose,))
assert len(tasks) == 4
tasks = autotvm.task.extract_from_multiple_program(mod_list, params_list,
target=target,
ops=(relay.op.nn.conv2d,))
ops=(conv2d,))
assert len(tasks) == 31
def test_template_key_provided():
"""test task extraction using non-'direct' template_key"""
target = 'llvm'
import topi
template_keys = {
# topi.nn.conv2d - is left blank to test fallback logic
topi.nn.dense: 'direct_nopack',
topi.nn.depthwise_conv2d_nchw: 'direct',
}
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense),
template_keys=template_keys)
for task in tasks:
if 'dense' in task.name:
assert task.config_space.template_key == 'direct_nopack'
else:
assert task.config_space.template_key == 'direct'
def test_template_key_empty():
"""test task extraction using empty template_key"""
target = 'llvm'
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense),
template_keys=None)
for task in tasks:
assert task.config_space.template_key == 'direct'
def test_template_key_default():
"""test task extraction without template_key"""
target = 'llvm'
mod, params, _ = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
for task in tasks:
assert task.config_space.template_key == 'direct'
if __name__ == '__main__':
test_task_extraction()
test_template_key_provided()
test_template_key_empty()
test_template_key_default()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment