Commit 9c0da90f by Tianqi Chen Committed by GitHub

[PASS/SETUP] Fix minior issues (#663)

* [PASS/SETUP] Fix minior issues

* fix lint
parent 46e6cae5
......@@ -22,13 +22,21 @@
namespace tvm {
namespace ir {
inline Expr Simplify(Expr a) {
return Halide::Internal::simplify(a);
}
/*!
* \brief Simplify the expression.
* \param expr The expression to be simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>());
inline Stmt Simplify(Stmt a) {
return Halide::Internal::simplify(a);
}
/*!
* \brief Simplify the statement.
* \param stmt The statement to be simplifed.
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Simplify by applying canonical form.
......
......@@ -18,16 +18,25 @@ else:
from setuptools import setup
from setuptools.extension import Extension
# We can not import `libinfo.py` in setup.py directly since __init__.py
# Will be invoked which introduces dependences
CURRENT_DIR = os.path.dirname(__file__)
libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py')
libinfo = {'__file__': libinfo_py}
exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
def get_lib_path():
"""Get library path, name and version"""
# We can not import `libinfo.py` in setup.py directly since __init__.py
# Will be invoked which introduces dependences
CURRENT_DIR = os.path.dirname(__file__)
libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py')
libinfo = {'__file__': libinfo_py}
exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
lib_path = libinfo['find_lib_path']()
version = libinfo['__version__']
libs = [lib_path[0]]
if libs[0].find("runtime") == -1:
for name in lib_path[1:]:
if name.find("runtime") != -1:
libs.append(name)
break
return libs, version
LIB_PATH = libinfo['find_lib_path']()
_, LIB_NAME = os.path.split(LIB_PATH[0])
__version__ = libinfo['__version__']
LIB_LIST, __version__ = get_lib_path()
def config_cython():
"""Try to configure cython and return cython configuration"""
......@@ -81,18 +90,21 @@ class BinaryDistribution(Distribution):
# For bdist_wheel only
if "bdist_wheel" in sys.argv:
shutil.copy(LIB_PATH[0], os.path.join(CURRENT_DIR, 'tvm'))
with open("MANIFEST.in", "w") as fo:
fo.write("include tvm/%s\n" % LIB_NAME)
for path in LIB_LIST:
shutil.copy(path, os.path.join(CURRENT_DIR, 'tvm'))
_, libname = os.path.split(path)
fo.write("include tvm/%s\n" % libname)
setup_kwargs = {
"include_package_data": True
}
else:
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
rpath = os.path.relpath(LIB_PATH[0], curr_path)
for i, path in enumerate(LIB_LIST):
LIB_LIST[i] = os.path.relpath(path, curr_path)
setup_kwargs = {
"include_package_data": True,
"data_files": [('tvm', [rpath])]
"data_files": [('tvm', LIB_LIST)]
}
setup(name='tvm',
......@@ -112,4 +124,6 @@ setup(name='tvm',
# Wheel cleanup
if "bdist_wheel" in sys.argv:
os.remove("MANIFEST.in")
os.remove("tvm/%s" % LIB_NAME)
for path in LIB_LIST:
_, libname = os.path.split(path)
os.remove("tvm/%s" % LIB_NAME)
......@@ -74,7 +74,8 @@ def find_lib_path(name=None, search_path=None):
if not use_runtime:
# try to find lib_dll_path
lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
if use_runtime or not lib_found:
lib_found += [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)]
else:
# try to find runtime_dll_path
use_runtime = True
lib_found = [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)]
......
......@@ -16,9 +16,17 @@ namespace ir {
TVM_REGISTER_API("ir_pass.Simplify")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsNodeType<Stmt>()) {
*ret = Simplify(args[0].operator Stmt());
if (args.size() > 1) {
*ret = Simplify(args[0].operator Stmt(), args[1]);
} else {
*ret = Simplify(args[0].operator Stmt());
}
} else {
*ret = Simplify(args[0].operator Expr());
if (args.size() > 1) {
*ret = Simplify(args[0].operator Expr(), args[1]);
} else {
*ret = Simplify(args[0].operator Expr());
}
}
});
......
......@@ -7,6 +7,7 @@
#include <tvm/arithmetic.h>
#include "./canonical.h"
#include "./compute_expr.h"
#include "arithmetic/Simplify.h"
namespace tvm {
namespace arith {
......@@ -559,5 +560,28 @@ Stmt CanonicalSimplify(Stmt stmt) {
Expr CanonicalSimplify(Expr expr) {
return arith::Canonical().Simplify(expr);
}
template<typename T>
T Simplify_(T a, Map<Var, Range> vrange) {
using namespace Halide::Internal;
Scope<Interval> rscope;
for (auto kv : vrange) {
Range r = kv.second;
rscope.push(
kv.first.get(),
Interval(r->min,
simplify(r->min + r->extent - make_const(r->min.type(), 1))));
}
return Halide::Internal::simplify(a, true, rscope);
}
Expr Simplify(Expr a, Map<Var, Range> vrange) {
return Simplify_(a, vrange);
}
Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
return Simplify_(a, vrange);
}
} // namespace ir
} // namespace tvm
......@@ -27,6 +27,13 @@ def test_basic():
assert str(ret.value) == "(m - 1)"
def test_bound():
m = tvm.var('m')
vrange = tvm.convert({m: tvm.Range(tvm.const(0), tvm.const(10))})
ret = tvm.ir_pass.Simplify(m % 10, vrange)
assert ret == m
def test_canonical():
x = tvm.var("x")
z = tvm.const(3)
......@@ -37,6 +44,7 @@ def test_canonical():
assert(tvm.ir_pass.Equal(ret, 0))
if __name__ == "__main__":
test_bound()
test_basic()
test_simplify()
test_canonical()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment