Commit 007a06db by tqchen Committed by Tianqi Chen

[FIX] Make master compile

parent 40901446
......@@ -141,9 +141,9 @@ def compute_max_pool2d(attrs, inputs, _):
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs["ceil_mode"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
assert ceil_mode == "False", "not support ceil_mode now"
assert not ceil_mode, "not support ceil_mode now"
return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='max')
@reg.register_schedule("max_pool2d")
......@@ -165,9 +165,9 @@ def compute_avg_pool2d(attrs, inputs, _):
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
layout = attrs["layout"]
ceil_mode = attrs["ceil_mode"]
ceil_mode = attrs.get_bool("ceil_mode")
assert layout == "NCHW", "only support nchw for now"
assert ceil_mode == "False", "not support ceil_mode now"
assert not ceil_mode, "not support ceil_mode now"
return topi.nn.pool(inputs[0], pool_size, strides, padding, pool_type='avg')
@reg.register_schedule("avg_pool2d")
......
......@@ -13,13 +13,29 @@
#include <nnvm/compiler/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/lowered_func.h>
#include <dmlc/parameter.h>
#include "./compile_engine.h"
#include "../../tvm/src/runtime/graph/graph_runtime.h"
namespace nnvm {
namespace compiler {
using tvm::runtime::TVMOpParam;
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs).set_default(1);
DMLC_DECLARE_FIELD(num_outputs).set_default(1);
DMLC_DECLARE_FIELD(flatten_data).set_default(0);
}
};
DMLC_REGISTER_PARAMETER(TVMOpParam);
// parser
inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) {
......@@ -368,7 +384,7 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
nnvm::NodePtr np = nnvm::Node::Create();
np->attrs.op = tvm_op;
np->attrs.name = inode.source->attrs.name;
runtime::TVMOpParam param;
TVMOpParam param;
param.func_name = fe.compiled_func->func_name;
param.num_inputs = static_cast<uint32_t>(fe.imap.size());
param.num_outputs = static_cast<uint32_t>(fe.subgraph.outputs.size());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment