/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \brief Tiny graph runtime that can run graph * containing only tvm PackedFunc. * \file graph_runtime.h */ #ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_ #define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_ #include <dlpack/dlpack.h> #include <dmlc/memory_io.h> #include <dmlc/json.h> #include <tvm/runtime/ndarray.h> #include <tvm/runtime/packed_func.h> #include <memory> #include <unordered_map> #include <utility> #include <vector> #include <string> namespace tvm { namespace runtime { /*! \brief macro to do C API call */ #define TVM_CCALL(func) \ { \ int ret = (func); \ CHECK_EQ(ret, 0) \ << TVMGetLastError(); \ } /*! \brief Magic number for NDArray list file */ constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; /*! \brief operator attributes about tvm op */ struct TVMOpParam { std::string func_name; uint32_t num_inputs; uint32_t num_outputs; uint32_t flatten_data; }; /*! * \brief Tiny graph runtime. * * This runtime can be acccesibly in various language via * TVM runtime PackedFunc API. */ class GraphRuntime : public ModuleNode { struct OpArgs { std::vector<DLTensor> args; std::vector<TVMValue> arg_values; std::vector<int> arg_tcodes; std::vector<int64_t> shape_data; }; public: /*! * \brief Get member function to front-end * \param name The name of the function. * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self); /*! * \return The type key of the executor. */ const char* type_key() const final { return "GraphRuntime"; } void Run(); /*! * \brief Initialize the graph executor with graph and context. * \param graph_json The execution graph. * \param module The module containing the compiled functions for the host * processor. * \param ctxs The context of the host and devices where graph nodes will be * executed on. */ void Init(const std::string& graph_json, tvm::runtime::Module module, const std::vector<TVMContext>& ctxs); /*! * \brief Get the input index given the name of input. * \param name The name of the input. * \return The index of input. */ int GetInputIndex(const std::string& name); /*! * \brief set index-th input to the graph. * \param index The input index. * \param data_in The input data. */ void SetInput(int index, DLTensor* data_in); /*! * \brief set index-th input to the graph without copying the data * \param index The input index. * \param data_ref The input data that is referred. */ void SetInputZeroCopy(int index, DLTensor* data_ref); /*! * \brief Get the number of outputs * * \return The number of outputs from graph. */ int NumOutputs() const; /*! * \brief Return NDArray for given input index. * \param index The input index. * * \return NDArray corresponding to given input node index. */ NDArray GetInput(int index) const; /*! * \brief Return NDArray for given output index. * \param index The output index. * * \return NDArray corresponding to given output node index. */ NDArray GetOutput(int index) const; /*! * \brief Copy index-th output to data_out. * \param index The output index. * \param data_out the output data. */ void CopyOutputTo(int index, DLTensor* data_out); /*! * \brief Load parameters from binary stream * \param strm The input stream. */ void LoadParams(dmlc::Stream* strm); /*! * \brief Load parameters from parameter blob. * \param param_blob A binary blob of parameter. */ void LoadParams(const std::string& param_blob); /*! * \brief Share parameters from pre-existing GraphRuntime instance. * \param other A GraphRuntime instance, previously with |LoadParams| called with the * identical input |param_blob|. * \param strm The input stream. */ void ShareParams(const GraphRuntime& other, dmlc::Stream* strm); /*! * \brief Get total number of nodes. * \return Total number of nodes. */ uint32_t GetNumOfNodes() const { return static_cast<uint32_t>(nodes_.size()); } std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; } protected: // Memory pool entry. struct PoolEntry { size_t size; int device_type; PoolEntry(int s, int dev_type) : size(s), device_type(dev_type) {} }; // Node entry struct NodeEntry { uint32_t node_id; uint32_t index; uint32_t version; // JSON Loader void Load(dmlc::JSONReader *reader) { reader->BeginArray(); CHECK(reader->NextArrayItem()) << "invalid json format"; reader->Read(&node_id); CHECK(reader->NextArrayItem()) << "invalid json format"; reader->Read(&index); if (reader->NextArrayItem()) { reader->Read(&version); CHECK(!reader->NextArrayItem()) << "invalid json format"; } else { version = 0; } } }; // Node struct Node { // operator type in string std::string op_type; // name of the op std::string name; // parameters TVMOpParam param; // inputs std::vector<NodeEntry> inputs; // control deps std::vector<uint32_t> control_deps; // JSON Loader void LoadAttrs(dmlc::JSONReader *reader, TVMOpParam* param) { int bitmask = 0; std::string key, value; reader->BeginObject(); while (reader->NextObjectItem(&key)) { reader->Read(&value); if (key == "func_name") { param->func_name = value; bitmask |= 1; } else if (key == "num_inputs") { param->num_inputs = strtoul(value.c_str(), nullptr, 10); bitmask |= 2; } else if (key == "num_outputs") { param->num_outputs = strtoul(value.c_str(), nullptr, 10); bitmask |= 4; } else if (key == "flatten_data") { param->flatten_data = strtoul(value.c_str(), nullptr, 10); bitmask |= 8; } } CHECK_EQ(bitmask, 1|2|4|8) << "invalid format"; } // JSON Loader void Load(dmlc::JSONReader *reader) { reader->BeginObject(); int bitmask = 0; std::string key; while (reader->NextObjectItem(&key)) { if (key == "op") { reader->Read(&op_type); bitmask |= 1; } else if (key == "name") { reader->Read(&name); bitmask |= 2; } else if (key == "inputs") { reader->Read(&inputs); bitmask |= 4; } else if (key == "attr" || key == "attrs") { this->LoadAttrs(reader, ¶m); } else if (key == "control_deps") { reader->Read(&control_deps); } else { LOG(FATAL) << "do not support key " << key; } } CHECK_EQ(bitmask, 1|2|4) << "invalid format"; } }; struct GraphAttr { size_t storage_num_not_alloctaed{0}; std::vector<int> storage_id; std::vector<int> device_index; std::vector<std::string> dltype; std::vector<std::vector<int64_t> > shape; // The graph attribute fields. void Load(dmlc::JSONReader *reader) { reader->BeginObject(); int bitmask = 0; std::string key, type; while (reader->NextObjectItem(&key)) { if (key == "dltype") { reader->BeginArray(); CHECK(reader->NextArrayItem()); reader->Read(&type); CHECK_EQ(type, "list_str"); CHECK(reader->NextArrayItem()); reader->Read(&dltype); CHECK(!reader->NextArrayItem()); bitmask |= 1; } else if (key == "storage_id") { reader->BeginArray(); CHECK(reader->NextArrayItem()); reader->Read(&type); CHECK_EQ(type, "list_int"); CHECK(reader->NextArrayItem()); reader->Read(&storage_id); CHECK(!reader->NextArrayItem()); bitmask |= 2; } else if (key == "shape") { reader->BeginArray(); CHECK(reader->NextArrayItem()); reader->Read(&type); CHECK_EQ(type, "list_shape"); CHECK(reader->NextArrayItem()); reader->Read(&shape); CHECK(!reader->NextArrayItem()); bitmask |= 4; } else if (key == "device_index") { reader->BeginArray(); CHECK(reader->NextArrayItem()); reader->Read(&type); CHECK_EQ(type, "list_int"); CHECK(reader->NextArrayItem()); reader->Read(&device_index); CHECK(!reader->NextArrayItem()); } else { reader->BeginArray(); CHECK(reader->NextArrayItem()); reader->Read(&type); if (type == "list_int") { CHECK(reader->NextArrayItem()); std::vector<int> temp; reader->Read(&temp); } else if (type == "size_t") { CHECK(reader->NextArrayItem()); size_t temp; reader->Read(&temp); } else { LOG(FATAL) << "cannot skip graph attr " << key; } CHECK(!reader->NextArrayItem()); } } CHECK_EQ(bitmask, 1|2|4) << "invalid format"; } }; // The graph attribute fields. void Load(dmlc::JSONReader *reader) { reader->BeginObject(); int bitmask = 0; std::string key; while (reader->NextObjectItem(&key)) { if (key == "nodes") { reader->Read(&nodes_); bitmask |= 1; } else if (key == "arg_nodes") { reader->Read(&input_nodes_); bitmask |= 2; } else if (key == "node_row_ptr") { reader->Read(&node_row_ptr_); bitmask |= 4; } else if (key == "heads") { reader->Read(&outputs_); bitmask |= 8; } else if (key == "attrs") { reader->Read(&attrs_); bitmask |= 16; } else if (key == "metadata") { break; } else { LOG(FATAL) << "key " << key << " is not supported"; } } CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format"; } /*! \brief Setup the temporal storage */ void SetupStorage(); /*! \brief Setup the executors. */ void SetupOpExecs(); /*! * \brief Create an execution function given input. * \param attrs The node attributes. * \param args The arguments to the functor, including inputs and outputs. * \param num_inputs Number of inputs. * \return The created executor. */ std::pair<std::function<void()>, std::shared_ptr<OpArgs> > CreateTVMOp( const TVMOpParam& attrs, const std::vector<DLTensor>& args, size_t num_inputs); // Get node entry index. uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; } // Get node entry index. uint32_t entry_id(const NodeEntry& e) const { return entry_id(e.node_id, e.index); } // Number of node entries. uint32_t num_node_entries() const { return node_row_ptr_.back(); } /*! \brief The graph nodes. */ std::vector<Node> nodes_; /*! \brief The argument nodes. */ std::vector<uint32_t> input_nodes_; /*! \brief Map of input names to input indices. */ std::unordered_map<std::string, uint32_t> input_map_; /*! \brief Used for quick node input DLTensor* lookup given an input eid. */ std::vector<std::vector<DLTensor*>> input_dltensors_; /*! \brief Used for quick entry indexing. */ std::vector<uint32_t> node_row_ptr_; /*! \brief Output entries. */ std::vector<NodeEntry> outputs_; /*! \brief Additional graph attributes. */ GraphAttr attrs_; /*! \brief The code module that contains both host and device code. */ tvm::runtime::Module module_; /*! \brief Execution context of all devices including the host. */ std::vector<TVMContext> ctxs_; /*! \brief Common storage pool for all devices. */ std::vector<NDArray> storage_pool_; /*! \brief Data entry of each node. */ std::vector<NDArray> data_entry_; /*! \brief Data alignment of each node. */ std::vector<size_t> data_alignment_; /*! \brief Operator on each node. */ std::vector<std::function<void()> > op_execs_; }; std::vector<TVMContext> GetAllContext(const TVMArgs& args); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_