Commit 87c929f5 by Hu Shiwen Committed by Tianqi Chen

add msvc in cc (#531)

parent 85c545c7
......@@ -4,6 +4,10 @@ from __future__ import absolute_import as _abs
import sys
import subprocess
import os
from .util import tempdir
def create_shared(output,
objects,
options=None,
......@@ -24,26 +28,85 @@ def create_shared(output,
cc : str, optional
The compile string.
"""
if sys.platform == "darwin" or sys.platform.startswith('linux'):
_linux_shared(output, objects, options, cc)
elif sys.platform == "win32":
_windows_shared(output, objects, options)
else:
raise ValueError("Unsupported platform")
def _linux_shared(output, objects, options, cc="g++"):
cmd = [cc]
cmd += ["-shared", "-fPIC"]
if sys.platform == "darwin":
cmd += ["-undefined", "dynamic_lookup"]
cmd += ["-o", output]
if isinstance(objects, str):
cmd += [objects]
else:
cmd += objects
if options:
cmd += options
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += str(out)
raise RuntimeError(msg)
def _windows_shared(output, objects, options):
cl_cmd = ["cl"]
cl_cmd += ["-c"]
if isinstance(objects, str):
objects = [objects]
cl_cmd += objects
if options:
cl_cmd += options
temp = tempdir()
dllmain_path = temp.relpath("dllmain.cc")
with open(dllmain_path, "w") as dllmain_obj:
dllmain_obj.write('#include <windows.h>\
BOOL APIENTRY DllMain( HMODULE hModule,\
DWORD ul_reason_for_call,\
LPVOID lpReserved)\
{return TRUE;}')
cl_cmd += [dllmain_path]
temp_path = dllmain_path.replace("dllmain.cc", "")
cl_cmd += ["-Fo:" + temp_path]
proc = subprocess.Popen(
cl_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += str(out)
raise RuntimeError(msg)
link_cmd = ["link"]
link_cmd += ["-dll", "-FORCE:MULTIPLE"]
for obj in objects:
if obj.endswith(".cc"):
(_, temp_file_name) = os.path.split(obj)
(shot_name, _) = os.path.splitext(temp_file_name)
link_cmd += [os.path.join(temp_path, shot_name + ".obj")]
if obj.endswith(".o"):
link_cmd += [obj]
link_cmd += ["-EXPORT:__tvm_main__"]
link_cmd += [temp_path + "dllmain.obj"]
link_cmd += ["-out:" + output]
proc = subprocess.Popen(
link_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += out
msg += str(out)
raise RuntimeError(msg)
......@@ -16,7 +16,7 @@ class TempDirectory(object):
def remove(self):
"""Remote the tmp dir"""
if self.temp_dir:
self._rmtree(self.temp_dir)
self._rmtree(self.temp_dir, ignore_errors=True)
self.temp_dir = None
def __del__(self):
......
......@@ -2,14 +2,17 @@
from __future__ import absolute_import as _abs
from collections import namedtuple
from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api
from .contrib import cc as _cc, tar as _tar, util as _util
ProfileResult = namedtuple("ProfileResult", ["mean"])
class Module(ModuleBase):
"""Module container of all TVM generated functions"""
def __repr__(self):
return "Module(%s, %x)" % (self.type_key, self.handle.value)
......@@ -135,11 +138,13 @@ class Module(ModuleBase):
try:
feval = _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number)
def evaluator(*args):
"""Internal wrapped evaluator."""
# Wrap feval so we can add more stats in future.
mean = feval(*args)
return ProfileResult(mean=mean)
return evaluator
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")
......
......@@ -226,6 +226,7 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(
name);
gv->setAlignment(data_layout_->getTypeAllocSize(p_type));
gv->setInitializer(llvm::Constant::getNullValue(p_type));
gv->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
return gv;
}
......
......@@ -117,6 +117,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
ftype, llvm::Function::ExternalLinkage,
f->name, module_.get());
function_->setCallingConv(llvm::CallingConv::C);
function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
// set var map and align information
auto arg_it = function_->arg_begin();
for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) {
......
......@@ -41,6 +41,7 @@
#include <llvm/Target/TargetMachine.h>
#include <llvm/Target/TargetOptions.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/CodeGen/TargetLoweringObjectFileImpl.h>
#include <utility>
#include <string>
......
......@@ -2,6 +2,7 @@ import tvm
from tvm.contrib import cc, util
import ctypes
import os
import sys
import numpy as np
import subprocess
......@@ -88,7 +89,13 @@ def test_device_module_dump():
return
temp = util.tempdir()
name = "myadd_%s" % device
f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
if sys.platform == "darwin" or sys.platform.startswith('linux'):
f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
elif sys.platform == "win32":
f = tvm.build(s, [A, B], device, "llvm", name=name)
else:
raise ValueError("Unsupported platform")
path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso)
......@@ -96,10 +103,11 @@ def test_device_module_dump():
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f1(a, b)
f2 = tvm.module.system_lib()
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
f2[name](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
if sys.platform != "win32":
f2 = tvm.module.system_lib()
f2[name](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_device("cuda")
check_device("opencl")
......@@ -164,8 +172,9 @@ def test_combine_module_llvm():
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
mm['myadd2'](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_system_lib()
if sys.platform != "win32":
check_system_lib()
check_llvm()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment