/*! * Copyright (c) 2017 by Contributors * \file packed_func_ext.cc * \brief Registeration of extension type. */ #include <tvm/expr.h> #include <tvm/packed_func_ext.h> #include <nnvm/op.h> #include <nnvm/compiler/packed_func_ext.h> #include <nnvm/compiler/op_attr_types.h> #include <tvm/runtime/c_runtime_api.h> #include "./node_attr.h" #include "compile_engine.h" namespace tvm { namespace runtime { TVM_REGISTER_EXT_TYPE(nnvm::Graph); TVM_REGISTER_EXT_TYPE(nnvm::Symbol); TVM_REGISTER_EXT_TYPE(nnvm::compiler::AttrDict); } // namespace runtime } // namespace tvm namespace nnvm { DMLC_JSON_ENABLE_ANY(int, int); } // namespace nnvm namespace nnvm { namespace compiler { using tvm::Tensor; using tvm::Array; using tvm::Node; using tvm::runtime::TVMArgs; using tvm::runtime::TVMRetValue; TVM_REGISTER_GLOBAL("nnvm.compiler._dict_get") .set_body([](TVMArgs args, TVMRetValue *rv) { const AttrDict& dict = args[0].AsExtension<AttrDict>(); std::string key = args[1]; auto it = dict.find(key); if (it != dict.end()) { *rv = it->second; } else { *rv = nullptr; } }); TVM_REGISTER_GLOBAL("nnvm.compiler._dict_size") .set_body([](TVMArgs args, TVMRetValue *rv) { const AttrDict& dict = args[0].AsExtension<AttrDict>(); *rv = static_cast<int64_t>(dict.size()); }); TVM_REGISTER_GLOBAL("nnvm.compiler._dict_keys") .set_body([](TVMArgs args, TVMRetValue *rv) { const AttrDict& dict = args[0].AsExtension<AttrDict>(); tvm::Array<tvm::Expr> keys; for (const auto& kv : dict) { keys.push_back(kv.first); } *rv = keys; }); TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout") .set_body([](TVMArgs args, TVMRetValue *rv) { // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* f = new PackedFunc(args[1].operator PackedFunc()); Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]); auto fpack = [f](const NodeAttrs& attrs, const Symbol& inputs, const Array<Tensor>& tinfos, Symbol* ret_symbol) { TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, tinfos); if (ret.type_code() == TVMTypeCode::kNull) { return false; } CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code) << " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code << ") but get code = " << ret.type_code(); *ret_symbol = *(static_cast<Symbol*>(ret.value().v_handle)); return true; }; op.set_attr<FTVMAlterOpLayout>("FTVMAlterOpLayout", fpack, args[2]); }); // custom version of TVM compute TVM_REGISTER_GLOBAL("nnvm._register_compute") .set_body([](TVMArgs args, TVMRetValue *rv) { // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* f = new PackedFunc(args[1].operator PackedFunc()); Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]); auto fcompute = [f](const NodeAttrs& attrs, const Array<Tensor>& inputs, const Array<Tensor>& out_info) -> Array<Tensor> { TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info); if ((*ret.ptr<std::shared_ptr<tvm::Node> >())->derived_from<tvm::TensorNode>()) { return {ret.operator Tensor()}; } else { return ret; } }; op.set_attr<FTVMCompute>("FTVMCompute", fcompute, args[2]); }); TVM_REGISTER_GLOBAL("nnvm._register_schedule") .set_body([](TVMArgs args, TVMRetValue *rv) { // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* f = new PackedFunc(args[1].operator PackedFunc()); Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]); auto fschedule = [f](const NodeAttrs& attrs, const Array<Tensor>& outs, const std::string& target) { return (*f)(GetAttrDict(attrs), outs, target).operator Schedule(); }; op.set_attr<FTVMSchedule>("FTVMSchedule", fschedule, args[2]); }); TVM_REGISTER_GLOBAL("nnvm._register_pattern") .set_body([](TVMArgs args, TVMRetValue *rv) { Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]); op.set_attr<TOpPattern>("TOpPattern", args[1].operator int(), args[2]); }); TVM_REGISTER_GLOBAL("nnvm.graph._move_module") .set_body([](TVMArgs args, TVMRetValue *rv) { const nnvm::Graph& g = args[0].AsExtension<Graph>(); *rv = const_cast<nnvm::Graph*>(&g)-> MoveCopyAttr<tvm::runtime::Module>(args[1]); }); TVM_REGISTER_GLOBAL("nnvm.graph._move_graph") .set_body([](TVMArgs args, TVMRetValue *rv) { const nnvm::Graph& g = args[0].AsExtension<Graph>(); std::string key = args[1]; if (g.attrs.count(key)) { *rv = const_cast<nnvm::Graph*>(&g)-> MoveCopyAttr<nnvm::Graph>(key); } else { *rv = nullptr; } }); } // namespace compiler } // namespace nnvm