Commit 1bc5d0ad by Yuwei HU Committed by Tianqi Chen

register softmax (#16)

parent 48038a9c
......@@ -23,7 +23,6 @@ def compute_conv2d(attrs, inputs):
out = topi.broadcast_add(out, bias)
return out
@reg.register_schedule("conv2d")
def schedule_conv2d(_, outs, target):
"""Schedule definition of conv2d"""
......@@ -33,3 +32,22 @@ def schedule_conv2d(_, outs, target):
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("conv2d", OpPattern.COMPLEX)
# softmax
@reg.register_compute("softmax")
def compute_softmax(attrs, inputs):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.softmax(inputs[0])
@reg.register_schedule("softmax")
def schedule_softmax(_, outs, target):
"""Schedule definition of softmax"""
if target == "cuda":
return topi.cuda.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("softmax", OpPattern.COMPLEX)
import numpy as np
import tvm
import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
USE_GPU=True
def default_target():
if USE_GPU:
return 'cuda'
else:
return 'llvm'
def default_ctx():
if USE_GPU:
return tvm.gpu(0)
else:
return tvm.cpu(0)
def test_softmax():
x = sym.Variable("x")
y = sym.softmax(x)
dtype = "float32"
dshape = (10, 1000)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get outputs
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = topi.testing.softmax_python(data.asnumpy())
np.testing.assert_allclose(out.asnumpy(), y_np, rtol=1e-5)
if __name__ == "__main__":
test_softmax()
......@@ -5,7 +5,7 @@ def test_dense():
x1 = sym.dense(x, units=3, name="dense")
x2 = sym.flatten(x1)
x3 = sym.softmax(x2)
assert x2.list_input_names() == ['x', 'dense_weight', 'dense_bias']
assert x3.list_input_names() == ['x', 'dense_weight', 'dense_bias']
def test_concatenate_split():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment