Commit 4d280905 by Tianqi Chen Committed by GitHub

[PERF] Persitent kernel (#87)

* [PERF] Persitent kernel

* fix doc
parent 27205e36
......@@ -95,3 +95,4 @@ config.mk
build_*
Win32
*.dir
perf
# TVM
hack for fun
Run the tests, set the bashrc
[![Build Status](https://travis-ci.com/tqchen/tvm.svg?token=ZQpnpAReT4LHdjWAX8jR&branch=master)](https://travis-ci.com/tqchen/tvm)
```bash
export PYTHONPATH=${PYTHONPATH}:/path/to/tvm/python
```
[Installation](docs/how_to/install.md) |
[Contributor Guide](docs/how_to/contribute.md) |
[Release Notes](NEWS.md)
Write and run tests via
```bash
nosetests tests/python
```
\ No newline at end of file
# TVM
TVM is a domain specific language(DSL) for tensor computations.
The goal of the project is to generate efficient kernels for deep learning workloads.
......@@ -3,7 +3,7 @@
TVM has been developed and used by a group of active community members.
Everyone is more than welcome to contribute. It is a way to make the project better and more accessible to more users.
- Please update [NEWS.md](../NEWS.md) to add note on your changes to the API or added a new document.
- Please update [NEWS.md](../../NEWS.md) to add note on your changes to the API or added a new document.
## Guidelines
* [Submit Pull Request](#submit-pull-request)
......
Installation Guide
==================
This page gives instructions on how to build and install the tvm package from
scratch on various systems. It consists of two steps:
1. First build the shared library from the C++ codes (`libtvm.so` for linux/osx and `libtvm.dll` for windows).
2. Setup for the language packages (e.g. Python Package).
To get started, clone tvm repo from github. It is important to clone the submodules along, with ```--recursive``` option.
```bash
git clone --recursive https://github.com/tqchen/tvm
```
For windows users who use github tools, you can open the git shell, and type the following command.
```bash
git submodule init
git submodule update
```
## Contents
- [Build the Shared Library](#build-the-shared-library)
- [Python Package Installation](#python-package-installation)
## Build the Shared Library
Our goal is to build the shared library:
- On Linux/OSX the target library is `libtvm.so`
- On Windows the target library is `libtvm.dll`
The minimal building requirement is
- A recent c++ compiler supporting C++ 11 (g++-4.8 or higher)
You can edit `make/config.mk` to change the compile options, and then build by
`make`. If everything goes well, we can go to the specific language installation section.
### Building on Windows
TVM support build via MSVC using cmake. To build with Visual Studio 2015 use cmake.
Make sure you have a recent version of cmake added to your path and then from the xgboost directory:
```bash
mkdir build
cd build
cmake .. -G"Visual Studio 14 2015 Win64"
```
This specifies an out of source build using the MSVC 12 64 bit generator. Open the .sln file in the build directory and build with Visual Studio.
### Customized Building
The configuration of xgboost can be modified by ```config.mk```
- First copy [make/config.mk](../make/config.mk) to the project root, on which
any local modification will be ignored by git, then modify the according flags.
- TVM optionally depends on LLVM. LLVM is required for CPU codegen that needs LLVM.
- LLVM 4.0 is needed for build with LLVM
- By default CUDA and OpenCL code generator do not require llvm.
## Python Package Installation
The python package is located at [python](../python).
There are several ways to install the package:
1. Set the environment variable `PYTHONPATH` to tell python where to find
the library. For example, assume we cloned `tvm` on the home directory
`~`. then we can added the following line in `~/.bashrc`.
It is ***recommended for developers*** who may change the codes.
The changes will be immediately reflected once you pulled the code and rebuild the project (no need to call ```setup``` again)
```bash
export PYTHONPATH=/path/to/tvm/python:${PYTHONPATH}
```
......@@ -184,6 +184,11 @@ constexpr const char* tvm_call_packed = "tvm_call_packed";
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
*/
constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
/*!
* \brief See pesudo code
*
* Expr tvm_thread_allreduce(std::string op, Expr value, Expr cond,
......
......@@ -147,9 +147,13 @@ namespace symbol {
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
/*! \brief Local function to set the device during API entry. */
constexpr const char* tvm_entry_setdevice = "__tvm_entry_setdevice";
/*! \brief Auxiliary counter to global barrier. */
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_module_main = "__tvm_main__";
} // packed symbol
} // namespace symbol
// implementations of inline functions.
inline ModuleNode* Module::operator->() {
......
......@@ -77,7 +77,8 @@ def build(sch,
target_host="stackvm",
name="default_function",
binds=None,
max_auto_unroll_step=8):
max_auto_unroll_step=8,
detect_global_barrier=True):
"""Build a function with arguments as signiture.
Parameters
......@@ -104,6 +105,9 @@ def build(sch,
max_auto_unroll_step: int
Maximum step to perform automatic unrolling
detect_global_barrier: boolean
Whether detect and inser global barrier
Returns
-------
f : Function, or pair of functions
......@@ -122,14 +126,13 @@ def build(sch,
fapi = sch
else:
raise ValueError("sch have to be Schedule or LoweredFunc")
fsplits = ir_pass.SplitHostDevice(fapi)
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
# device related lowering
if detect_global_barrier:
fapi = ir_pass.StorageSync(fapi, "global")
fapi = ir_pass.StorageSync(fapi, "shared")
warp_size = 32 if target == "cuda" else 1
fsplits[i] = ir_pass.LowerThreadAllreduce(fsplits[i], warp_size)
fapi = ir_pass.LowerThreadAllreduce(fapi, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(fapi)]
if len(fsplits) > 1:
mhost = codegen.build(fsplits[0], target_host)
if target:
......
......@@ -200,7 +200,7 @@ void CodeGenC::PrintThreadIndexExpr(
os << thread_tag;
}
void CodeGenC::PrintStorageSync(const std::string& sync) { // NOLINT(*)
void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*)
}
void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
......@@ -737,7 +737,7 @@ void CodeGenC::VisitStmt_(const Evaluate *op) {
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) {
this->PrintStorageSync(call->args[0].as<StringImm>()->value);
this->PrintStorageSync(call);
} else {
std::string vid = this->PrintExpr(op->value);
this->PrintIndent();
......
......@@ -127,7 +127,7 @@ class CodeGenC :
virtual void PrintThreadIndexExpr(
std::string tag, std::ostream& os); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void PrintStorageSync(const std::string& scope); // NOLINT(*)
virtual void PrintStorageSync(const Call* op); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(
const std::string&op, Type op_type,
......
......@@ -14,6 +14,13 @@
namespace tvm {
namespace codegen {
void CodeGenCUDA::Init(bool output_ssa) {
CodeGenC::Init(output_ssa);
vid_global_barrier_state_ = GetUniqueName(runtime::symbol::tvm_global_barrier_state);
vid_global_barrier_expect_ = GetUniqueName("__barrier_expect");
CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
}
void CodeGenCUDA::AddFunction(LoweredFunc f) {
this->stream << "extern \"C\" __global__ ";
CodeGenC::AddFunction(f);
......@@ -132,12 +139,42 @@ void CodeGenCUDA::PrintVecElemStore(
stream << vec << "." << access[i] << " = " << value << ";\n";
}
void CodeGenCUDA::PrintStorageSync(const std::string& sync) {
void CodeGenCUDA::PrintStorageSync(const Call* op) {
const std::string& sync = op->args[0].as<StringImm>()->value;
if (sync == "shared") {
this->PrintIndent();
this->stream << "__syncthreads();\n";
} else if (sync == "global") {
LOG(FATAL) << "not supported";
if (!need_global_barrier_) {
need_global_barrier_ = true;
this->decl_stream << "extern \"C\" __device__ unsigned "
<< vid_global_barrier_state_ << ";\n";
}
// global synchronizer
std::string is_load = PrintExpr(op->args[1]);
std::string num_blocks = PrintExpr(op->args[2]);
this->PrintIndent();
// In theory only threadfence is needed
// but we observed problems with only threadfence
this->stream <<"__threadfence_system();\n";
this->PrintIndent();
this->stream <<"if (" << is_load << ") {\n";
int wb = this->BeginScope();
this->PrintIndent();
this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
this->PrintIndent();
std::string ptr = GetUniqueName("pf");
this->stream << "volatile unsigned* "
<< ptr << " = &" << vid_global_barrier_state_<< ";\n";
this->PrintIndent();
this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n";
this->PrintIndent();
this->stream <<"while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n";
this->EndScope(wb);
this->PrintIndent();
this->stream <<"}\n";
this->PrintIndent();
this->stream <<"__syncthreads();\n";
}
}
......@@ -148,5 +185,22 @@ void CodeGenCUDA::PrintStorageScope(
os << "__shared__ ";
}
}
void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return;
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) {
PrintIndent();
stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
PrintIndent();
stream << "if (threadIdx.x == 0) {\n";
PrintIndent();
stream << " " << vid_global_barrier_expect_ << " = 0;\n";
PrintIndent();
stream << "}\n";
} else {
CodeGenC::VisitStmt_(op);
}
}
} // namespace codegen
} // namespace tvm
......@@ -16,10 +16,11 @@ namespace codegen {
class CodeGenCUDA : public CodeGenC {
public:
void Init(bool output_ssa);
void AddFunction(LoweredFunc f);
// override behavior
void VisitStmt_(const ir::For* op) final;
void PrintStorageSync(const std::string& sync) final;
void PrintStorageSync(const Call* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
const std::string&op, Type t,
......@@ -30,10 +31,18 @@ class CodeGenCUDA : public CodeGenC {
void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final;
void VisitStmt_(const Evaluate *op) final;
private:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int max_auto_unroll_{1025};
// Whether global barrier is needed.
bool need_global_barrier_{false};
// Global barrier state
std::string vid_global_barrier_state_;
// Global barrier expected node.
std::string vid_global_barrier_expect_;
};
} // namespace codegen
......
......@@ -113,7 +113,8 @@ void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
stream << ");\n";
}
void CodeGenOpenCL::PrintStorageSync(const std::string& sync) {
void CodeGenOpenCL::PrintStorageSync(const Call* op) {
const std::string& sync = op->args[0].as<StringImm>()->value;
if (sync == "shared") {
this->PrintIndent();
this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n";
......
......@@ -22,7 +22,7 @@ class CodeGenOpenCL : public CodeGenC {
void PrintThreadIndexExpr(
std::string tag, std::ostream& os) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const std::string& scope) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
std::string GetVecLoad(const Variable* buffer,
Type t, Expr base) final;
......
......@@ -84,6 +84,8 @@ class CodeGenSourceBase {
virtual void PrintSSAAssign(
const std::string& target, const std::string& src, Type t) = 0;
/*! \brief the declaration stream */
std::ostringstream decl_stream;
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief name of each variable */
......
......@@ -106,7 +106,6 @@ class ThreadAllreduceBuilder : public IRMutator {
reduce_set.insert(v);
}
size_t nmatch = 0;
std::unordered_set<const Variable*> visited;
std::vector<ThreadEntry> vred, vpar;
for (const AttrStmt* attr : thread_extents_) {
ThreadEntry e;
......@@ -119,8 +118,6 @@ class ThreadAllreduceBuilder : public IRMutator {
CHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
if (!visited.count(iv->var.get())) {
visited.insert(iv->var.get());
if (reduce_set.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
......@@ -129,7 +126,6 @@ class ThreadAllreduceBuilder : public IRMutator {
}
}
}
}
CHECK_EQ(nmatch, reduce_set.size())
<< "Not all reduce index are presented in the context";
std::sort(vred.begin(), vred.end());
......@@ -261,6 +257,7 @@ class ThreadAllreduceBuilder : public IRMutator {
}
// The warp size of the device.
int warp_size_{1};
// surrounding scope of thread extent.
std::vector<const AttrStmt*> thread_extents_;
// The load remap
......
......@@ -97,6 +97,30 @@ class CUDAModuleNode : public runtime::ModuleNode {
}
return func;
}
// get a global var from primary context in device_id
CUdeviceptr GetGlobal(int device_id,
const std::string& global_name,
size_t expect_nbytes) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
}
CUdeviceptr global;
size_t nbytes;
CUresult result = cuModuleGetGlobal(&global, &nbytes,
module_[device_id], global_name.c_str());
CHECK_EQ(nbytes, expect_nbytes);
if (result != CUDA_SUCCESS) {
const char *msg;
cuGetErrorName(result, &msg);
LOG(FATAL)
<< "CUDAError: cuModuleGetGlobal " << global_name
<< " failed with error: " << msg;
}
return global;
}
private:
// the binary data
......@@ -163,6 +187,34 @@ class CUDAWrappedFunc {
ThreadAxisConfig thread_axis_cfg_;
};
class CUDAPrepGlobalBarrier {
public:
CUDAPrepGlobalBarrier(CUDAModuleNode* m,
std::shared_ptr<ModuleNode> sptr)
: m_(m), sptr_(sptr) {
std::fill(pcache_.begin(), pcache_.end(), 0);
}
void operator()(const TVMArgs& args, TVMRetValue* rv) const {
int device_id;
CUDA_CALL(cudaGetDevice(&device_id));
if (pcache_[device_id] == 0) {
pcache_[device_id] = m_->GetGlobal(
device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned));
}
CUDA_DRIVER_CALL(cuMemsetD32(pcache_[device_id], 0, 1));
}
private:
// internal module
CUDAModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
// mark as mutable, to enable lazy initialization
mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
};
void AutoSetCUDADevice(const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 3);
TVMValue* values = static_cast<TVMValue*>(args[0].operator void*());
......@@ -196,6 +248,8 @@ PackedFunc CUDAModuleNode::GetFunction(
<< "Device function do not have main";
if (name == symbol::tvm_entry_setdevice) {
return PackedFunc(AutoSetCUDADevice);
} else if (name == symbol::tvm_prepare_global_barrier) {
return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self));
}
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
......
......@@ -158,6 +158,18 @@ class SchedulePostProc : public IRMutator {
CHECK(scan);
var_value_[scan->scan_axis->var.get()] = op->value;
return this->Mutate(op->body);
} else if (op->attr_key == attr::thread_extent) {
// delete duplicated thread extent attr
auto it = thread_extent_scope_.find(op->node.get());
if (it != thread_extent_scope_.end()) {
CHECK(is_zero(ir::Simplify(it->second- op->value)));
return this->Mutate(op->body);
} else {
thread_extent_scope_[op->node.get()] = op->value;
Stmt ret = IRMutator::Mutate_(op, s);
thread_extent_scope_.erase(op->node.get());
return ret;
}
} else if (op->attr_key == ir::attr::realize_scope) {
auto it = replace_op_.find(op->node.get());
if (it != replace_op_.end()) {
......@@ -267,6 +279,8 @@ class SchedulePostProc : public IRMutator {
replace_realize_[key] = repl_realize;
replace_op_[src->op.get()] = repl_op;
}
// The thread extent scope.
std::unordered_map<const Node*, Expr> thread_extent_scope_;
// The scan value
std::unordered_map<const Variable*, Expr> var_value_;
// buffer replacement
......
# Perf Examples for TVM
This folder contains perf examples of tvm under various settings.
## GPU Perf Workflow
Since TVM is work in progress, some optimization might not be perfect.
One quick way I find useful is to do codegen plus manual modification.
The workflow is:
- Generate the GPU kernels, write them into a file, say ```cuda/matexp_generated.cu```
- Copy the generated file into another one, say ```cuda/matexp_manual.cu```,
do modifications according to your intuition.
- Set use_manual flag in the script to continue the codegen workflow as normal, but piggy back the manual written code instead.
- Observe the performance difference.
- If the performance improves, mark the manual code and think of optimization pass
to generate the desired target code.
"""Matrix exponential example.
This is an example for matrix exponential,
which calculates the following recursion formula
```math
X[t] = dot(X[t-1], W)
```
"""
import tvm
import time
import os
import argparse
from tvm.addon import nvcc_compiler
import numpy as np
# Quick knobs
TASK="rnn_matexp"
USE_MANUAL_CODE = False
PERSIST_KERNEL = False
DETECT_GLOBAL_BARRIER = True
SKIP_CHECK = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"])
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
if USE_MANUAL_CODE:
code = open("perf/%s_manual.cu" % TASK).read()
return code
def rnn_matexp():
n_num_step = 128
n_num_hidden = 1152
n_batch_size = 4
max_auto_unroll_step = 0
detect_global_barrier = DETECT_GLOBAL_BARRIER
num_step = tvm.Var("num_step")
num_hidden = tvm.convert(n_num_hidden)
batch_size = tvm.convert(n_batch_size)
num_thread_y = 8
num_thread_x = 16 * 3
num_sm = 24
Whh = tvm.placeholder((num_hidden, num_hidden), name="Whh")
s_init = tvm.compute((1, batch_size, num_hidden),
lambda _, i, j: 1.0, name="init")
s_state = tvm.placeholder((num_step, batch_size, num_hidden))
kh = tvm.reduce_axis((0, num_hidden), name="kh")
s_update = tvm.compute(
(num_step, batch_size, num_hidden),
lambda t, i, j: tvm.sum(s_state[t-1, i, kh] * Whh[kh, j], axis=kh),
name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
# schedule
s = tvm.Schedule(s_scan.op)
CL = s_update
SS = s.cache_read(s_state, "shared", [CL])
SL = s.cache_read(SS, "local", [CL])
WhhL = s.cache_read(Whh, "local", [CL])
ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y)
CLF = s.rfactor(CL, ko)
block_x = tvm.thread_axis((0, num_sm), "blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
if PERSIST_KERNEL:
s[s_scan.op].env_threads([block_x, thread_y, thread_x])
bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm)
tx, xi = s[s_init].split(xi, nparts=num_thread_x)
s[s_init].bind(bx, block_x)
s[s_init].bind(tx, thread_x)
bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm)
tx, xi = s[s_update].split(xi, nparts=num_thread_x)
s[s_update].bind(bx, block_x)
s[s_update].bind(tx, thread_x)
s[CL].bind(s[CL].op.reduce_axis[0], thread_y)
s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0])
if PERSIST_KERNEL:
s[WhhL].compute_at(s[s_scan], thread_x)
else:
s[WhhL].compute_at(s[CLF], CLF.op.axis[3])
kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1)
ko, ki = s[CLF].split(ki, factor=4)
s[SS].compute_at(s[CLF], kr)
s[SL].compute_at(s[CLF], ko)
xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3)
ty, xi = s[SS].split(xi, nparts=num_thread_y)
tx, xi = s[SS].split(xi, nparts=num_thread_x)
s[SS].bind(ty, thread_y)
s[SS].bind(tx, thread_x)
def check_device(target):
codes = []
f = tvm.build(s, [s_scan, Whh],
target,
max_auto_unroll_step=max_auto_unroll_step,
detect_global_barrier=detect_global_barrier)
ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
# launch the kernel.
res_np = np.zeros(
(n_num_step, n_batch_size, n_num_hidden)).astype("float32")
Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
Whh_np[:] = 2.0 / n_num_hidden
Whh_np[:, n_num_hidden//2:] = 0
res_a = tvm.nd.array(res_np, ctx)
Whh_a = tvm.nd.array(Whh_np, ctx)
# Skip first pass as it is compilation
f(res_a, Whh_a)
tvm.nd.sync(ctx)
# measure time cost of second step.
tstart = time.time()
f(res_a, Whh_a)
tvm.nd.sync(ctx)
tgap = time.time() - tstart
print("Time cost=%g" % tgap)
# correctness
if not SKIP_CHECK:
res_gpu = res_a.asnumpy()
res_cmp = np.ones_like(res_np).astype("float64")
Whh_np = Whh_np.astype("float64")
for t in range(1, n_num_step):
res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
for i in range(n_num_step):
for j in range(n_num_hidden):
if abs(res_cmp[i,0,j] - res_gpu[i,0,j]) > 1e-5:
print("%d, %d: %g vs %g" % (i,j, res_cmp[i,0,j], res_gpu[i,0,j]))
np.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3)
check_device("cuda")
if __name__ == "__main__":
rnn_matexp()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment