Commit f364d563 by Tianqi Chen Committed by GitHub

[CONTRIB/BLAS] Add CBLAS Example to contrib (#120)

* [CONTRIB/BLAS] Add CBLAS Example to contrib

* Update makefile
parent 8a5b6c21
...@@ -7,7 +7,7 @@ endif() ...@@ -7,7 +7,7 @@ endif()
include(cmake/Util.cmake) include(cmake/Util.cmake)
tvm_option(USE_CUDA "Build with CUDA" ON) tvm_option(USE_CUDA "Build with CUDA" ON)
tvm_option(USE_OPENCL "Build with OpenCL" ON) tvm_option(USE_OPENCL "Build with OpenCL" OFF)
tvm_option(USE_LLVM "Build with LLVM" OFF) tvm_option(USE_LLVM "Build with LLVM" OFF)
tvm_option(USE_RTTI "Build with RTTI" OFF) tvm_option(USE_RTTI "Build with RTTI" OFF)
tvm_option(USE_MSVC_MT "Build with MT" OFF) tvm_option(USE_MSVC_MT "Build with MT" OFF)
......
...@@ -10,47 +10,52 @@ endif ...@@ -10,47 +10,52 @@ endif
include $(config) include $(config)
# specify tensor path
.PHONY: clean all test doc pylint cpplint lint verilog cython cython2 cython3 .PHONY: clean all test doc pylint cpplint lint verilog cython cython2 cython3
all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a
LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a # The source code dependencies
LIB_HALIDEIR = HalideIR/lib/libHalideIR.a
SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc) CC_SRC = $(filter-out src/contrib/%.cc src/runtime/%.cc,\
$(wildcard src/*/*.cc src/*/*/*.cc))
METAL_SRC = $(wildcard src/runtime/metal/*.mm) METAL_SRC = $(wildcard src/runtime/metal/*.mm)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc)
# Objectives
METAL_OBJ = $(patsubst src/%.mm, build/%.o, $(METAL_SRC)) METAL_OBJ = $(patsubst src/%.mm, build/%.o, $(METAL_SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC))
RUNTIME_OBJ = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))
CONTRIB_OBJ =
RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc) UNAME_S := $(shell uname -s)
RUNTIME_DEP = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) # Deps
ALL_DEP = $(CC_OBJ) $(CONTRIB_OBJ) $(LIB_HALIDEIR)
RUNTIME_DEP = $(RUNTIME_OBJ)
export LDFLAGS = -pthread -lm # The flags
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\ LDFLAGS = -pthread -lm
-Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0 CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
export OBJCFLAGS= -fno-objc-arc -Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
FRAMEWORKS =
OBJCFLAGS = -fno-objc-arc
# Dependency specific rules
ifdef CUDA_PATH ifdef CUDA_PATH
NVCC=$(CUDA_PATH)/bin/nvcc NVCC=$(CUDA_PATH)/bin/nvcc
CFLAGS += -I$(CUDA_PATH)/include CFLAGS += -I$(CUDA_PATH)/include
LDFLAGS += -L$(CUDA_PATH)/lib64 LDFLAGS += -L$(CUDA_PATH)/lib64
endif endif
ifeq ($(ENABLE_CUDA), 1) ifeq ($(USE_CUDA), 1)
CFLAGS += -DTVM_CUDA_RUNTIME=1 CFLAGS += -DTVM_CUDA_RUNTIME=1
LDFLAGS += -lcuda -lcudart -lnvrtc LDFLAGS += -lcuda -lcudart -lnvrtc
else else
CFLAGS += -DTVM_CUDA_RUNTIME=0 CFLAGS += -DTVM_CUDA_RUNTIME=0
endif endif
FRAMEWORKS= ifeq ($(USE_OPENCL), 1)
UNAME_S := $(shell uname -s)
ifeq ($(ENABLE_OPENCL), 1)
CFLAGS += -DTVM_OPENCL_RUNTIME=1 CFLAGS += -DTVM_OPENCL_RUNTIME=1
ifeq ($(UNAME_S), Darwin) ifeq ($(UNAME_S), Darwin)
FRAMEWORKS += -framework OpenCL FRAMEWORKS += -framework OpenCL
...@@ -61,10 +66,9 @@ else ...@@ -61,10 +66,9 @@ else
CFLAGS += -DTVM_OPENCL_RUNTIME=0 CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif endif
ifeq ($(ENABLE_METAL), 1) ifeq ($(USE_METAL), 1)
CFLAGS += -DTVM_METAL_RUNTIME=1 CFLAGS += -DTVM_METAL_RUNTIME=1
LDFLAGS += -lObjc LDFLAGS += -lObjc
ALL_DEP += $(METAL_OBJ)
RUNTIME_DEP += $(METAL_OBJ) RUNTIME_DEP += $(METAL_OBJ)
FRAMEWORKS += -framework Metal -framework Foundation FRAMEWORKS += -framework Metal -framework Foundation
else else
...@@ -74,13 +78,15 @@ endif ...@@ -74,13 +78,15 @@ endif
# llvm configuration # llvm configuration
LLVM_CONFIG=llvm-config LLVM_CONFIG=llvm-config
ifeq ($(ENABLE_LLVM), 1) ifeq ($(USE_LLVM), 1)
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3) LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags)) LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags))
LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs) LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs)
CFLAGS += $(LLVM_INCLUDE) -DTVM_LLVM_VERSION=$(LLVM_VERSION) CFLAGS += $(LLVM_INCLUDE) -DTVM_LLVM_VERSION=$(LLVM_VERSION)
endif endif
include make/contrib/cblas.mk
ifdef ADD_CFLAGS ifdef ADD_CFLAGS
CFLAGS += $(ADD_CFLAGS) CFLAGS += $(ADD_CFLAGS)
endif endif
...@@ -106,7 +112,7 @@ build/%.o: src/%.mm ...@@ -106,7 +112,7 @@ build/%.o: src/%.mm
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@ $(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
lib/libtvm.so: $(ALL_DEP) lib/libtvm.so: $(ALL_DEP) $(RUNTIME_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
...@@ -114,11 +120,11 @@ lib/libtvm_runtime.so: $(RUNTIME_DEP) ...@@ -114,11 +120,11 @@ lib/libtvm_runtime.so: $(RUNTIME_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libtvm.a: $(ALL_DEP) lib/libtvm.a: $(ALL_DEP) $(RUNTIME_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
ar crv $@ $(filter %.o, $?) ar crv $@ $(filter %.o, $?)
$(LIB_HALIDE_IR): LIBHALIDEIR $(LIB_HALIDEIR): LIBHALIDEIR
LIBHALIDEIR: LIBHALIDEIR:
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR) + cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)
......
...@@ -4,8 +4,11 @@ TVM Change Log ...@@ -4,8 +4,11 @@ TVM Change Log
This file records the changes in TVM library in reverse chronological order. This file records the changes in TVM library in reverse chronological order.
## Initial version (0.1rc) ## Initial version (0.1rc)
- CUDA/OpenCL codegen - External function and contrib libraries
- LLVM codegen - Metal backend
- AOT and module system - OpenCL backend
- External function call - CUDA backend
- Beta verilog codegen - LLVM backend
- DLPack integration support
- AOT and module system
- Basic code structure ready.
\ No newline at end of file
...@@ -346,7 +346,7 @@ TVM_DLL int TVMFuncRegisterGlobal( ...@@ -346,7 +346,7 @@ TVM_DLL int TVMFuncRegisterGlobal(
* \brief Get a global function. * \brief Get a global function.
* *
* \param name The name of the function. * \param name The name of the function.
* \param out the result function pointer. * \param out the result function pointer, NULL if it does not exist.
* *
* \note The function handle of global function is managed by TVM runtime, * \note The function handle of global function is managed by TVM runtime,
* So TVMFuncFree is should not be called when it get deleted. * So TVMFuncFree is should not be called when it get deleted.
......
/*!
* Copyright (c) 2017 by Contributors
* \file util.h
* \brief Useful runtime util.
*/
#ifndef TVM_RUNTIME_UTIL_H_
#define TVM_RUNTIME_UTIL_H_
#include "./c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief Check whether type matches the given spec.
* \param t The type
* \param code The type code.
* \param bits The number of bits to be matched.
* \param lanes The number of lanes sin the type.
*/
inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_UTIL_H_
...@@ -16,11 +16,6 @@ ...@@ -16,11 +16,6 @@
# $ make -j8 # $ make -j8
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
#---------------------
# choice of compiler
#--------------------
export NVCC = nvcc
# whether compile with debug # whether compile with debug
DEBUG = 0 DEBUG = 0
...@@ -31,22 +26,27 @@ ADD_LDFLAGS = ...@@ -31,22 +26,27 @@ ADD_LDFLAGS =
ADD_CFLAGS = ADD_CFLAGS =
#--------------------------------------------- #---------------------------------------------
# matrix computation libraries for CPU/GPU # Backend runtimes.
#--------------------------------------------- #---------------------------------------------
# whether enable CUDA during compile # whether enable CUDA during compile
ENABLE_CUDA = 1 USE_CUDA = 1
# whether enable OpenCL during compile # whether enable OpenCL during compile
ENABLE_OPENCL = 0 USE_OPENCL = 0
# whether enable Metal during compile # whether enable Metal during compile
ENABLE_METAL = 0 USE_METAL = 0
# whether build with LLVM support # whether build with LLVM support
# This requires llvm-config to be in your PATH # This requires llvm-config to be in your PATH
# Requires LLVM version >= 4.0 # Requires LLVM version >= 4.0
ENABLE_LLVM = 0 USE_LLVM = 0
#---------------------------------------------
# Contrib optional libraries.
#---------------------------------------------
# Whether use BLAS, choices: openblas, atlas, blas, apple
USE_BLAS = none
# add the path to CUDA library to link and compile flag # add the path to CUDA library to link and compile flag
# if you have already add them to environment variable. # if you have already add them to environment variable.
......
CBLAS_CONTRIB_SRC = $(wildcard src/contrib/cblas/*.cc)
CBLAS_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(CBLAS_CONTRIB_SRC))
ifeq ($(USE_BLAS), openblas)
ADD_LDFLAGS += -lopenblas
RUNTIME_DEP += $(CBLAS_CONTRIB_OBJ)
else ifeq ($(USE_BLAS), atlas)
ADD_LDFLAGS += -lcblas
RUNTIME_DEP += $(CBLAS_CONTRIB_OBJ)
else ifeq ($(USE_BLAS), blas)
ADD_LDFLAGS += -lblas
RUNTIME_DEP += $(CBLAS_CONTRIB_OBJ)
else ifeq ($(USE_BLAS), apple)
ADD_CFLAGS += -I/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Versions/Current/Headers/
FRAMEWORKS += -framework Accelerate
RUNTIME_DEP += $(CBLAS_CONTRIB_OBJ)
endif
...@@ -196,7 +196,7 @@ def register_func(func_name, f=None, override=False): ...@@ -196,7 +196,7 @@ def register_func(func_name, f=None, override=False):
return register return register
def get_global_func(name): def get_global_func(name, allow_missing=False):
"""Get a global function by name """Get a global function by name
Parameters Parameters
...@@ -204,14 +204,24 @@ def get_global_func(name): ...@@ -204,14 +204,24 @@ def get_global_func(name):
name : str name : str
The name of the global function The name of the global function
allow_missing : bool
Whether allow missing function or raise an error.
Returns Returns
------- -------
func : tvm.Function func : tvm.Function
The function to be returned. The function to be returned, None if function is missing.
""" """
handle = FunctionHandle() handle = FunctionHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle))) check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
return Function(handle, False) if handle.value:
return Function(handle, False)
else:
if allow_missing:
return None
else:
raise ValueError("Cannot find global function %s" % name)
def list_global_func_names(): def list_global_func_names():
......
"""External function interface to BLAS libraroes."""
from __future__ import absolute_import as _abs
from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to calle external libraries.
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
...@@ -106,6 +106,12 @@ class Schedule(NodeBase): ...@@ -106,6 +106,12 @@ class Schedule(NodeBase):
include_inputs : boolean, optional include_inputs : boolean, optional
Whether include input operations in the group if they are used by outputs. Whether include input operations in the group if they are used by outputs.
Returns
-------
group : Stage
A virtual stage represents the group, user can use compute_at to move
the attachment point of the group.
""" """
if isinstance(outputs, _tensor.Tensor): if isinstance(outputs, _tensor.Tensor):
outputs = [outputs] outputs = [outputs]
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
Header files in include are public APIs that share across modules. Header files in include are public APIs that share across modules.
There can be internal header files within each module that sit in src. There can be internal header files within each module that sit in src.
The current code modules in src. ## Modules
- common Internal common utilities. - common Internal common utilities.
- api API function registration - api API function registration
- lang The definition of DSL related data structure - lang The definition of DSL related data structure
......
/*!
* Copyright (c) 2017 by Contributors
* \file Use external cblas library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
extern "C" {
#include <cblas.h>
}
namespace tvm {
namespace contrib {
using namespace runtime;
// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kFloat, 32));
CHECK(TypeMatch(B->dtype, kFloat, 32));
CHECK(TypeMatch(C->dtype, kFloat, 32));
cblas_sgemm(CblasColMajor,
transb ? CblasTrans : CblasNoTrans,
transa ? CblasTrans : CblasNoTrans,
transb ? B->shape[0] : B->shape[1],
transa ? A->shape[1] : A->shape[0],
transa ? B->shape[1] : B->shape[0],
1.0f,
static_cast<float*>(B->data), B->shape[1],
static_cast<float*>(A->data), A->shape[1],
0.0f,
static_cast<float*>(C->data), C->shape[1]);
});
} // namespace contrib
} // namespace tvm
...@@ -105,9 +105,11 @@ int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { ...@@ -105,9 +105,11 @@ int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
API_BEGIN(); API_BEGIN();
const tvm::runtime::PackedFunc* fp = const tvm::runtime::PackedFunc* fp =
tvm::runtime::Registry::Get(name); tvm::runtime::Registry::Get(name);
CHECK(fp != nullptr) if (fp != nullptr) {
<< "Cannot find global function " << name; *out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*)
*out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*) } else {
*out = nullptr;
}
API_END(); API_END();
} }
......
import tvm
import numpy as np
from tvm.contrib import cblas
def test_matmul_add():
n = 1024
l = 128
m = 235
bias = tvm.var('bias', dtype=tvm.float32)
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = cblas.matmul(A, B)
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op)
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
print("skip because extern function is not avalable")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
np.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb)
verify()
if __name__ == "__main__":
test_matmul_add()
...@@ -18,17 +18,17 @@ if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then ...@@ -18,17 +18,17 @@ if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
fi fi
cp make/config.mk config.mk cp make/config.mk config.mk
echo "ENABLE_CUDA=0" >> config.mk echo "USE_CUDA=0" >> config.mk
if [ ${TRAVIS_OS_NAME} == "osx" ]; then if [ ${TRAVIS_OS_NAME} == "osx" ]; then
echo "ENABLE_OPENCL=1" >> config.mk echo "USE_OPENCL=1" >> config.mk
echo "ENABLE_METAL=1" >> config.mk echo "USE_METAL=1" >> config.mk
else else
# use g++-4.8 for linux # use g++-4.8 for linux
if [ ${CXX} == "g++" ]; then if [ ${CXX} == "g++" ]; then
export CXX=g++-4.8 export CXX=g++-4.8
fi fi
echo "ENABLE_OPENCL=0" >> config.mk echo "USE_OPENCL=0" >> config.mk
fi fi
if [ ${TASK} == "verilog_test" ] || [ ${TASK} == "all_test" ]; then if [ ${TASK} == "verilog_test" ] || [ ${TASK} == "all_test" ]; then
......
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
# #
# A **Schedule** is a set of transformation of computation that # A **Schedule** is a set of transformation of computation that
# transforms the loop of computations in the program. # transforms the loop of computations in the program.
#
# declare some variables for use later # declare some variables for use later
n = tvm.var('n') n = tvm.var('n')
...@@ -50,7 +51,7 @@ print(tvm.lower(s, [A, B, C], with_api_wrapper=False)) ...@@ -50,7 +51,7 @@ print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
###################################################################### ######################################################################
# split # split
# -------------------------- # -----
# :code:`split` can split a specified axis into two axises by # :code:`split` can split a specified axis into two axises by
# :code:`factor`. # :code:`factor`.
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
...@@ -72,7 +73,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False)) ...@@ -72,7 +73,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
###################################################################### ######################################################################
# tile # tile
# -------------------------- # ----
# :code:`tile` help you execute the computation tile by tile over two # :code:`tile` help you execute the computation tile by tile over two
# axises. # axises.
A = tvm.placeholder((m, n), name='A') A = tvm.placeholder((m, n), name='A')
...@@ -84,7 +85,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False)) ...@@ -84,7 +85,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
###################################################################### ######################################################################
# fuse # fuse
# -------------------------- # ----
# :code:`fuse` can fuse two consecutive axises of one computation. # :code:`fuse` can fuse two consecutive axises of one computation.
A = tvm.placeholder((m, n), name='A') A = tvm.placeholder((m, n), name='A')
B = tvm.compute((m, n), lambda i, j: A[i, j], name='B') B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')
...@@ -98,7 +99,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False)) ...@@ -98,7 +99,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
###################################################################### ######################################################################
# reorder # reorder
# -------------------------- # -------
# :code:`reorder` can reorder the axises in the specified order. # :code:`reorder` can reorder the axises in the specified order.
A = tvm.placeholder((m, n), name='A') A = tvm.placeholder((m, n), name='A')
B = tvm.compute((m, n), lambda i, j: A[i, j], name='B') B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')
...@@ -112,7 +113,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False)) ...@@ -112,7 +113,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
###################################################################### ######################################################################
# bind # bind
# -------------------------- # ----
# :code:`bind` can bind a specified axis with a thread axis, often used # :code:`bind` can bind a specified axis with a thread axis, often used
# in gpu programming. # in gpu programming.
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
...@@ -126,7 +127,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False)) ...@@ -126,7 +127,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
###################################################################### ######################################################################
# compute_at # compute_at
# -------------------------- # ----------
# For a schedule consists of multiple operators, tvm will compute # For a schedule consists of multiple operators, tvm will compute
# tensors at the root separately by default. # tensors at the root separately by default.
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
...@@ -149,7 +150,7 @@ print(tvm.lower(s, [A, B, C], with_api_wrapper=False)) ...@@ -149,7 +150,7 @@ print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
###################################################################### ######################################################################
# compute_inline # compute_inline
# -------------------------- # --------------
# :code:`compute_inline` can mark one stage as inline, then the body of # :code:`compute_inline` can mark one stage as inline, then the body of
# computation will be expanded and inserted at the address where the # computation will be expanded and inserted at the address where the
# tensor is required. # tensor is required.
...@@ -163,7 +164,7 @@ print(tvm.lower(s, [A, B, C], with_api_wrapper=False)) ...@@ -163,7 +164,7 @@ print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
###################################################################### ######################################################################
# compute_root # compute_root
# -------------------------- # ------------
# :code:`compute_root` can move computation of one stage to the root. # :code:`compute_root` can move computation of one stage to the root.
A = tvm.placeholder((m,), name='A') A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i]+1, name='B') B = tvm.compute((m,), lambda i: A[i]+1, name='B')
......
# Verilog Code Guidline # Verilog Code Guidline
The verilog backend is still at early alpha and not yet ready to use.
- Use ```my_port_name``` for variable naming. - Use ```my_port_name``` for variable naming.
- Always use suffix to indicate certain usage. - Always use suffix to indicate certain usage.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment