Commit 87c929f5 by Hu Shiwen Committed by Tianqi Chen

add msvc in cc (#531)

parent 85c545c7
...@@ -4,6 +4,10 @@ from __future__ import absolute_import as _abs ...@@ -4,6 +4,10 @@ from __future__ import absolute_import as _abs
import sys import sys
import subprocess import subprocess
import os
from .util import tempdir
def create_shared(output, def create_shared(output,
objects, objects,
options=None, options=None,
...@@ -24,26 +28,85 @@ def create_shared(output, ...@@ -24,26 +28,85 @@ def create_shared(output,
cc : str, optional cc : str, optional
The compile string. The compile string.
""" """
if sys.platform == "darwin" or sys.platform.startswith('linux'):
_linux_shared(output, objects, options, cc)
elif sys.platform == "win32":
_windows_shared(output, objects, options)
else:
raise ValueError("Unsupported platform")
def _linux_shared(output, objects, options, cc="g++"):
cmd = [cc] cmd = [cc]
cmd += ["-shared", "-fPIC"] cmd += ["-shared", "-fPIC"]
if sys.platform == "darwin": if sys.platform == "darwin":
cmd += ["-undefined", "dynamic_lookup"] cmd += ["-undefined", "dynamic_lookup"]
cmd += ["-o", output] cmd += ["-o", output]
if isinstance(objects, str): if isinstance(objects, str):
cmd += [objects] cmd += [objects]
else: else:
cmd += objects cmd += objects
if options: if options:
cmd += options cmd += options
proc = subprocess.Popen( proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate() (out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += str(out)
raise RuntimeError(msg)
def _windows_shared(output, objects, options):
cl_cmd = ["cl"]
cl_cmd += ["-c"]
if isinstance(objects, str):
objects = [objects]
cl_cmd += objects
if options:
cl_cmd += options
temp = tempdir()
dllmain_path = temp.relpath("dllmain.cc")
with open(dllmain_path, "w") as dllmain_obj:
dllmain_obj.write('#include <windows.h>\
BOOL APIENTRY DllMain( HMODULE hModule,\
DWORD ul_reason_for_call,\
LPVOID lpReserved)\
{return TRUE;}')
cl_cmd += [dllmain_path]
temp_path = dllmain_path.replace("dllmain.cc", "")
cl_cmd += ["-Fo:" + temp_path]
proc = subprocess.Popen(
cl_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += str(out)
raise RuntimeError(msg)
link_cmd = ["link"]
link_cmd += ["-dll", "-FORCE:MULTIPLE"]
for obj in objects:
if obj.endswith(".cc"):
(_, temp_file_name) = os.path.split(obj)
(shot_name, _) = os.path.splitext(temp_file_name)
link_cmd += [os.path.join(temp_path, shot_name + ".obj")]
if obj.endswith(".o"):
link_cmd += [obj]
link_cmd += ["-EXPORT:__tvm_main__"]
link_cmd += [temp_path + "dllmain.obj"]
link_cmd += ["-out:" + output]
proc = subprocess.Popen(
link_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0: if proc.returncode != 0:
msg = "Compilation error:\n" msg = "Compilation error:\n"
msg += out msg += str(out)
raise RuntimeError(msg) raise RuntimeError(msg)
...@@ -16,7 +16,7 @@ class TempDirectory(object): ...@@ -16,7 +16,7 @@ class TempDirectory(object):
def remove(self): def remove(self):
"""Remote the tmp dir""" """Remote the tmp dir"""
if self.temp_dir: if self.temp_dir:
self._rmtree(self.temp_dir) self._rmtree(self.temp_dir, ignore_errors=True)
self.temp_dir = None self.temp_dir = None
def __del__(self): def __del__(self):
......
...@@ -2,14 +2,17 @@ ...@@ -2,14 +2,17 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
from ._ffi.function import ModuleBase, _set_class_module from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .contrib import cc as _cc, tar as _tar, util as _util from .contrib import cc as _cc, tar as _tar, util as _util
ProfileResult = namedtuple("ProfileResult", ["mean"]) ProfileResult = namedtuple("ProfileResult", ["mean"])
class Module(ModuleBase): class Module(ModuleBase):
"""Module container of all TVM generated functions""" """Module container of all TVM generated functions"""
def __repr__(self): def __repr__(self):
return "Module(%s, %x)" % (self.type_key, self.handle.value) return "Module(%s, %x)" % (self.type_key, self.handle.value)
...@@ -135,11 +138,13 @@ class Module(ModuleBase): ...@@ -135,11 +138,13 @@ class Module(ModuleBase):
try: try:
feval = _RPCTimeEvaluator( feval = _RPCTimeEvaluator(
self, func_name, ctx.device_type, ctx.device_id, number) self, func_name, ctx.device_type, ctx.device_id, number)
def evaluator(*args): def evaluator(*args):
"""Internal wrapped evaluator.""" """Internal wrapped evaluator."""
# Wrap feval so we can add more stats in future. # Wrap feval so we can add more stats in future.
mean = feval(*args) mean = feval(*args)
return ProfileResult(mean=mean) return ProfileResult(mean=mean)
return evaluator return evaluator
except NameError: except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled") raise NameError("time_evaluate is only supported when RPC is enabled")
......
...@@ -226,6 +226,7 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr( ...@@ -226,6 +226,7 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(
name); name);
gv->setAlignment(data_layout_->getTypeAllocSize(p_type)); gv->setAlignment(data_layout_->getTypeAllocSize(p_type));
gv->setInitializer(llvm::Constant::getNullValue(p_type)); gv->setInitializer(llvm::Constant::getNullValue(p_type));
gv->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
return gv; return gv;
} }
......
...@@ -117,6 +117,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { ...@@ -117,6 +117,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
ftype, llvm::Function::ExternalLinkage, ftype, llvm::Function::ExternalLinkage,
f->name, module_.get()); f->name, module_.get());
function_->setCallingConv(llvm::CallingConv::C); function_->setCallingConv(llvm::CallingConv::C);
function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
// set var map and align information // set var map and align information
auto arg_it = function_->arg_begin(); auto arg_it = function_->arg_begin();
for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) { for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) {
......
...@@ -41,6 +41,7 @@ ...@@ -41,6 +41,7 @@
#include <llvm/Target/TargetMachine.h> #include <llvm/Target/TargetMachine.h>
#include <llvm/Target/TargetOptions.h> #include <llvm/Target/TargetOptions.h>
#include <llvm/IRReader/IRReader.h> #include <llvm/IRReader/IRReader.h>
#include <llvm/CodeGen/TargetLoweringObjectFileImpl.h>
#include <utility> #include <utility>
#include <string> #include <string>
......
...@@ -2,6 +2,7 @@ import tvm ...@@ -2,6 +2,7 @@ import tvm
from tvm.contrib import cc, util from tvm.contrib import cc, util
import ctypes import ctypes
import os import os
import sys
import numpy as np import numpy as np
import subprocess import subprocess
...@@ -88,7 +89,13 @@ def test_device_module_dump(): ...@@ -88,7 +89,13 @@ def test_device_module_dump():
return return
temp = util.tempdir() temp = util.tempdir()
name = "myadd_%s" % device name = "myadd_%s" % device
if sys.platform == "darwin" or sys.platform.startswith('linux'):
f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name) f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
elif sys.platform == "win32":
f = tvm.build(s, [A, B], device, "llvm", name=name)
else:
raise ValueError("Unsupported platform")
path_dso = temp.relpath("dev_lib.so") path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso) f.export_library(path_dso)
...@@ -96,8 +103,9 @@ def test_device_module_dump(): ...@@ -96,8 +103,9 @@ def test_device_module_dump():
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f1(a, b) f1(a, b)
f2 = tvm.module.system_lib()
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
if sys.platform != "win32":
f2 = tvm.module.system_lib()
f2[name](a, b) f2[name](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
...@@ -165,6 +173,7 @@ def test_combine_module_llvm(): ...@@ -165,6 +173,7 @@ def test_combine_module_llvm():
mm['myadd2'](a, b) mm['myadd2'](a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
if sys.platform != "win32":
check_system_lib() check_system_lib()
check_llvm() check_llvm()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment