make_api.cc 6.82 KB
Newer Older
1 2 3 4
/*!
 *  Copyright (c) 2017 by Contributors
 * \file make_api.cc Build API function.
 */
5
#include <tvm/ir_pass.h>
6
#include <tvm/ir.h>
7
#include <tvm/ir_visitor.h>
8
#include <tvm/ir_mutator.h>
9
#include <tvm/buffer.h>
10
#include <tvm/runtime/device_api.h>
11 12 13 14
#include <vector>
#include <utility>
#include <unordered_set>

15
#include "./ir_util.h"
16
#include "./arg_binder.h"
17
#include "../arithmetic/compute_expr.h"
18 19

namespace tvm {
20
namespace ir {
21 22

inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
23
  return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
24 25 26 27 28
}

LoweredFunc MakeAPI(Stmt body,
                    std::string name,
                    Array<NodeRef> api_args,
29 30
                    int num_unpacked_args,
                    bool is_restricted) {
31
  const Stmt nop = Evaluate::make(0);
32 33 34
  int num_args = static_cast<int>(api_args.size());
  CHECK_LE(num_unpacked_args, num_args);
  int num_packed_args = num_args - num_unpacked_args;
35 36 37 38 39 40 41
  // Data field definitions
  // The packed fields
  Var v_packed_args("args", Handle());
  Var v_packed_arg_type_ids("arg_type_ids", Handle());
  Var v_num_packed_args("num_args", Int(32));
  // The arguments of the function.
  Array<Var> args;
42 43
  // The device context
  Var device_type("dev_type"), device_id("dev_id");
44 45 46
  // seq_init gives sequence of initialization
  // seq_check gives sequence of later checks after iniit
  std::vector<Stmt> seq_init, seq_check;
47 48
  std::unordered_map<const Variable*, Expr> vmap;
  ArgBinder binder(&vmap);
49 50 51 52
  // ---------------------------
  // local function defintiions
  // load i-th argument as type t
  auto f_arg_value = [&](Type t, int i) {
53 54 55 56 57 58 59
    Array<Expr> call_args{v_packed_args,
                          IntImm::make(Int(32), i),
                          IntImm::make(Int(32), intrinsic::kTVMValueContent)};
    // load 64 bit version
    Type api_type = APIType(t);
    Expr res = Call::make(
        api_type, intrinsic::tvm_struct_get, call_args,
60
        Call::PureIntrinsic);
61 62 63 64 65
    // cast to the target version.
    if (api_type != t) {
      res = Cast::make(t, res);
    }
    return res;
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
  };
  // get declaration of argument i
  auto f_arg_decl = [&](int i) {
    std::ostringstream os;
    os << "arg" << i;
    const Variable* v = api_args[i].as<Variable>();
    return Var(os.str(), v ? v->type: Handle());
  };
  // ---------------------------
  // start of logics
  // add signiture for packed arguments.
  if (num_packed_args != 0) {
    args.push_back(v_packed_args);
    args.push_back(v_packed_arg_type_ids);
    args.push_back(v_num_packed_args);
    std::ostringstream os;
    os << "expected num_args to be " << num_packed_args;
    seq_init.emplace_back(
        MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
  }
86
  for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
87
    Var v_arg = f_arg_decl(i);
88
    if (i < num_packed_args) {
89
      // Value loads
90 91
      seq_init.emplace_back(LetStmt::make(
          v_arg, f_arg_value(v_arg.type(), i), nop));
92 93 94 95
      // type code checks
      Var tcode(v_arg->name_hint + ".code", Int(32));
      seq_init.emplace_back(LetStmt::make(
          tcode, Load::make(
96 97
              Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i), const_true(1)),
          nop));
98 99 100 101 102 103 104
      Type t = v_arg.type();
      if (t.is_handle()) {
        std::ostringstream msg;
        msg << "Expect argument " << i << " to be pointer";
        seq_check.emplace_back(
            AssertStmt::make(tcode == kHandle ||
                             tcode == kArrayHandle ||
105
                             tcode == kNull, msg.str(), nop));
106 107 108
      } else if (t.is_int() || t.is_uint()) {
        std::ostringstream msg;
        msg << "Expect argument " << i << " to be int";
109
        seq_check.emplace_back(AssertStmt::make(tcode == kInt, msg.str(), nop));
110 111 112 113
      } else {
        CHECK(t.is_float());
        std::ostringstream msg;
        msg << "Expect argument " << i << " to be float";
114
        seq_check.emplace_back(AssertStmt::make(tcode == kFloat, msg.str(), nop));
115
      }
116 117 118 119 120
    } else {
      args.push_back(v_arg);
    }
    // add checks for functions.
    if (api_args[i].as<Variable>()) {
121
      binder.Bind(Var(api_args[i].node_), v_arg, v_arg->name_hint, true);
122 123 124 125 126
    } else {
      // Buffer checks
      CHECK(api_args[i].as<BufferNode>())
          << "api_args can only be Buffer or Var";
      Buffer buf(api_args[i].node_);
127 128
      binder.BindDLTensor(
          buf, device_type, device_id, v_arg, v_arg->name_hint);
129 130 131 132 133 134
    }
  }

  std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
  n->name = name;
  n->args = args;
135
  n->handle_data_type = binder.def_handle_dtype();
136
  n->is_packed_func = num_unpacked_args == 0;
137
  n->is_restricted = is_restricted;
Tianqi Chen committed
138 139 140
  body = AttrStmt::make(
      make_zero(Int(32)), attr::compute_scope,
      StringImm::make(name + "_compute_"), body);
141
  // Set device context
142
  if (vmap.count(device_id.get())) {
143
    Expr node = StringImm::make("default");
144 145
    CHECK(vmap.count(device_type.get()));
    seq_check.push_back(AttrStmt::make(
146
        node, attr::device_context_id, device_id, nop));
147
    seq_check.push_back(AttrStmt::make(
148
        node, attr::device_context_type, device_type, nop));
149 150 151 152 153 154
    Stmt set_device = IfThenElse::make(
        device_type != kCPU, Evaluate::make(Call::make(
            Int(32), intrinsic::tvm_call_packed,
            {StringImm::make(runtime::symbol::tvm_set_device),
             device_type, device_id}, Call::Intrinsic)));
    body = Block::make(set_device, body);
155
  }
156 157
  n->body = MergeNest(
      {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
158
  LoweredFunc f(n);
159
  Array<Var> undefined = UndefinedVars(f->body, f->args);
160 161 162 163 164 165 166 167 168 169
  if (undefined.size() != 0) {
    std::ostringstream os;
    for (Var v : undefined) {
      os << " \'" << v->name_hint << "\' ";
    }
    os << " does not appeared in api_args";
    LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
  }
  return f;
}
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201

class DeviceTypeBinder: public IRMutator {
 public:
  explicit DeviceTypeBinder(int device_type)
      : device_type_(device_type) {}

  Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
    if (op->attr_key == attr::device_context_type) {
      if (const Variable* var = op->value.as<Variable>()) {
        std::unordered_map<const Variable*, Expr> dmap;
        Expr value = make_const(op->value.type(), device_type_);
        dmap[var] = value;
        Stmt body = Substitute(s, dmap);
        std::ostringstream os;
        os << "device_type need to be " << device_type_;
        return AssertStmt::make(op->value == value, os.str(), body);
      }
    }
    return IRMutator::Mutate_(op, s);
  }

 public:
  int device_type_;
};

LoweredFunc BindDeviceType(LoweredFunc f,
                           int device_type) {
  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
  n->body = DeviceTypeBinder(device_type).Mutate(n->body);
  return LoweredFunc(n);
}

202
}  // namespace ir
203
}  // namespace tvm