"""The build pipeline in python. Eventually some of these pipelines will be moved to C++. But the first pipeline will be kept in python for ease of change and evolving. """ # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments from . import api from . import tensor from . import schedule from . import expr from . import ir_pass from . import codegen def build(sch, args, target, target_host="stackvm", name="default_function", binds=None, max_auto_unroll_step=8): """Build a function with arguments as signiture. Parameters ---------- sch : tvm.Schedule The schedule to be builded args : list of Buffer or Tensor or Var The argument lists to the function. target : str The target of the compilation. target_host : Host compilation target, if target is device. name : str The name of result function. binds : dict, optional Dictionary that maps the binding of symbolic buffer to Tensor. By default, a new buffer is created for each tensor in the argument. max_auto_unroll_step: int Maximum step to perform automatic unrolling Returns ------- f : Function, or pair of functions The result function. If the function requires host space allocation, a pair of functions will be returned. """ binds = {} if binds is None else binds.copy() arg_list = [] for x in args: if isinstance(x, tensor.Tensor): buf = api.Buffer(x.shape, dtype=x.dtype, name=x.op.name) assert x not in binds binds[x] = buf arg_list.append(buf) elif isinstance(x, schedule.Buffer): arg_list.append(x) elif isinstance(x, expr.Var): arg_list.append(x) else: raise ValueError("args must be Tensor, Buffer or Var") # normalize schedule first sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.LiftAllocate(stmt) stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step) stmt = ir_pass.Simplify(stmt) fapi = ir_pass.MakeAPI(stmt, name, arg_list, 0) fsplits = ir_pass.SplitHostDevice(fapi) fsplits = [x for x in fsplits] for i in range(1, len(fsplits)): fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared") if len(fsplits) > 1: mhost = codegen.build(fsplits[0], target_host) if target: mdev = codegen.build(fsplits[1:], target) mhost.import_module(mdev) return mhost else: return codegen.build(fsplits[0], target)