Commit 8c9758b6 by Tianqi Chen

Update Graph Support for Batching, Fix Swapping (#37)

* fix graph transform for batch dimension

* fix

* fix
parent a96a4a9b
...@@ -3,6 +3,7 @@ import nnvm ...@@ -3,6 +3,7 @@ import nnvm
import tvm import tvm
from nnvm.compiler import graph_attr from nnvm.compiler import graph_attr
import vta import vta
import vta.testing
import os import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -12,7 +13,8 @@ import logging ...@@ -12,7 +13,8 @@ import logging
import wget import wget
from tvm.contrib import graph_runtime, rpc, util from tvm.contrib import graph_runtime, rpc, util
factor = 16 bfactor = 1
cfactor = 16
host = "pynq" host = "pynq"
port = 9091 port = 9091
verbose = False verbose = False
...@@ -38,6 +40,10 @@ if verbose: ...@@ -38,6 +40,10 @@ if verbose:
target = tvm.target.create("llvm -device=vta") target = tvm.target.create("llvm -device=vta")
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon" target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
if vta.get_env().TARGET == "sim":
target_host = "llvm"
synset = eval(open(os.path.join(CATEG_FILE)).read()) synset = eval(open(os.path.join(CATEG_FILE)).read())
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224)) image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
...@@ -105,7 +111,7 @@ sym = vta.graph.remove_stochastic(sym) ...@@ -105,7 +111,7 @@ sym = vta.graph.remove_stochastic(sym)
sym = vta.graph.clean_cast(sym) sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym) sym = vta.graph.clean_conv_fuse(sym)
if target.device_name == "vta": if target.device_name == "vta":
sym = vta.graph.pack(sym, shape_dict, factor) sym = vta.graph.pack(sym, shape_dict, bfactor, cfactor)
graph_attr.set_shape_inputs(sym, shape_dict) graph_attr.set_shape_inputs(sym, shape_dict)
sym = sym.apply("InferShape") sym = sym.apply("InferShape")
...@@ -127,7 +133,13 @@ with nnvm.compiler.build_config(opt_level=3): ...@@ -127,7 +133,13 @@ with nnvm.compiler.build_config(opt_level=3):
assert tvm.module.enabled("rpc") assert tvm.module.enabled("rpc")
temp = util.tempdir() temp = util.tempdir()
lib.save(temp.relpath("graphlib.o")) lib.save(temp.relpath("graphlib.o"))
remote = rpc.connect(host, port)
if vta.get_env().TARGET == "sim":
remote = rpc.LocalSession()
print("local session")
else:
remote = rpc.connect(host, port)
remote.upload(temp.relpath("graphlib.o")) remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o") lib = remote.load_module("graphlib.o")
ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0) ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
...@@ -154,16 +166,17 @@ def run_e2e(graph): ...@@ -154,16 +166,17 @@ def run_e2e(graph):
print("t-cost=%g" % tcost.mean) print("t-cost=%g" % tcost.mean)
def run_layer(old_graph): def run_layer(old_graph, layer_begin, layer_end):
"""Run a certain layer.""" """Run a certain layer."""
for layer_id in range(1, 2): for layer_id in range(layer_begin, layer_end):
print("run resnet[%d]..."% (layer_id))
graph = mark_nop(old_graph, layer_id) graph = mark_nop(old_graph, layer_id)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# set inputs # set inputs
m.set_input('data', tvm.nd.array(x.astype("float32"))) m.set_input('data', tvm.nd.array(x.astype("float32")))
m.set_input(**params) m.set_input(**params)
# execute # execute
timer = m.module.time_evaluator("run", ctx, number=10) timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer() tcost = timer()
print("resnet[%d]: %g\n"% (layer_id, tcost.mean)) print("resnet[%d]: %g\n"% (layer_id, tcost.mean))
......
...@@ -10,51 +10,58 @@ import nnvm ...@@ -10,51 +10,58 @@ import nnvm
from nnvm.compiler import graph_attr, graph_util from nnvm.compiler import graph_attr, graph_util
def _pack_channel(data, dshape, factor): def _pack_batch_channel(data, dshape, bfactor, cfactor):
"""Pack the data channel dimension. """Pack the data channel dimension.
""" """
assert dshape[1] % factor == 0 assert dshape[0] % bfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data, data = nnvm.sym.reshape(data,
shape=(dshape[0], dshape[1] // factor, shape=(dshape[0] // bfactor, bfactor,
factor, dshape[2], dshape[3])) dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose( data = nnvm.sym.transpose(
data, axes=(0, 1, 3, 4, 2)) data, axes=(0, 2, 4, 5, 1, 3))
return data return data
def _unpack_channel(data, old_shape): def _unpack_batch_channel(data, old_shape):
"""Unpack the data channel dimension. """Unpack the data channel dimension.
""" """
data = nnvm.sym.transpose(data, axes=(0, 1, 4, 2, 3)) data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3))
data = nnvm.sym.reshape(data, shape=old_shape) data = nnvm.sym.reshape(data, shape=old_shape)
return data return data
def _pack_weight(data, dshape, factor): def _pack_weight(data, dshape, cfactor):
"""Pack the weight into packed format. """Pack the weight into packed format.
""" """
assert len(dshape) == 4 assert len(dshape) == 4
assert dshape[0] % factor == 0 assert dshape[0] % cfactor == 0
assert dshape[1] % factor == 0 assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data, data = nnvm.sym.reshape(data,
shape=(dshape[0] // factor, factor, shape=(dshape[0] // cfactor, cfactor,
dshape[1] // factor, factor, dshape[1] // cfactor, cfactor,
dshape[2], dshape[3])) dshape[2], dshape[3]))
data = nnvm.sym.transpose( data = nnvm.sym.transpose(
data, axes=(0, 2, 4, 5, 1, 3)) data, axes=(0, 2, 4, 5, 1, 3))
return data return data
def _pack_bias(data, dshape, factor): def _pack_bias(data, dshape, bfactor, cfactor):
"""Pack the bias parameter. """Pack the bias parameter.
""" """
assert len(dshape) == 3 assert len(dshape) == 3
assert dshape[0] % factor == 0 assert dshape[0] % cfactor == 0
data = nnvm.sym.reshape(data, data = nnvm.sym.reshape(data,
shape=(dshape[0] // factor, shape=(dshape[0] // cfactor,
factor, dshape[1], dshape[2])) cfactor, dshape[1],
dshape[2], 1))
data = nnvm.sym.transpose( data = nnvm.sym.transpose(
data, axes=(0, 2, 3, 1)) data, axes=(0, 2, 3, 4, 1))
# broadcast batch dimension to bfactor
data = nnvm.sym.broadcast_to(
data,
shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
return data return data
...@@ -245,8 +252,8 @@ def clean_cast(graph): ...@@ -245,8 +252,8 @@ def clean_cast(graph):
return ret return ret
def pack(graph, shape_dict, factor, start_name=None): def pack(graph, shape_dict, bfactor, cfactor, start_name=None):
"""Pack the graph into channel packed format. """Pack the graph into batch&channel packed format.
Parameters Parameters
---------- ----------
...@@ -256,8 +263,11 @@ def pack(graph, shape_dict, factor, start_name=None): ...@@ -256,8 +263,11 @@ def pack(graph, shape_dict, factor, start_name=None):
shape_dict : dict of str to shapex shape_dict : dict of str to shapex
The input shape. The input shape.
factor : int bfactor : int
The packing factor The packing factor in batch
cfactor : int
The packing factor in channel
start_name: str, optional start_name: str, optional
Start name start packing from certain known node. Start name start packing from certain known node.
...@@ -290,42 +300,44 @@ def pack(graph, shape_dict, factor, start_name=None): ...@@ -290,42 +300,44 @@ def pack(graph, shape_dict, factor, start_name=None):
new_node = nnvm.symbol.Variable(node_name) new_node = nnvm.symbol.Variable(node_name)
if start_name and node_name == start_name: if start_name and node_name == start_name:
start_pack = True start_pack = True
new_node = _pack_channel(new_node, oshape, factor) new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
elif op_name == "max_pool2d": elif op_name == "max_pool2d":
assert not start_pack assert not start_pack
start_pack = True start_pack = True
new_node = get_clone(children, op_name, node_name, attrs) new_node = get_clone(children, op_name, node_name, attrs)
new_node = _pack_channel(new_node, oshape, factor) new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
elif op_name == "global_avg_pool2d": elif op_name == "global_avg_pool2d":
if start_pack: if start_pack:
start_pack = False start_pack = False
children[0] = _unpack_channel(children[0], ishape[0]) children[0] = _unpack_batch_channel(children[0], ishape[0])
new_node = getattr(nnvm.symbol, op_name)( new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs) *children, name=node_name, **attrs)
else: else:
new_node = get_clone(children, op_name, node_name, attrs) new_node = get_clone(children, op_name, node_name, attrs)
elif op_name == "quantized_conv2d": elif op_name == "quantized_conv2d":
if start_pack: if start_pack:
attrs["pack_channel"] = str(factor) attrs["pack_batch"] = str(bfactor)
attrs["pack_channel"] = str(cfactor)
data, weight = children data, weight = children
weight = _pack_weight(weight, ishape[1], factor) weight = _pack_weight(weight, ishape[1], cfactor)
new_node = nnvm.sym.quantized_conv2d( new_node = nnvm.sym.quantized_conv2d(
data, weight, name=node_name, **attrs) data, weight, name=node_name, **attrs)
elif counter == 1: elif counter == 1:
attrs["pack_channel"] = str(factor) attrs["pack_batch"] = str(bfactor)
attrs["pack_channel"] = str(cfactor)
data, weight = children data, weight = children
data = _pack_channel(data, ishape[0], factor) data = _pack_batch_channel(data, ishape[0], bfactor, cfactor)
weight = _pack_weight(weight, ishape[1], factor) weight = _pack_weight(weight, ishape[1], cfactor)
new_node = nnvm.sym.quantized_conv2d( new_node = nnvm.sym.quantized_conv2d(
data, weight, name=node_name, **attrs) data, weight, name=node_name, **attrs)
new_node = _unpack_channel(new_node, oshape) new_node = _unpack_batch_channel(new_node, oshape)
counter = counter + 1 counter = counter + 1
else: else:
new_node = get_clone(children, op_name, node_name, attrs) new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("broadcast"): elif op_name.startswith("broadcast"):
if start_pack: if start_pack:
assert len(ishape[1]) == 3 assert len(ishape[1]) == 3
children[1] = _pack_bias(children[1], ishape[1], factor) children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor)
new_node = getattr(nnvm.symbol, op_name)( new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs) *children, name=node_name, **attrs)
else: else:
...@@ -341,7 +353,7 @@ def pack(graph, shape_dict, factor, start_name=None): ...@@ -341,7 +353,7 @@ def pack(graph, shape_dict, factor, start_name=None):
ret = node_map[graph.index.output_entries[0][0]] ret = node_map[graph.index.output_entries[0][0]]
if start_pack: if start_pack:
oshape = shape[graph.index.output_entries[0][0]] oshape = shape[graph.index.output_entries[0][0]]
ret = _unpack_channel(ret, oshape) ret = _unpack_batch_channel(ret, oshape)
graph = nnvm.graph.create(ret) graph = nnvm.graph.create(ret)
graph = graph_attr.set_shape_inputs(graph, shape_dict) graph = graph_attr.set_shape_inputs(graph, shape_dict)
graph = graph.apply("InferShape") graph = graph.apply("InferShape")
......
...@@ -367,9 +367,10 @@ class UopQueue : public BaseQueue { ...@@ -367,9 +367,10 @@ class UopQueue : public BaseQueue {
} }
assert(num_op <= kMaxNumUop); assert(num_op <= kMaxNumUop);
uint32_t uop_begin = 0; uint32_t uop_begin = 0;
if (sram_end_ + num_op > kMaxElems) { if (sram_end_ + num_op > kMaxNumUop) {
// Need to evict // Need to evict
cache_ptr_ = 0; cache_ptr_ = 0;
sram_begin_ = 0;
sram_end_ = num_op; sram_end_ = num_op;
} else { } else {
uop_begin = sram_end_; uop_begin = sram_end_;
...@@ -388,6 +389,7 @@ class UopQueue : public BaseQueue { ...@@ -388,6 +389,7 @@ class UopQueue : public BaseQueue {
dram_end_ += num_op; dram_end_ += num_op;
kernel->sram_begin_ = uop_begin; kernel->sram_begin_ = uop_begin;
kernel->sram_end_ = sram_end_; kernel->sram_end_ = sram_end_;
CHECK(kernel->cached());
assert(uop_begin != sram_end_); assert(uop_begin != sram_end_);
cache_.insert(cache_.begin() + cache_ptr_, kernel); cache_.insert(cache_.begin() + cache_ptr_, kernel);
cache_.erase(cache_.begin() + evict_begin, cache_.begin() + cache_ptr_); cache_.erase(cache_.begin() + evict_begin, cache_.begin() + cache_ptr_);
......
...@@ -162,6 +162,7 @@ class DRAM { ...@@ -162,6 +162,7 @@ class DRAM {
*/ */
void Free(void* data) { void Free(void* data) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (pmap_.size() == 0) return;
auto it = pmap_.find(data); auto it = pmap_.find(data);
CHECK(it != pmap_.end()); CHECK(it != pmap_.end());
Page* p = it->second.get(); Page* p = it->second.get();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment