#pylint: disable=unused-argument """The base node types for the Relay language.""" import topi from ..._ffi.function import _init_api from ..base import register_relay_node from ..expr import Expr from ...api import register_func from ...build_module import lower, build from . import _make @register_relay_node class Op(Expr): """A Relay operator definition.""" def __init__(self): raise RuntimeError("Cannot create op, use get instead") def get_attr(self, attr_name): """Get additional attribute about the operator. Parameters ---------- attr_name : str The attribute name. Returns ------- value : object The attribute value """ return _OpGetAttr(self, attr_name) def get(op_name): """Get the Op for a given name Parameters ---------- op_name : str The operator name Returns ------- op : Op The op of the corresponding name """ return _GetOp(op_name) def register(op_name, attr_key, value=None, level=10): """Register an operator property of an operator. Parameters ---------- op_name : str The name of operator attr_key : str The attribute name. value : object, optional The value to set level : int, optional The priority level Returns ------- fregister : function Register function if value is not specified. """ def _register(v): """internal register function""" _Register(op_name, attr_key, v, level) return v return _register(value) if value is not None else _register class OpPattern(object): """Operator generic patterns See Also -------- top.tag : Contains explanation of the tag type. """ # Elementwise operator ELEMWISE = 0 # Broadcast operator BROADCAST = 1 # Injective mapping INJECTIVE = 2 # Communication COMM_REDUCE = 3 # Complex op, can still fuse ewise into it OUT_ELEMWISE_FUSABLE = 4 # Not fusable opaque op OPAQUE = 8 def register_schedule(op_name, schedule=None, level=10): """Register schedule function for an op Parameters ---------- op_name : str The name of the op. schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule The schedule function. level : int The priority level """ return register(op_name, "FTVMSchedule", schedule, level) def register_compute(op_name, compute=None, level=10): """Register compute function for an op. Parameters ---------- op_name : str The name of the op. compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target) -> List[Tensor] The compute function. level : int The priority level """ return register(op_name, "FTVMCompute", compute, level) def register_alter_op_layout(op_name, alter_layout=None, level=10): """Register alter op layout function for an op Parameters ---------- op_name : str The name of the operator alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr The function for changing the layout or replacing the operator level : int The priority level """ return register(op_name, "FTVMAlterOpLayout", alter_layout, level) def register_pattern(op_name, pattern, level=10): """Register operator pattern for an op. Parameters ---------- op_name : str The name of the op. pattern : int The pattern being used. level : int The priority level """ return register(op_name, "TOpPattern", pattern, level) _init_api("relay.op", __name__) @register_func("relay.op.compiler._lower") def _lower(name, schedule, inputs, outputs): return lower(schedule, list(inputs) + list(outputs), name=name) @register_func("relay.op.compiler._build") def _build(lowered_funcs): return build(lowered_funcs, target="llvm") def schedule_injective(attrs, outputs, target): """Generic schedule for binary broadcast.""" with target: return topi.generic.schedule_injective(outputs) __DEBUG_COUNTER__ = 0 def debug(expr, debug_func=None): """The main entry point to the debugger.""" global __DEBUG_COUNTER__ if debug_func: name = "debugger_func{}".format(__DEBUG_COUNTER__) register_func(name, debug_func) __DEBUG_COUNTER__ += 1 else: name = '' return _make.debug(expr, name)