Commit 0a410a39 by Zhixun Tan Committed by Tianqi Chen

[WIP] Add OpenGL topi. (#836)

[TOPI][GL] OpenGL topi.
parent fb556ef4
......@@ -226,6 +226,13 @@ constexpr const char* channel_write_advance = "channel_write_advance";
constexpr const char* pipeline_stage_scope = "pipeline_stage_scope";
/*! \brief pipeline execution scope, implies the scope can be pipelined. */
constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
/*!
* \brief Mark that this stage is an OpenGL shader. Since OpenGL shader only
* allows writing out to one element of the output texture, the Provide node
* gets translated to a special Call::glsl_texture_store statement instead of a
* Store statement.
*/
constexpr const char* opengl_stage_scope = "opengl_stage_scope";
} // namespace attr
/*! \brief namespace of TVM Intrinsic functions */
......
......@@ -427,6 +427,8 @@ class StageNode : public Node {
std::string scope;
/*! \brief Whether this is an output stage */
bool is_output{false};
/*! \brief Whether this is an OpenGL stage */
bool is_opengl{false};
/*! \brief Whether apply double buffer optimization to this stage */
bool double_buffer{false};
/*!
......@@ -450,6 +452,7 @@ class StageNode : public Node {
v->Visit("attach_stage", &attach_stage);
v->Visit("scope", &scope);
v->Visit("is_output", &is_output);
v->Visit("is_opengl", &is_opengl);
v->Visit("double_buffer", &double_buffer);
v->Visit("group", &group);
v->Visit("num_child_stages", &num_child_stages);
......
......@@ -280,6 +280,17 @@ def mali(options=None):
return Target("opencl", opts)
def opengl(options=None):
"""Returns a OpenGL target.
Parameters
----------
options : list of str
Additional options
"""
return Target("opengl", options)
def create(target_str):
"""Get a target given target string.
......
......@@ -168,31 +168,9 @@ void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) {
this->stream << "}\n";
}
// GLSL texture store is special. We can only store to one output texture, and
// we must store to the index that matches the current "thread index".
void CodeGenOpenGL::VisitStmt_(const Store* op) {
auto t = op->value.type();
auto buffer = op->buffer_var.get();
auto index = op->index;
if (t.lanes() == 1) {
// Store to a scalar.
CHECK(inputs_.find(buffer) == inputs_.cend())
<< "Texture has been read from before. Must not store to it.";
if (output_ == nullptr) {
output_ = buffer; // Record that this texture is the output.
} else {
CHECK(output_ == buffer) << "GLSL can only write to 1 texture.";
}
this->PrintIndent();
this->stream << GetBufferRef(t, buffer, index) << " = "
<< PrintExpr(op->value) << ";\n";
} else {
// Store to a vector.
LOG(FATAL) << "Vectorized store not implemented.";
}
LOG(FATAL) << "Store statement not supported in OpenGL."
<< " Texture store should be a Call statement.";
}
// texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r
......@@ -215,8 +193,6 @@ std::string CodeGenOpenGL::GetBufferRef(
if (buffer == this->output_) {
// This is the output texture.
CHECK_EQ(index.get(), output_iter_var_)
<< "GLSL must access corresponding elem of output texture.";
return GetVarID(buffer);
} else {
// This is an input texture.
......@@ -265,5 +241,33 @@ void CodeGenOpenGL::VisitExpr_(const StringImm*, std::ostream& os) {
LOG(FATAL) << "GLSL 3.0 doesn't support strings.";
}
void CodeGenOpenGL::VisitStmt_(const Evaluate* op) {
auto call = op->value.as<Call>();
if (call == nullptr || call->name != Call::glsl_texture_store) {
// Fallback to normal logic.
CodeGenC::VisitStmt_(op);
}
CHECK_EQ(call->args.size(), 2);
auto buffer = call->args[0].as<Variable>();
auto value = call->args[1];
// Doesn't support store to vector.
auto type = value.type();
CHECK_EQ(type.lanes(), 1)
<< "Vectorized store not implemented, type = " << type;
CHECK(inputs_.find(buffer) == inputs_.cend())
<< "Texture has been read from before. Must not store to it.";
if (output_ == nullptr) {
output_ = buffer; // Record that this texture is the output.
} else {
CHECK(output_ == buffer) << "GLSL can only write to 1 texture.";
}
this->PrintIndent();
this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n";
}
} // namespace codegen
} // namespace tvm
......@@ -34,6 +34,9 @@ class CodeGenOpenGL final : public CodeGenC {
void VisitExpr_(const FloatImm* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const StringImm* op, std::ostream& os) final; // NOLINT(*)
// Match glsl_texture_store Call.
void VisitStmt_(const Evaluate* op) final; // NOLINT(*)
private:
const Variable* output_{nullptr};
std::unordered_set<const Variable*> inputs_;
......
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_opencl.cc
* \brief OpenCL intrinsic rules.
*/
#include "./intrin_rule.h"
namespace tvm {
namespace codegen {
namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sqrt")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount")
.set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -29,7 +29,8 @@ using intrinsic::tvm_address_of;
class StorageFlattener : public IRMutator {
public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer, int cache_line_size) {
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
int cache_line_size) {
for (auto kv : extern_buffer) {
BufferEntry e;
e.buffer = kv.second;
......@@ -38,6 +39,7 @@ class StorageFlattener : public IRMutator {
}
cache_line_size_ = cache_line_size;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
......@@ -90,6 +92,8 @@ class StorageFlattener : public IRMutator {
vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
return this->Mutate(op->body);
} else if (op->attr_key == attr::opengl_stage_scope) {
is_opengl_ = true;
}
return IRMutator::Mutate_(op, s);
}
......@@ -104,7 +108,15 @@ class StorageFlattener : public IRMutator {
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
return e.buffer.vstore(e.RelIndex(op->args), op->value);
if (is_opengl_) {
return Evaluate::make(Call::make(
Type(),
Call::glsl_texture_store,
{e.buffer->data, op->value},
Call::Intrinsic));
} else {
return e.buffer.vstore(e.RelIndex(op->args), op->value);
}
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
......@@ -421,6 +433,8 @@ class StorageFlattener : public IRMutator {
std::vector<ThreadScope> curr_thread_scope_;
// The size of cacheline
int cache_line_size_;
// The current stage is an OpenGL shader.
bool is_opengl_{false};
};
Stmt StorageFlatten(Stmt stmt,
......
......@@ -281,7 +281,7 @@ GLuint OpenGLWorkspace::CreateShader(GLenum shader_kind,
if (err != GL_TRUE) {
std::unique_ptr<char[]> err_msg(new char[info_log_len + 1]);
gl->GetShaderInfoLog(shader, info_log_len, nullptr, err_msg.get());
LOG(FATAL) << err_msg.get();
LOG(FATAL) << err_msg.get() << "\n" << shader_src;
assert(false);
}
......
......@@ -433,6 +433,9 @@ Stage& Stage::opengl() {
// Bind the only dimension to threadIdx.x.
bind(fused, thread_axis(Range(nullptr), "threadIdx.x"));
// Mark this stage as OpenGL.
(*this)->is_opengl = true;
return *this;
}
......
......@@ -44,6 +44,11 @@ Stmt MakePipeline(const Stage& s,
s->op, ir::attr::realize_scope,
StringImm::make(s->scope),
pipeline);
if (s->is_opengl) {
pipeline = AttrStmt::make(
s->op, ir::attr::opengl_stage_scope, StringImm::make(""), pipeline);
}
return pipeline;
}
......
"""Example code to do convolution.
Copied from topi/tests/python/test_topi_conv2d_nchw.py.
Should be removed once we fix OpenGL testing on Jenkins."""
import os
import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d_nchw(A, W, stride, padding)
C = topi.nn.relu(B)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s1 = topi.generic.schedule_conv2d_nchw([B])
s2 = topi.generic.schedule_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=1400,
unroll_explicit=(device != "cuda")):
func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
func1(a, w, b)
func2(a, w, c)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['opengl']:
check_device(device)
def test_conv2d_nchw():
# ResNet18 worklaods
verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1)
verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0)
verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1)
verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0)
verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
# Vgg16 workloads
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1)
# Super resolution workloads
verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2)
verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1)
verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1)
if __name__ == "__main__":
test_conv2d_nchw()
"""Test code for dense operator
Copied from topi/tests/python/test_topi_dense.py.
Should be removed once we fix OpenGL testing on Jenkins.
"""
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
def verify_dense(batch, in_dim, out_dim, use_bias=True):
A = tvm.placeholder((batch, in_dim), name='A')
B = tvm.placeholder((out_dim, in_dim), name='B')
C = tvm.placeholder((out_dim,), name='C')
D = topi.nn.dense(A, B, C if use_bias else None)
D = topi.nn.relu(D)
dtype = A.dtype
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_dense")
def get_ref_data():
a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype)
b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype)
c_np = np.random.uniform(size=(out_dim,)).astype(dtype)
if use_bias:
d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0)
else:
d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
return (a_np, b_np, c_np, d_np)
# get the test data
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_dense(D)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B, C, D], device, name="dense")
f(a, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
for device in ['opengl']:
check_device(device)
def test_dense():
verify_dense(1, 1024, 1000, use_bias=True)
verify_dense(1, 1024, 1000, use_bias=False)
if __name__ == "__main__":
test_dense()
"""Test code for pooling
Copied from topi/tests/python/test_topi_pooling.py.
Should be removed once we fix OpenGL testing on Jenkins.
"""
import numpy as np
import tvm
import topi
import math
from topi.util import get_const_tuple
def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
iw = ih
kw = kh
sw = sh
ph, pw = padding
A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode)
B = topi.nn.relu(B)
dtype = A.dtype
bshape = get_const_tuple(B.shape)
ashape = get_const_tuple(A.shape)
if ceil_mode:
assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1)
assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1)
else:
assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1)
assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1)
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype)
no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
pad_np[np.ix_(*no_zero)] = a_np
_, oc, oh, ow = get_const_tuple(B.shape)
b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)
if pool_type == 'avg':
for i in range(oh):
for j in range(ow):
b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
elif pool_type =='max':
for i in range(oh):
for j in range(ow):
b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
b_np = np.maximum(b_np, 0.0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_pool(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
print(tvm.lower(s, [A, B], simple_mode=True))
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['opengl']:
check_device(device)
def test_pool():
verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False)
verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False)
verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True)
def verify_global_pool(n, c, h, w, pool_type):
A = tvm.placeholder((n, c, h, w), name='A')
B = topi.nn.global_pool(A, pool_type=pool_type)
B = topi.nn.relu(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
if pool_type == 'avg':
b_np = np.mean(a_np, axis=(2,3), keepdims=True)
elif pool_type =='max':
b_np = np.max(a_np, axis=(2,3), keepdims=True)
b_np = np.maximum(b_np, 0.0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_global_pool(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['opengl']:
check_device(device)
def test_global_pool():
verify_global_pool(1, 1024, 7, 7, 'avg')
verify_global_pool(4, 1024, 7, 7, 'avg')
verify_global_pool(1, 1024, 7, 7, 'max')
verify_global_pool(4, 1024, 7, 7, 'max')
if __name__ == "__main__":
test_pool()
test_global_pool()
"""Test code for softmax
Copied from topi/tests/python/test_topi_softmax.py.
Should be removed once we fix OpenGL testing on Jenkins.
"""
import os
import numpy as np
import tvm
import topi
import logging
from topi.util import get_const_tuple
def verify_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.nn.softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="softmax")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ["opengl"]:
check_device(device)
def test_softmax():
verify_softmax(32, 10)
verify_softmax(3, 4)
def verify_log_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.nn.log_softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_softmax(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="log_softmax")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ["opengl"]:
check_device(device)
def test_log_softmax():
verify_log_softmax(32, 10)
verify_log_softmax(3, 4)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
test_softmax()
test_log_softmax()
......@@ -21,6 +21,7 @@ from . import x86
from . import cuda
from . import rasp
from . import mali
from . import opengl
from . import testing
from . import util
from . import rocm
......
# pylint: disable=redefined-builtin, wildcard-import
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .conv2d_nchw import schedule_conv2d_nchw
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .softmax import schedule_softmax
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_global_pool
#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long
"""Schedule for conv2d_nchw with auto fusion"""
import tvm
from .. import tag
from .. import generic
@generic.schedule_conv2d_nchw.register(["opengl"])
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw.
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_nchw
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d_nchw.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(conv2d, data):
if conv2d.op in s.outputs:
Out = conv2d
else:
Out = outs[0].op.output(0)
s[conv2d].opengl()
s[Out].opengl()
s[data].opengl()
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].opengl()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule conv2d_nchw
elif OP.tag.startswith('conv2d_nchw'):
conv2d = OP.output(0)
data = OP.input_tensors[0]
_schedule(conv2d, data)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
return s
# pylint: disable=invalid-name, unused-variable
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic
@generic.schedule_dense.register(["opengl"])
def schedule_dense(outs):
"""Schedule for dense operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of dense
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for dense.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(Dense):
if Dense.op in s.outputs:
Out = Dense
else:
Out = outs[0].op.output(0)
s[Dense].opengl()
s[Out].opengl()
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
Dense = OP.output(0)
_schedule(Dense)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
return s
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import tvm
from .. import generic
def _schedule_injective(op, sch):
x = op.output(0)
sch[x].opengl()
return sch
@generic.schedule_injective.register(["opengl"])
def schedule_injective(outs):
"""Schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
for out in outs:
_schedule_injective(out.op, s)
return s
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
# pylint: disable=invalid-name, unused-variable
"""Schedule for pooling operators"""
import tvm
from .. import tag
from .. import generic
@generic.schedule_global_pool.register(["opengl"])
def schedule_global_pool(outs):
"""Schedule for global_pool.
Parameters
----------
outs: Array of Tensor
The computation graph description of global_pool
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for global_pool.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(Pool):
if Pool.op in s.outputs:
Out = Pool
else:
Out = outs[0].op.output(0)
s[Pool].opengl()
s[Out].opengl()
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].opengl()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('global_pool'):
Pool = OP.output(0)
_schedule(Pool)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
return s
@generic.schedule_pool.register(["opengl"])
def schedule_pool(outs):
"""Schedule for pool.
Parameters
----------
outs: Array of Tensor
The computation graph description of pool
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for pool.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Pool):
s[PaddedInput].opengl()
if Pool.op in s.outputs:
Out = Pool
else:
Out = outs[0].op.output(0)
s[Pool].opengl()
s[Out].opengl()
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
PaddedInput = OP.input_tensors[0]
Pool = OP.output(0)
_schedule(PaddedInput, Pool)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
return s
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator"""
import tvm
from .. import generic
@generic.schedule_softmax.register(["opengl"])
def schedule_softmax(outs):
"""Schedule for softmax op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
softmax = outs[0]
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
s[max_elem].opengl()
s[expsum].opengl()
s[softmax].opengl()
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment