Commit 9c0da90f by Tianqi Chen Committed by GitHub

[PASS/SETUP] Fix minior issues (#663)

* [PASS/SETUP] Fix minior issues

* fix lint
parent 46e6cae5
...@@ -22,13 +22,21 @@ ...@@ -22,13 +22,21 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
inline Expr Simplify(Expr a) { /*!
return Halide::Internal::simplify(a); * \brief Simplify the expression.
} * \param expr The expression to be simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>());
inline Stmt Simplify(Stmt a) { /*!
return Halide::Internal::simplify(a); * \brief Simplify the statement.
} * \param stmt The statement to be simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
/*! /*!
* \brief Simplify by applying canonical form. * \brief Simplify by applying canonical form.
......
...@@ -18,16 +18,25 @@ else: ...@@ -18,16 +18,25 @@ else:
from setuptools import setup from setuptools import setup
from setuptools.extension import Extension from setuptools.extension import Extension
# We can not import `libinfo.py` in setup.py directly since __init__.py def get_lib_path():
# Will be invoked which introduces dependences """Get library path, name and version"""
CURRENT_DIR = os.path.dirname(__file__) # We can not import `libinfo.py` in setup.py directly since __init__.py
libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py') # Will be invoked which introduces dependences
libinfo = {'__file__': libinfo_py} CURRENT_DIR = os.path.dirname(__file__)
exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo) libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py')
libinfo = {'__file__': libinfo_py}
exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
lib_path = libinfo['find_lib_path']()
version = libinfo['__version__']
libs = [lib_path[0]]
if libs[0].find("runtime") == -1:
for name in lib_path[1:]:
if name.find("runtime") != -1:
libs.append(name)
break
return libs, version
LIB_PATH = libinfo['find_lib_path']() LIB_LIST, __version__ = get_lib_path()
_, LIB_NAME = os.path.split(LIB_PATH[0])
__version__ = libinfo['__version__']
def config_cython(): def config_cython():
"""Try to configure cython and return cython configuration""" """Try to configure cython and return cython configuration"""
...@@ -81,18 +90,21 @@ class BinaryDistribution(Distribution): ...@@ -81,18 +90,21 @@ class BinaryDistribution(Distribution):
# For bdist_wheel only # For bdist_wheel only
if "bdist_wheel" in sys.argv: if "bdist_wheel" in sys.argv:
shutil.copy(LIB_PATH[0], os.path.join(CURRENT_DIR, 'tvm'))
with open("MANIFEST.in", "w") as fo: with open("MANIFEST.in", "w") as fo:
fo.write("include tvm/%s\n" % LIB_NAME) for path in LIB_LIST:
shutil.copy(path, os.path.join(CURRENT_DIR, 'tvm'))
_, libname = os.path.split(path)
fo.write("include tvm/%s\n" % libname)
setup_kwargs = { setup_kwargs = {
"include_package_data": True "include_package_data": True
} }
else: else:
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
rpath = os.path.relpath(LIB_PATH[0], curr_path) for i, path in enumerate(LIB_LIST):
LIB_LIST[i] = os.path.relpath(path, curr_path)
setup_kwargs = { setup_kwargs = {
"include_package_data": True, "include_package_data": True,
"data_files": [('tvm', [rpath])] "data_files": [('tvm', LIB_LIST)]
} }
setup(name='tvm', setup(name='tvm',
...@@ -112,4 +124,6 @@ setup(name='tvm', ...@@ -112,4 +124,6 @@ setup(name='tvm',
# Wheel cleanup # Wheel cleanup
if "bdist_wheel" in sys.argv: if "bdist_wheel" in sys.argv:
os.remove("MANIFEST.in") os.remove("MANIFEST.in")
for path in LIB_LIST:
_, libname = os.path.split(path)
os.remove("tvm/%s" % LIB_NAME) os.remove("tvm/%s" % LIB_NAME)
...@@ -74,7 +74,8 @@ def find_lib_path(name=None, search_path=None): ...@@ -74,7 +74,8 @@ def find_lib_path(name=None, search_path=None):
if not use_runtime: if not use_runtime:
# try to find lib_dll_path # try to find lib_dll_path
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)] lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
if use_runtime or not lib_found: lib_found += [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)]
else:
# try to find runtime_dll_path # try to find runtime_dll_path
use_runtime = True use_runtime = True
lib_found = [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)] lib_found = [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)]
......
...@@ -16,10 +16,18 @@ namespace ir { ...@@ -16,10 +16,18 @@ namespace ir {
TVM_REGISTER_API("ir_pass.Simplify") TVM_REGISTER_API("ir_pass.Simplify")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) { if (args[0].IsNodeType<Stmt>()) {
if (args.size() > 1) {
*ret = Simplify(args[0].operator Stmt(), args[1]);
} else {
*ret = Simplify(args[0].operator Stmt()); *ret = Simplify(args[0].operator Stmt());
}
} else {
if (args.size() > 1) {
*ret = Simplify(args[0].operator Expr(), args[1]);
} else { } else {
*ret = Simplify(args[0].operator Expr()); *ret = Simplify(args[0].operator Expr());
} }
}
}); });
TVM_REGISTER_API("ir_pass.CanonicalSimplify") TVM_REGISTER_API("ir_pass.CanonicalSimplify")
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/arithmetic.h> #include <tvm/arithmetic.h>
#include "./canonical.h" #include "./canonical.h"
#include "./compute_expr.h" #include "./compute_expr.h"
#include "arithmetic/Simplify.h"
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -559,5 +560,28 @@ Stmt CanonicalSimplify(Stmt stmt) { ...@@ -559,5 +560,28 @@ Stmt CanonicalSimplify(Stmt stmt) {
Expr CanonicalSimplify(Expr expr) { Expr CanonicalSimplify(Expr expr) {
return arith::Canonical().Simplify(expr); return arith::Canonical().Simplify(expr);
} }
template<typename T>
T Simplify_(T a, Map<Var, Range> vrange) {
using namespace Halide::Internal;
Scope<Interval> rscope;
for (auto kv : vrange) {
Range r = kv.second;
rscope.push(
kv.first.get(),
Interval(r->min,
simplify(r->min + r->extent - make_const(r->min.type(), 1))));
}
return Halide::Internal::simplify(a, true, rscope);
}
Expr Simplify(Expr a, Map<Var, Range> vrange) {
return Simplify_(a, vrange);
}
Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
return Simplify_(a, vrange);
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -27,6 +27,13 @@ def test_basic(): ...@@ -27,6 +27,13 @@ def test_basic():
assert str(ret.value) == "(m - 1)" assert str(ret.value) == "(m - 1)"
def test_bound():
m = tvm.var('m')
vrange = tvm.convert({m: tvm.Range(tvm.const(0), tvm.const(10))})
ret = tvm.ir_pass.Simplify(m % 10, vrange)
assert ret == m
def test_canonical(): def test_canonical():
x = tvm.var("x") x = tvm.var("x")
z = tvm.const(3) z = tvm.const(3)
...@@ -37,6 +44,7 @@ def test_canonical(): ...@@ -37,6 +44,7 @@ def test_canonical():
assert(tvm.ir_pass.Equal(ret, 0)) assert(tvm.ir_pass.Equal(ret, 0))
if __name__ == "__main__": if __name__ == "__main__":
test_bound()
test_basic() test_basic()
test_simplify() test_simplify()
test_canonical() test_canonical()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment