Commit e35666ae by Jian Weng Committed by Leyuan Wang

[Hybrid Script] Add `max_num_threads` (#2672)

* i think it works for now?

* fix lint

* fix 2/3 compat

* fix py2 again

* fine, i gave up
parent d9dc65f8
......@@ -4,6 +4,7 @@ semantic support."""
from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from .. import target as _tgt
from ..container import Array
from .. import ir_pass
from ..stmt import For
......@@ -123,7 +124,7 @@ def ceil_div(func_id, args):
_internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div")
_internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div")
a, b = args[0], args[1]
return (a + b - 1) / b
return (a + b - 1) // b
def likely(func_id, args):
......@@ -131,3 +132,14 @@ def likely(func_id, args):
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
return call_pure_intrin(args[0].dtype, 'likely', *args)
def max_num_threads(func_id, args):
_internal_assert(func_id == "max_num_threads", "This function cannot be directly invoked!")
_internal_assert(args.__len__() <= 1, "At most one argument accepted!")
if args.__len__() == 0:
res = _tgt.current_target().max_num_threads
else:
_internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint")
res = _tgt.current_target(args[0].value).max_num_threads
return _api.convert(res)
......@@ -219,6 +219,8 @@ class HybridParser(ast.NodeVisitor):
def visit_Name(self, node):
name = node.id
if sys.version_info[0] == 2 and name in ['True', 'False']:
return _api.convert(eval(name)) #pylint: disable=eval-used
ty, entry = self.symbols[name]
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
......@@ -248,6 +250,10 @@ class HybridParser(ast.NodeVisitor):
return _api.const(node.n, dtype)
def visit_NameConstant(self, node):
return _api.convert(node.value)
def visit_AugAssign(self, node):
buf = self.visit(node.target)
rhs = self.visit(node.value)
......@@ -450,17 +456,18 @@ class HybridParser(ast.NodeVisitor):
func_id = node.func.id
args = [self.visit(i) for i in node.args]
try:
# Intrinsics'
if hasattr(calls, func_id):
return getattr(calls, func_id)(func_id, args)
except AttributeError:
_internal_assert(func_id in self.symbols.keys(), \
"The function called is not in the context either!")
ty, entry = self.symbols[func_id]
_internal_assert(ty is Symbol.Callable, \
"Are you sure what you call is a function?!")
outs = entry(*args)
op = outs.op if isinstance(outs, Tensor) else outs[0].op
return op
# Contexts'
_internal_assert(func_id in self.symbols.keys(), \
"The function called (%s) is not in the context either!" % func_id)
ty, entry = self.symbols[func_id]
_internal_assert(ty is Symbol.Callable, \
"Are you sure what you call is a function?!")
outs = entry(*args)
op = outs.op if isinstance(outs, Tensor) else outs[0].op
return op
def visit_For(self, node):
......
......@@ -59,6 +59,9 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_Name(self, node):
# If it is True or False, we do not worry about it!
if sys.version_info[0] == 2 and node.id in ['True', 'False']:
return
# If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys():
return
......
"""Intrinsics of TVM-Python Hybrid Script for Python emulation runtime"""
import numpy
from .. import target
class bind(object): #pylint: disable=invalid-name
......@@ -72,34 +73,40 @@ def sigmoid(x):
return 1 / (1 + numpy.exp(-x))
def max_num_threads(allow_none=True):
"""Get max number of threads for GPU targets."""
return target.current_target(allow_none).max_num_threads
HYBRID_GLOBALS = {
'unroll' : range,
'vectorize' : range,
'parallel' : range,
'const_range' : range,
'bind' : bind,
'allocate' : allocate,
'output_tensor': allocate,
'sqrt' : numpy.sqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount,
'likely' : lambda cond: cond,
'uint8' : numpy.uint8,
'uint16' : numpy.uint16,
'uint32' : numpy.uint32,
'uint64' : numpy.uint64,
'int8' : numpy.int8,
'int16' : numpy.int16,
'int32' : numpy.int32,
'int64' : numpy.int64,
'float16' : numpy.float16,
'float32' : numpy.float32,
'float64' : numpy.float64,
'ceil_div' : lambda a, b: (a + b - 1) / b
'unroll' : range,
'vectorize' : range,
'parallel' : range,
'const_range' : range,
'bind' : bind,
'allocate' : allocate,
'output_tensor' : allocate,
'sqrt' : numpy.sqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount,
'likely' : lambda cond: cond,
'uint8' : numpy.uint8,
'uint16' : numpy.uint16,
'uint32' : numpy.uint32,
'uint64' : numpy.uint64,
'int8' : numpy.int8,
'int16' : numpy.int16,
'int32' : numpy.int32,
'int64' : numpy.int64,
'float16' : numpy.float16,
'float32' : numpy.float32,
'float64' : numpy.float64,
'ceil_div' : lambda a, b: (a + b - 1) // b,
'max_num_threads': max_num_threads
}
......
......@@ -400,6 +400,8 @@ void CodeGenHybrid::ReserveKeywords() {
GetUniqueName("for");
GetUniqueName("in");
GetUniqueName("range");
GetUniqueName("True");
GetUniqueName("False");
GetUniqueName("unroll");
GetUniqueName("const_range");
GetUniqueName("parallel");
......@@ -434,6 +436,7 @@ void CodeGenHybrid::ReserveKeywords() {
GetUniqueName("float32");
GetUniqueName("float64");
GetUniqueName("ceil_div");
GetUniqueName("max_num_threads");
}
void CodeGenHybrid::DumpStmt(const Stmt &stmt,
......
......@@ -350,6 +350,22 @@ def test_bind():
func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
@tvm.hybrid.script
def max_threads(a):
b = output_tensor(a.shape, a.dtype)
n = a.shape[0]
m = max_num_threads(True)
for i in bind('threadIdx.x', m):
for j in bind('blockIdx.x', ceil_div(n, m)):
if i * m + j < n:
b[i * m + j] = a[i * m + j] + a[i * m + j]
return b
a = tvm.placeholder((10000, ), 'float32')
with tvm.target.create('cuda'):
func, ins, outs = run_and_check(max_threads, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
def test_math_intrin():
@script
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment