Commit 889573cf by Zhixun Tan Committed by Tianqi Chen

Add the equivalence of graph_runtime.py in tvm_runtime.js (#950)

parent 549aa165
......@@ -656,6 +656,8 @@ var tvm_runtime = tvm_runtime || {};
v = convertFunc(v);
this.temp.push(v);
this.setHandle(i, v._tvm_function.handle, kFuncHandle);
} else if (v instanceof TVMModule) {
this.setHandle(i, v.handle, kModuleHandle);
} else {
throwError("Unsupported argument type " + tp);
}
......@@ -977,6 +979,107 @@ var tvm_runtime = tvm_runtime || {};
};
var loadModuleFromFile = this.loadModuleFromFile;
/**
* Wrapper runtime module.
* Wraps around set_input, load_params, run, and get_output.
*
* @class
* @memberof tvm
*/
function GraphModule(tvm_graph_module, ctx) {
CHECK(tvm_graph_module instanceof TVMModule,
"tvm_graph_module must be TVMModule");
CHECK(ctx instanceof TVMContext, "ctx must be TVMContext");
this.tvm_graph_module = tvm_graph_module;
this.ctx = ctx;
this._set_input = tvm_graph_module.getFunction("set_input");
this._load_params = tvm_graph_module.getFunction("load_params");
this._run = tvm_graph_module.getFunction("run");
this._get_output = tvm_graph_module.getFunction("get_output");
};
GraphModule.prototype = {
/**
* Set input to graph module.
*
* @param {string} key The name of the input.
* @param {NDArray} value The input value.
*/
"set_input" : function(key, value) {
CHECK(typeof key == "string", "key must be string");
CHECK(value instanceof NDArray, "value must be NDArray");
this._set_input(key, value);
},
/**
* Load parameters from serialized byte array of parameter dict.
*
* @param {Uint8Array} params The serialized parameter dict.
*/
"load_params" : function(params) {
CHECK(params instanceof Uint8Array, "params must be Uint8Array");
this._load_params(params);
},
/**
* Load parameters from serialized base64 string of parameter dict.
*
* @param {string} base64_params The serialized parameter dict.
*/
"load_base64_params" : function(base64_params) {
CHECK(typeof base64_params == "string", "base64_params must be string");
var decoded_string = atob(base64_params);
var decoded_u8 = new Uint8Array(decoded_string.length);
for (var i = 0; i < decoded_string.length; i++) {
decoded_u8[i] = decoded_string[i].charCodeAt(0);
}
this.load_params(decoded_u8);
},
/**
* Run forward execution of the graph.
*/
"run" : function() {
this._run();
},
/**
* Get index-th output to out.
*
* @param {NDArray} out The output array container.
* @return {NDArray} The output array container.
*/
"get_output" : function(index, out) {
CHECK(typeof index == "number", "index must be number");
CHECK(out instanceof NDArray, "out must be NDArray");
this._get_output(new TVMConstant(index, "int32"), out);
return out;
}
};
/**
* Create a runtime executor module given a graph and a module.
* @param {string} graph_json_str The Json string of the graph.
* @param {TVMModule} libmod The TVM module.
* @param {TVMContext} ctx The context to deploy the module.
* @return {GraphModule} Runtime graph module for executing the graph.
*/
this.createGraphRuntime = function(graph_json_str, libmod, ctx) {
CHECK(typeof graph_json_str == "string", "graph_json_str must be string");
CHECK(libmod instanceof TVMModule, "libmod must be TVMModule");
CHECK(ctx instanceof TVMContext, "ctx must be TVMContext");
var fcreate = getGlobalFunc("tvm.graph_runtime.create");
CHECK(fcreate != null, "Cannot find tvm.graph_runtime.create");
var tvm_graph_module = fcreate(graph_json_str, libmod,
new TVMConstant(ctx.device_type, "int32"),
new TVMConstant(ctx.device_id, "int32"));
return new GraphModule(tvm_graph_module, ctx);
};
//-----------------------------------------
// Class defintions
// ----------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment