/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2019 by Contributors * \file src/relay/backend/vm/vm.cc * \brief The Relay virtual machine. */ #include <tvm/relay/interpreter.h> #include <tvm/logging.h> #include <tvm/relay/module.h> #include <tvm/runtime/vm.h> #include <tvm/relay/analysis.h> namespace tvm { namespace relay { namespace vm { runtime::vm::VirtualMachine CompileModule(const Module& mod); using tvm::runtime::Object; using tvm::runtime::ObjectTag; using tvm::runtime::vm::VirtualMachine; VirtualMachine FromModule(const Module& module, const std::vector<TVMContext>& ctxs) { auto vm = CompileModule(module); vm.Init(ctxs); return vm; } Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs, const std::vector<Object>& vm_args) { VirtualMachine vm = FromModule(module, ctxs); // TODO(zhiics): This measurement is for temporary usage. Remove it later. We // need to introduce a better profiling method. #if ENABLE_PROFILING DLOG(INFO) << "Entry function is main." << std::endl; auto start = std::chrono::high_resolution_clock::now(); #endif // ENABLE_PROFILING Object res = vm.Invoke("main", vm_args); #if ENABLE_PROFILING auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count(); LOG(INFO) << "Inference time: " << duration << "ms\n"; #endif // ENABLE_PROFILING return res; } Value VMToValue(const relay::Module& module, Object obj) { CHECK(module.defined()); switch (obj->tag) { case ObjectTag::kTensor: { return TensorValueNode::make(ToNDArray(obj)); } case ObjectTag::kDatatype: { const auto& data_type = obj.AsDatatype(); tvm::Array<Value> fields; for (size_t i = 0; i < data_type->fields.size(); ++i) { fields.push_back(VMToValue(module, data_type->fields[i])); } return ConstructorValueNode::make(data_type->tag, fields); } default: LOG(FATAL) << "unsupported return value of type: " << obj->tag; return Value(); } } TVM_REGISTER_API("relay._vm._Tensor").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = Object::Tensor(args[0]); }); TVM_REGISTER_API("relay._vm._Tuple").set_body([](TVMArgs args, TVMRetValue* ret) { std::vector<Object> fields; for (auto i = 0; i < args.size(); i++) { fields.push_back(args[i]); } *ret = Object::Tuple(fields); }); template <typename T> std::string ToString(const T& t) { std::stringstream s; s << t; return s.str(); } TVM_REGISTER_API("relay._vm._ObjectTag").set_body([](TVMArgs args, TVMRetValue* ret) { Object obj = args[0]; *ret = ToString(obj->tag); }); TVM_REGISTER_API("relay._vm._Datatype") .set_body([](TVMArgs args, TVMRetValue* ret) { int itag = args[0]; size_t tag = static_cast<size_t>(itag); std::vector<Object> fields; for (int i = 1; i < args.size(); i++) { fields.push_back(args[i]); } *ret = Object::Datatype(tag, fields); }); TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef to_compile = args[0]; TVMContext ctx; int dev_type = args[1]; ctx.device_type = static_cast<DLDeviceType>(dev_type); ctx.device_id = args[2]; Module module; if (to_compile.as<FunctionNode>()) { Function to_compile = args[0]; module = ModuleNode::FromExpr(to_compile); } else if (to_compile.as<ModuleNode>()) { module = args[0]; } else { LOG(FATAL) << "expected function or module"; } std::vector<Object> vm_args; for (auto i = 3; i < args.size(); i++) { Object obj = args[i]; vm_args.push_back(obj); } auto result = EvaluateModule(module, {ctx}, vm_args); DLOG(INFO) << "Evaluate VM returning: result=" << result->tag; *ret = VMToValue(module, result); }); } // namespace vm } // namespace relay } // namespace tvm