Commit e8c6adc6 by Peter Yeh Committed by masahi

Add .hsaco save/load for ROCm target (#3852)

fix lld
parent 54150cd5
......@@ -46,7 +46,7 @@ def find_lld(required=True):
major = codegen.llvm_version_major()
lld_list += ["ld.lld-%d.0" % major]
lld_list += ["ld.lld-%d" % major]
lld_list += ["lld"]
lld_list += ["ld.lld"]
valid_list = [util.which(x) for x in lld_list]
valid_list = [x for x in valid_list if x]
if not valid_list and required:
......
......@@ -170,8 +170,8 @@ inline int DetectROCMComputeVersion(const std::string& target) {
return val.operator int();
}
}
LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx803";
return 803;
LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx900";
return 900;
}
runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
......
......@@ -33,7 +33,7 @@ import numpy as np
# Global declarations of environment.
tgt_host="llvm"
# Change it to respective GPU if gpu is enabled Ex: cuda, opencl
# Change it to respective GPU if gpu is enabled Ex: cuda, opencl, rocm
tgt="cuda"
######################################################################
......@@ -113,7 +113,7 @@ bx, tx = s[C].split(C.op.axis[0], factor=64)
# compute grid. These are GPU specific constructs that allow us
# to generate code that runs on GPU.
#
if tgt == "cuda" or tgt.startswith('opencl'):
if tgt == "cuda" or tgt == "rocm" or tgt.startswith('opencl'):
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
......@@ -168,7 +168,7 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
#
# The following code fetches the device module and prints the content code.
#
if tgt == "cuda" or tgt.startswith('opencl'):
if tgt == "cuda" or tgt == "rocm" or tgt.startswith('opencl'):
dev_module = fadd.imported_modules[0]
print("-----GPU code-----")
print(dev_module.get_source())
......@@ -212,6 +212,8 @@ temp = util.tempdir()
fadd.save(temp.relpath("myadd.o"))
if tgt == "cuda":
fadd.imported_modules[0].save(temp.relpath("myadd.ptx"))
if tgt == "rocm":
fadd.imported_modules[0].save(temp.relpath("myadd.hsaco"))
if tgt.startswith('opencl'):
fadd.imported_modules[0].save(temp.relpath("myadd.cl"))
cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")])
......@@ -238,6 +240,10 @@ if tgt == "cuda":
fadd1_dev = tvm.module.load(temp.relpath("myadd.ptx"))
fadd1.import_module(fadd1_dev)
if tgt == "rocm":
fadd1_dev = tvm.module.load(temp.relpath("myadd.hsaco"))
fadd1.import_module(fadd1_dev)
if tgt.startswith('opencl'):
fadd1_dev = tvm.module.load(temp.relpath("myadd.cl"))
fadd1.import_module(fadd1_dev)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment