make_api.cc 6.86 KB
Newer Older
1 2 3 4
/*!
 *  Copyright (c) 2017 by Contributors
 * \file make_api.cc Build API function.
 */
5
#include <tvm/ir_pass.h>
6
#include <tvm/ir.h>
7
#include <tvm/ir_visitor.h>
8
#include <tvm/ir_mutator.h>
9
#include <tvm/buffer.h>
10
#include <tvm/runtime/device_api.h>
11 12 13 14
#include <vector>
#include <utility>
#include <unordered_set>

15
#include "./ir_util.h"
16
#include "./arg_binder.h"
17
#include "../arithmetic/compute_expr.h"
18 19

namespace tvm {
20
namespace ir {
21 22

inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
23
  return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
24 25 26 27 28
}

LoweredFunc MakeAPI(Stmt body,
                    std::string name,
                    Array<NodeRef> api_args,
29 30
                    int num_unpacked_args,
                    bool is_restricted) {
31
  const Stmt nop = Evaluate::make(0);
32 33 34
  int num_args = static_cast<int>(api_args.size());
  CHECK_LE(num_unpacked_args, num_args);
  int num_packed_args = num_args - num_unpacked_args;
35 36 37 38 39 40 41
  // Data field definitions
  // The packed fields
  Var v_packed_args("args", Handle());
  Var v_packed_arg_type_ids("arg_type_ids", Handle());
  Var v_num_packed_args("num_args", Int(32));
  // The arguments of the function.
  Array<Var> args;
42 43
  // The device context
  Var device_type("dev_type"), device_id("dev_id");
44 45 46
  // seq_init gives sequence of initialization
  // seq_check gives sequence of later checks after iniit
  std::vector<Stmt> seq_init, seq_check;
47 48
  std::unordered_map<const Variable*, Expr> vmap;
  ArgBinder binder(&vmap);
49 50 51 52
  // ---------------------------
  // local function defintiions
  // load i-th argument as type t
  auto f_arg_value = [&](Type t, int i) {
53 54 55 56 57 58 59
    Array<Expr> call_args{v_packed_args,
                          IntImm::make(Int(32), i),
                          IntImm::make(Int(32), intrinsic::kTVMValueContent)};
    // load 64 bit version
    Type api_type = APIType(t);
    Expr res = Call::make(
        api_type, intrinsic::tvm_struct_get, call_args,
60
        Call::PureIntrinsic);
61 62 63 64 65
    // cast to the target version.
    if (api_type != t) {
      res = Cast::make(t, res);
    }
    return res;
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
  };
  // get declaration of argument i
  auto f_arg_decl = [&](int i) {
    std::ostringstream os;
    os << "arg" << i;
    const Variable* v = api_args[i].as<Variable>();
    return Var(os.str(), v ? v->type: Handle());
  };
  // ---------------------------
  // start of logics
  // add signiture for packed arguments.
  if (num_packed_args != 0) {
    args.push_back(v_packed_args);
    args.push_back(v_packed_arg_type_ids);
    args.push_back(v_num_packed_args);
    std::ostringstream os;
82 83

    os << name << ": num_args should be " << num_packed_args;
84 85 86
    seq_init.emplace_back(
        MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
  }
87
  for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
88
    Var v_arg = f_arg_decl(i);
89
    if (i < num_packed_args) {
90
      // Value loads
91 92
      seq_init.emplace_back(LetStmt::make(
          v_arg, f_arg_value(v_arg.type(), i), nop));
93 94 95 96
      // type code checks
      Var tcode(v_arg->name_hint + ".code", Int(32));
      seq_init.emplace_back(LetStmt::make(
          tcode, Load::make(
97 98
              Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i), const_true(1)),
          nop));
99 100 101
      Type t = v_arg.type();
      if (t.is_handle()) {
        std::ostringstream msg;
102
        msg << name << ": Expect arg[" << i << "] to be pointer";
103 104 105
        seq_check.emplace_back(
            AssertStmt::make(tcode == kHandle ||
                             tcode == kArrayHandle ||
106
                             tcode == kNull, msg.str(), nop));
107 108
      } else if (t.is_int() || t.is_uint()) {
        std::ostringstream msg;
109
        msg << name << ": Expect arg[" << i << "] to be int";
110
        seq_check.emplace_back(AssertStmt::make(tcode == kDLInt, msg.str(), nop));
111 112 113
      } else {
        CHECK(t.is_float());
        std::ostringstream msg;
114
        msg << name << ": Expect arg[" << i << "] to be float";
115 116
        seq_check.emplace_back(
            AssertStmt::make(tcode == kDLFloat, msg.str(), nop));
117
      }
118 119 120 121 122
    } else {
      args.push_back(v_arg);
    }
    // add checks for functions.
    if (api_args[i].as<Variable>()) {
123
      binder.Bind(Var(api_args[i].node_), v_arg, v_arg->name_hint, true);
124 125 126 127 128
    } else {
      // Buffer checks
      CHECK(api_args[i].as<BufferNode>())
          << "api_args can only be Buffer or Var";
      Buffer buf(api_args[i].node_);
129 130
      binder.BindDLTensor(
          buf, device_type, device_id, v_arg, v_arg->name_hint);
131 132 133 134 135 136
    }
  }

  std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
  n->name = name;
  n->args = args;
137
  n->handle_data_type = binder.def_handle_dtype();
138
  n->is_packed_func = num_unpacked_args == 0;
139
  n->is_restricted = is_restricted;
Tianqi Chen committed
140 141 142
  body = AttrStmt::make(
      make_zero(Int(32)), attr::compute_scope,
      StringImm::make(name + "_compute_"), body);
143
  // Set device context
144
  if (vmap.count(device_id.get())) {
145
    Expr node = StringImm::make("default");
146 147
    CHECK(vmap.count(device_type.get()));
    seq_check.push_back(AttrStmt::make(
148
        node, attr::device_context_id, device_id, nop));
149
    seq_check.push_back(AttrStmt::make(
150
        node, attr::device_context_type, device_type, nop));
151
    Stmt set_device = IfThenElse::make(
152
        device_type != kDLCPU, Evaluate::make(Call::make(
153 154 155 156
            Int(32), intrinsic::tvm_call_packed,
            {StringImm::make(runtime::symbol::tvm_set_device),
             device_type, device_id}, Call::Intrinsic)));
    body = Block::make(set_device, body);
157
  }
158 159
  n->body = MergeNest(
      {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
160
  LoweredFunc f(n);
161
  Array<Var> undefined = UndefinedVars(f->body, f->args);
162 163 164 165 166 167 168 169 170 171
  if (undefined.size() != 0) {
    std::ostringstream os;
    for (Var v : undefined) {
      os << " \'" << v->name_hint << "\' ";
    }
    os << " does not appeared in api_args";
    LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
  }
  return f;
}
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203

class DeviceTypeBinder: public IRMutator {
 public:
  explicit DeviceTypeBinder(int device_type)
      : device_type_(device_type) {}

  Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
    if (op->attr_key == attr::device_context_type) {
      if (const Variable* var = op->value.as<Variable>()) {
        std::unordered_map<const Variable*, Expr> dmap;
        Expr value = make_const(op->value.type(), device_type_);
        dmap[var] = value;
        Stmt body = Substitute(s, dmap);
        std::ostringstream os;
        os << "device_type need to be " << device_type_;
        return AssertStmt::make(op->value == value, os.str(), body);
      }
    }
    return IRMutator::Mutate_(op, s);
  }

 public:
  int device_type_;
};

LoweredFunc BindDeviceType(LoweredFunc f,
                           int device_type) {
  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
  n->body = DeviceTypeBinder(device_type).Mutate(n->body);
  return LoweredFunc(n);
}

204
}  // namespace ir
205
}  // namespace tvm