make_api.cc 6.92 KB
Newer Older
1 2 3 4
/*!
 *  Copyright (c) 2017 by Contributors
 * \file make_api.cc Build API function.
 */
5
#include <tvm/ir_pass.h>
6
#include <tvm/ir.h>
7
#include <tvm/ir_visitor.h>
8
#include <tvm/ir_mutator.h>
9
#include <tvm/buffer.h>
10
#include <tvm/runtime/device_api.h>
11 12 13 14
#include <vector>
#include <utility>
#include <unordered_set>

15 16
#include "ir_util.h"
#include "arg_binder.h"
17
#include "../arithmetic/compute_expr.h"
18 19

namespace tvm {
20
namespace ir {
21 22

inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
23
  return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
24 25 26 27 28
}

LoweredFunc MakeAPI(Stmt body,
                    std::string name,
                    Array<NodeRef> api_args,
29 30
                    int num_unpacked_args,
                    bool is_restricted) {
31
  const Stmt nop = Evaluate::make(0);
32 33 34
  int num_args = static_cast<int>(api_args.size());
  CHECK_LE(num_unpacked_args, num_args);
  int num_packed_args = num_args - num_unpacked_args;
35 36 37 38 39 40 41
  // Data field definitions
  // The packed fields
  Var v_packed_args("args", Handle());
  Var v_packed_arg_type_ids("arg_type_ids", Handle());
  Var v_num_packed_args("num_args", Int(32));
  // The arguments of the function.
  Array<Var> args;
42 43
  // The device context
  Var device_type("dev_type"), device_id("dev_id");
44 45 46
  // seq_init gives sequence of initialization
  // seq_check gives sequence of later checks after iniit
  std::vector<Stmt> seq_init, seq_check;
47 48
  std::unordered_map<const Variable*, Expr> vmap;
  ArgBinder binder(&vmap);
49 50 51 52
  // ---------------------------
  // local function defintiions
  // load i-th argument as type t
  auto f_arg_value = [&](Type t, int i) {
53 54 55 56 57 58 59
    Array<Expr> call_args{v_packed_args,
                          IntImm::make(Int(32), i),
                          IntImm::make(Int(32), intrinsic::kTVMValueContent)};
    // load 64 bit version
    Type api_type = APIType(t);
    Expr res = Call::make(
        api_type, intrinsic::tvm_struct_get, call_args,
60
        Call::PureIntrinsic);
61 62 63 64 65
    // cast to the target version.
    if (api_type != t) {
      res = Cast::make(t, res);
    }
    return res;
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
  };
  // get declaration of argument i
  auto f_arg_decl = [&](int i) {
    std::ostringstream os;
    os << "arg" << i;
    const Variable* v = api_args[i].as<Variable>();
    return Var(os.str(), v ? v->type: Handle());
  };
  // ---------------------------
  // start of logics
  // add signiture for packed arguments.
  if (num_packed_args != 0) {
    args.push_back(v_packed_args);
    args.push_back(v_packed_arg_type_ids);
    args.push_back(v_num_packed_args);
    std::ostringstream os;
82 83

    os << name << ": num_args should be " << num_packed_args;
84 85 86
    seq_init.emplace_back(
        MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
  }
87
  for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
88
    Var v_arg = f_arg_decl(i);
89
    if (i < num_packed_args) {
90
      // Value loads
91 92
      seq_init.emplace_back(LetStmt::make(
          v_arg, f_arg_value(v_arg.type(), i), nop));
93 94 95 96
      // type code checks
      Var tcode(v_arg->name_hint + ".code", Int(32));
      seq_init.emplace_back(LetStmt::make(
          tcode, Load::make(
97 98
              Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i), const_true(1)),
          nop));
99 100 101
      Type t = v_arg.type();
      if (t.is_handle()) {
        std::ostringstream msg;
102
        msg << name << ": Expect arg[" << i << "] to be pointer";
103 104
        seq_check.emplace_back(
            AssertStmt::make(tcode == kHandle ||
105
                             tcode == kNDArrayContainer ||
106
                             tcode == kArrayHandle ||
107
                             tcode == kNull, msg.str(), nop));
108 109
      } else if (t.is_int() || t.is_uint()) {
        std::ostringstream msg;
110
        msg << name << ": Expect arg[" << i << "] to be int";
111
        seq_check.emplace_back(AssertStmt::make(tcode == kDLInt, msg.str(), nop));
112 113 114
      } else {
        CHECK(t.is_float());
        std::ostringstream msg;
115
        msg << name << ": Expect arg[" << i << "] to be float";
116 117
        seq_check.emplace_back(
            AssertStmt::make(tcode == kDLFloat, msg.str(), nop));
118
      }
119 120 121 122 123
    } else {
      args.push_back(v_arg);
    }
    // add checks for functions.
    if (api_args[i].as<Variable>()) {
124
      binder.Bind(Var(api_args[i].node_), v_arg, v_arg->name_hint, true);
125 126 127 128 129
    } else {
      // Buffer checks
      CHECK(api_args[i].as<BufferNode>())
          << "api_args can only be Buffer or Var";
      Buffer buf(api_args[i].node_);
130 131
      binder.BindDLTensor(
          buf, device_type, device_id, v_arg, v_arg->name_hint);
132 133 134 135 136 137
    }
  }

  std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
  n->name = name;
  n->args = args;
138
  n->handle_data_type = binder.def_handle_dtype();
139
  n->is_packed_func = num_unpacked_args == 0;
140
  n->is_restricted = is_restricted;
Tianqi Chen committed
141 142 143
  body = AttrStmt::make(
      make_zero(Int(32)), attr::compute_scope,
      StringImm::make(name + "_compute_"), body);
144
  // Set device context
145
  if (vmap.count(device_id.get())) {
146
    Expr node = StringImm::make("default");
147 148
    CHECK(vmap.count(device_type.get()));
    seq_check.push_back(AttrStmt::make(
149
        node, attr::device_context_id, device_id, nop));
150
    seq_check.push_back(AttrStmt::make(
151
        node, attr::device_context_type, device_type, nop));
152
    Stmt set_device = IfThenElse::make(
153
        device_type != kDLCPU, Evaluate::make(Call::make(
154 155 156 157
            Int(32), intrinsic::tvm_call_packed,
            {StringImm::make(runtime::symbol::tvm_set_device),
             device_type, device_id}, Call::Intrinsic)));
    body = Block::make(set_device, body);
158
  }
159 160
  n->body = MergeNest(
      {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
161
  LoweredFunc f(n);
162
  Array<Var> undefined = UndefinedVars(f->body, f->args);
163 164 165 166 167 168 169 170 171 172
  if (undefined.size() != 0) {
    std::ostringstream os;
    for (Var v : undefined) {
      os << " \'" << v->name_hint << "\' ";
    }
    os << " does not appeared in api_args";
    LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
  }
  return f;
}
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204

class DeviceTypeBinder: public IRMutator {
 public:
  explicit DeviceTypeBinder(int device_type)
      : device_type_(device_type) {}

  Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
    if (op->attr_key == attr::device_context_type) {
      if (const Variable* var = op->value.as<Variable>()) {
        std::unordered_map<const Variable*, Expr> dmap;
        Expr value = make_const(op->value.type(), device_type_);
        dmap[var] = value;
        Stmt body = Substitute(s, dmap);
        std::ostringstream os;
        os << "device_type need to be " << device_type_;
        return AssertStmt::make(op->value == value, os.str(), body);
      }
    }
    return IRMutator::Mutate_(op, s);
  }

 public:
  int device_type_;
};

LoweredFunc BindDeviceType(LoweredFunc f,
                           int device_type) {
  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
  n->body = DeviceTypeBinder(device_type).Mutate(n->body);
  return LoweredFunc(n);
}

205
}  // namespace ir
206
}  // namespace tvm