Commit 9fb13a69 by Tianqi Chen

[TVM] upgrade to generic schedule (#173)

parent 08e71b73
...@@ -161,7 +161,7 @@ def optimize(graph, shape, dtype="float32"): ...@@ -161,7 +161,7 @@ def optimize(graph, shape, dtype="float32"):
return graph return graph
def build(graph, target, shape, dtype="float32", params=None): def build(graph, target=None, shape=None, dtype="float32", params=None):
"""Build graph into runtime library. """Build graph into runtime library.
The build function will optimize the graph and do the compilation. The build function will optimize the graph and do the compilation.
...@@ -175,10 +175,10 @@ def build(graph, target, shape, dtype="float32", params=None): ...@@ -175,10 +175,10 @@ def build(graph, target, shape, dtype="float32", params=None):
graph : Graph graph : Graph
The graph to be used in lowering The graph to be used in lowering
target : str target : str or :any:`tvm.target.Target`, optional
The build target The build target
shape : dict of str to tuple shape : dict of str to tuple, optional
The input shape to the graph The input shape to the graph
dtype : str or dict of str to str dtype : str or dict of str to str
...@@ -201,8 +201,12 @@ def build(graph, target, shape, dtype="float32", params=None): ...@@ -201,8 +201,12 @@ def build(graph, target, shape, dtype="float32", params=None):
The updated parameters of graph if params is passed. The updated parameters of graph if params is passed.
This can be different from the params passed in. This can be different from the params passed in.
""" """
if not isinstance(target, str): target = target if target else tvm.target.current_target()
raise TypeError("require target to be str") if target is None:
raise ValueError("Target is not set in env or passed as argument.")
target = tvm.target.create(target)
shape = shape if shape else {}
if not isinstance(shape, dict): if not isinstance(shape, dict):
raise TypeError("require shape to be dict") raise TypeError("require shape to be dict")
cfg = BuildConfig.current cfg = BuildConfig.current
...@@ -223,13 +227,14 @@ def build(graph, target, shape, dtype="float32", params=None): ...@@ -223,13 +227,14 @@ def build(graph, target, shape, dtype="float32", params=None):
# Operator Fusion and generatiom # Operator Fusion and generatiom
graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype) graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", target, "str") graph._set_json_attr("target", str(target), "str")
if cfg.pass_enabled("OpFusion"): if cfg.pass_enabled("OpFusion"):
graph._set_json_attr("opt_level", 1, "int") graph._set_json_attr("opt_level", 1, "int")
else: else:
graph._set_json_attr("opt_level", 0, "int") graph._set_json_attr("opt_level", 0, "int")
graph = graph.apply("InferShape").apply("InferType") graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile") with target:
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
libmod = graph_attr._move_out_module(graph, "module") libmod = graph_attr._move_out_module(graph, "module")
return graph, libmod, params return graph, libmod, params
......
...@@ -50,10 +50,9 @@ def compute_softmax(attrs, inputs, _): ...@@ -50,10 +50,9 @@ def compute_softmax(attrs, inputs, _):
@reg.register_schedule("softmax") @reg.register_schedule("softmax")
def schedule_softmax(_, outs, target): def schedule_softmax(_, outs, target):
"""Schedule definition of softmax""" """Schedule definition of softmax"""
if target == "cuda": with tvm.target.create(target):
return topi.cuda.schedule_softmax(outs) return topi.generic.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("softmax", OpPattern.OPAQUE) reg.register_pattern("softmax", OpPattern.OPAQUE)
...@@ -68,10 +67,8 @@ def compute_log_softmax(attrs, inputs, _): ...@@ -68,10 +67,8 @@ def compute_log_softmax(attrs, inputs, _):
@reg.register_schedule("log_softmax") @reg.register_schedule("log_softmax")
def schedule_log_softmax(_, outs, target): def schedule_log_softmax(_, outs, target):
"""Schedule definition of softmax""" """Schedule definition of softmax"""
if target == "cuda": with tvm.target.create(target):
return topi.cuda.schedule_softmax(outs) return topi.generic.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
# Mark softmax as extern as we do not fuse it in call cases # Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("log_softmax", OpPattern.OPAQUE) reg.register_pattern("log_softmax", OpPattern.OPAQUE)
...@@ -87,10 +84,8 @@ def compute_dense(attrs, inputs, _): ...@@ -87,10 +84,8 @@ def compute_dense(attrs, inputs, _):
@reg.register_schedule("dense") @reg.register_schedule("dense")
def schedule_dense(_, outs, target): def schedule_dense(_, outs, target):
"""Schedule definition of dense""" """Schedule definition of dense"""
if target == "cuda": with tvm.target.create(target):
return topi.cuda.schedule_dense(outs) return topi.generic.schedule_dense(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -123,18 +118,10 @@ def compute_conv2d(attrs, inputs, _): ...@@ -123,18 +118,10 @@ def compute_conv2d(attrs, inputs, _):
def schedule_conv2d(attrs, outs, target): def schedule_conv2d(attrs, outs, target):
"""Schedule definition of conv2d""" """Schedule definition of conv2d"""
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
if target == "cuda": with tvm.target.create(target):
if groups == 1:
return topi.cuda.schedule_conv2d_nchw(outs)
return topi.cuda.schedule_depthwise_conv2d_nchw(outs)
# naive schedule
if tvm.target.current_target() == tvm.target.rasp():
if groups == 1: if groups == 1:
return topi.rasp.schedule_conv2d(outs) return topi.generic.schedule_conv2d_nchw(outs)
return topi.rasp.schedule_depthwise_conv2d(outs) return topi.generic.schedule_depthwise_conv2d_nchw(outs)
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -155,10 +142,8 @@ def compute_max_pool2d(attrs, inputs, _): ...@@ -155,10 +142,8 @@ def compute_max_pool2d(attrs, inputs, _):
@reg.register_schedule("max_pool2d") @reg.register_schedule("max_pool2d")
def schedule_max_pool2d(_, outs, target): def schedule_max_pool2d(_, outs, target):
"""Schedule definition of max_pool2d""" """Schedule definition of max_pool2d"""
if target == "cuda": with tvm.target.create(target):
return topi.cuda.schedule_pool(outs) return topi.generic.schedule_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -179,10 +164,8 @@ def compute_avg_pool2d(attrs, inputs, _): ...@@ -179,10 +164,8 @@ def compute_avg_pool2d(attrs, inputs, _):
@reg.register_schedule("avg_pool2d") @reg.register_schedule("avg_pool2d")
def schedule_avg_pool2d(_, outs, target): def schedule_avg_pool2d(_, outs, target):
"""Schedule definition of avg_pool2d""" """Schedule definition of avg_pool2d"""
if target == "cuda": with tvm.target.create(target):
return topi.cuda.schedule_pool(outs) return topi.generic.schedule_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -198,10 +181,8 @@ def compute_global_max_pool2d(attrs, inputs, _): ...@@ -198,10 +181,8 @@ def compute_global_max_pool2d(attrs, inputs, _):
@reg.register_schedule("global_max_pool2d") @reg.register_schedule("global_max_pool2d")
def schedule_global_max_pool2d(_, outs, target): def schedule_global_max_pool2d(_, outs, target):
"""Schedule definition of global_max_pool2d""" """Schedule definition of global_max_pool2d"""
if target == "cuda": with tvm.target.create(target):
return topi.cuda.schedule_global_pool(outs) return topi.generic.schedule_global_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -217,9 +198,7 @@ def compute_global_avg_pool2d(attrs, inputs, _): ...@@ -217,9 +198,7 @@ def compute_global_avg_pool2d(attrs, inputs, _):
@reg.register_schedule("global_avg_pool2d") @reg.register_schedule("global_avg_pool2d")
def schedule_global_avg_pool2d(_, outs, target): def schedule_global_avg_pool2d(_, outs, target):
"""Schedule definition of global_avg_pool2d""" """Schedule definition of global_avg_pool2d"""
if target == "cuda": with tvm.target.create(target):
return topi.cuda.schedule_global_pool(outs) return topi.generic.schedule_global_pool(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -10,14 +10,9 @@ from .registry import OpPattern ...@@ -10,14 +10,9 @@ from .registry import OpPattern
def _schedule_reduce(_, outs, target): def _schedule_reduce(_, outs, target):
"""Generic schedule for reduce""" """Generic schedule for reduce"""
if target == "cuda": with tvm.target.create(target):
return topi.cuda.schedule_reduce(outs) return topi.generic.schedule_reduce(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
_fschedule_reduce = tvm.convert(_schedule_reduce) _fschedule_reduce = tvm.convert(_schedule_reduce)
......
...@@ -10,15 +10,8 @@ from .registry import OpPattern ...@@ -10,15 +10,8 @@ from .registry import OpPattern
def _schedule_injective(_, outs, target): def _schedule_injective(_, outs, target):
"""Generic schedule for binary bcast""" """Generic schedule for binary bcast"""
if target == "cuda": with tvm.target.create(target):
return topi.cuda.schedule_injective(outs) return topi.generic.schedule_injective(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
def _compute_binary_scalar(f): def _compute_binary_scalar(f):
"""auxiliary function""" """auxiliary function"""
...@@ -174,7 +167,7 @@ reg.register_schedule("broadcast_div", _fschedule_broadcast) ...@@ -174,7 +167,7 @@ reg.register_schedule("broadcast_div", _fschedule_broadcast)
# broadcast_to # broadcast_to
@reg.register_compute("broadcast_to") @reg.register_compute("broadcast_to")
def compute_softmax(attrs, inputs, out_info): def compute_broadcast_to(attrs, inputs, out_info):
"""Compute definition of softmax""" """Compute definition of softmax"""
return topi.broadcast_to(inputs[0], shape=out_info[0].shape) return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
reg.register_pattern("broadcast_to", OpPattern.BROADCAST) reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment