Commit f1e0a55a by Yao Wang Committed by Tianqi Chen

Move dense compute back to python (#364)

parent 5884cd01
...@@ -51,6 +51,13 @@ reg.register_pattern("log_softmax", OpPattern.OPAQUE) ...@@ -51,6 +51,13 @@ reg.register_pattern("log_softmax", OpPattern.OPAQUE)
# dense # dense
@reg.register_compute("dense")
def compute_dense(attrs, inputs, _):
"""Compute definition of dense"""
if attrs.get_bool("use_bias"):
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
return topi.nn.dense(inputs[0], inputs[1])
@reg.register_schedule("dense") @reg.register_schedule("dense")
def schedule_dense(_, outs, target): def schedule_dense(_, outs, target):
"""Schedule definition of dense""" """Schedule definition of dense"""
......
...@@ -82,21 +82,6 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored. ...@@ -82,21 +82,6 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored.
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>)
.set_attr<FInferShape>("FInferShape", DenseInferShape) .set_attr<FInferShape>("FInferShape", DenseInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
Tensor bias_val;
Tensor* bias;
const DenseParam& param = nnvm::get<DenseParam>(attrs.parsed);
if (param.use_bias) {
bias_val = inputs[2];
bias = &bias_val;
} else {
bias = nullptr;
}
return Array<Tensor>{ topi::nn::dense(inputs[0], inputs[1], bias) };
})
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment