Commit e35666ae by Jian Weng Committed by Leyuan Wang

[Hybrid Script] Add `max_num_threads` (#2672)

* i think it works for now?

* fix lint

* fix 2/3 compat

* fix py2 again

* fine, i gave up
parent d9dc65f8
...@@ -4,6 +4,7 @@ semantic support.""" ...@@ -4,6 +4,7 @@ semantic support."""
from .. import api as _api from .. import api as _api
from .. import expr as _expr from .. import expr as _expr
from .. import make as _make from .. import make as _make
from .. import target as _tgt
from ..container import Array from ..container import Array
from .. import ir_pass from .. import ir_pass
from ..stmt import For from ..stmt import For
...@@ -123,7 +124,7 @@ def ceil_div(func_id, args): ...@@ -123,7 +124,7 @@ def ceil_div(func_id, args):
_internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div") _internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div")
_internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div") _internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div")
a, b = args[0], args[1] a, b = args[0], args[1]
return (a + b - 1) / b return (a + b - 1) // b
def likely(func_id, args): def likely(func_id, args):
...@@ -131,3 +132,14 @@ def likely(func_id, args): ...@@ -131,3 +132,14 @@ def likely(func_id, args):
"Only one expression can be likely") "Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!") _internal_assert(func_id == "likely", "This function cannot be directly invoked!")
return call_pure_intrin(args[0].dtype, 'likely', *args) return call_pure_intrin(args[0].dtype, 'likely', *args)
def max_num_threads(func_id, args):
_internal_assert(func_id == "max_num_threads", "This function cannot be directly invoked!")
_internal_assert(args.__len__() <= 1, "At most one argument accepted!")
if args.__len__() == 0:
res = _tgt.current_target().max_num_threads
else:
_internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint")
res = _tgt.current_target(args[0].value).max_num_threads
return _api.convert(res)
...@@ -219,6 +219,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -219,6 +219,8 @@ class HybridParser(ast.NodeVisitor):
def visit_Name(self, node): def visit_Name(self, node):
name = node.id name = node.id
if sys.version_info[0] == 2 and name in ['True', 'False']:
return _api.convert(eval(name)) #pylint: disable=eval-used
ty, entry = self.symbols[name] ty, entry = self.symbols[name]
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name) _internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]: if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
...@@ -248,6 +250,10 @@ class HybridParser(ast.NodeVisitor): ...@@ -248,6 +250,10 @@ class HybridParser(ast.NodeVisitor):
return _api.const(node.n, dtype) return _api.const(node.n, dtype)
def visit_NameConstant(self, node):
return _api.convert(node.value)
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
buf = self.visit(node.target) buf = self.visit(node.target)
rhs = self.visit(node.value) rhs = self.visit(node.value)
...@@ -450,17 +456,18 @@ class HybridParser(ast.NodeVisitor): ...@@ -450,17 +456,18 @@ class HybridParser(ast.NodeVisitor):
func_id = node.func.id func_id = node.func.id
args = [self.visit(i) for i in node.args] args = [self.visit(i) for i in node.args]
try: # Intrinsics'
if hasattr(calls, func_id):
return getattr(calls, func_id)(func_id, args) return getattr(calls, func_id)(func_id, args)
except AttributeError: # Contexts'
_internal_assert(func_id in self.symbols.keys(), \ _internal_assert(func_id in self.symbols.keys(), \
"The function called is not in the context either!") "The function called (%s) is not in the context either!" % func_id)
ty, entry = self.symbols[func_id] ty, entry = self.symbols[func_id]
_internal_assert(ty is Symbol.Callable, \ _internal_assert(ty is Symbol.Callable, \
"Are you sure what you call is a function?!") "Are you sure what you call is a function?!")
outs = entry(*args) outs = entry(*args)
op = outs.op if isinstance(outs, Tensor) else outs[0].op op = outs.op if isinstance(outs, Tensor) else outs[0].op
return op return op
def visit_For(self, node): def visit_For(self, node):
......
...@@ -59,6 +59,9 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -59,6 +59,9 @@ class PyVariableUsage(ast.NodeVisitor):
def visit_Name(self, node): def visit_Name(self, node):
# If it is True or False, we do not worry about it!
if sys.version_info[0] == 2 and node.id in ['True', 'False']:
return
# If it is from the argument list or loop variable, we do not worry about it! # If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys(): if node.id in self._args.keys():
return return
......
"""Intrinsics of TVM-Python Hybrid Script for Python emulation runtime""" """Intrinsics of TVM-Python Hybrid Script for Python emulation runtime"""
import numpy import numpy
from .. import target
class bind(object): #pylint: disable=invalid-name class bind(object): #pylint: disable=invalid-name
...@@ -72,34 +73,40 @@ def sigmoid(x): ...@@ -72,34 +73,40 @@ def sigmoid(x):
return 1 / (1 + numpy.exp(-x)) return 1 / (1 + numpy.exp(-x))
def max_num_threads(allow_none=True):
"""Get max number of threads for GPU targets."""
return target.current_target(allow_none).max_num_threads
HYBRID_GLOBALS = { HYBRID_GLOBALS = {
'unroll' : range, 'unroll' : range,
'vectorize' : range, 'vectorize' : range,
'parallel' : range, 'parallel' : range,
'const_range' : range, 'const_range' : range,
'bind' : bind, 'bind' : bind,
'allocate' : allocate, 'allocate' : allocate,
'output_tensor': allocate, 'output_tensor' : allocate,
'sqrt' : numpy.sqrt, 'sqrt' : numpy.sqrt,
'log' : numpy.log, 'log' : numpy.log,
'tanh' : numpy.tanh, 'tanh' : numpy.tanh,
'power' : numpy.power, 'power' : numpy.power,
'exp' : numpy.exp, 'exp' : numpy.exp,
'sigmoid' : sigmoid, 'sigmoid' : sigmoid,
'popcount' : popcount, 'popcount' : popcount,
'likely' : lambda cond: cond, 'likely' : lambda cond: cond,
'uint8' : numpy.uint8, 'uint8' : numpy.uint8,
'uint16' : numpy.uint16, 'uint16' : numpy.uint16,
'uint32' : numpy.uint32, 'uint32' : numpy.uint32,
'uint64' : numpy.uint64, 'uint64' : numpy.uint64,
'int8' : numpy.int8, 'int8' : numpy.int8,
'int16' : numpy.int16, 'int16' : numpy.int16,
'int32' : numpy.int32, 'int32' : numpy.int32,
'int64' : numpy.int64, 'int64' : numpy.int64,
'float16' : numpy.float16, 'float16' : numpy.float16,
'float32' : numpy.float32, 'float32' : numpy.float32,
'float64' : numpy.float64, 'float64' : numpy.float64,
'ceil_div' : lambda a, b: (a + b - 1) / b 'ceil_div' : lambda a, b: (a + b - 1) // b,
'max_num_threads': max_num_threads
} }
......
...@@ -400,6 +400,8 @@ void CodeGenHybrid::ReserveKeywords() { ...@@ -400,6 +400,8 @@ void CodeGenHybrid::ReserveKeywords() {
GetUniqueName("for"); GetUniqueName("for");
GetUniqueName("in"); GetUniqueName("in");
GetUniqueName("range"); GetUniqueName("range");
GetUniqueName("True");
GetUniqueName("False");
GetUniqueName("unroll"); GetUniqueName("unroll");
GetUniqueName("const_range"); GetUniqueName("const_range");
GetUniqueName("parallel"); GetUniqueName("parallel");
...@@ -434,6 +436,7 @@ void CodeGenHybrid::ReserveKeywords() { ...@@ -434,6 +436,7 @@ void CodeGenHybrid::ReserveKeywords() {
GetUniqueName("float32"); GetUniqueName("float32");
GetUniqueName("float64"); GetUniqueName("float64");
GetUniqueName("ceil_div"); GetUniqueName("ceil_div");
GetUniqueName("max_num_threads");
} }
void CodeGenHybrid::DumpStmt(const Stmt &stmt, void CodeGenHybrid::DumpStmt(const Stmt &stmt,
......
...@@ -350,6 +350,22 @@ def test_bind(): ...@@ -350,6 +350,22 @@ def test_bind():
func, ins, outs = run_and_check(foo, [a], target='cuda') func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda') run_and_check(func, ins, outs=outs, target='cuda')
@tvm.hybrid.script
def max_threads(a):
b = output_tensor(a.shape, a.dtype)
n = a.shape[0]
m = max_num_threads(True)
for i in bind('threadIdx.x', m):
for j in bind('blockIdx.x', ceil_div(n, m)):
if i * m + j < n:
b[i * m + j] = a[i * m + j] + a[i * m + j]
return b
a = tvm.placeholder((10000, ), 'float32')
with tvm.target.create('cuda'):
func, ins, outs = run_and_check(max_threads, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
def test_math_intrin(): def test_math_intrin():
@script @script
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment