arg_binder.cc 10.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/*!
 *  Copyright (c) 2017 by Contributors
 * \file arg_binder.cc
 * \brief Helper utility to match and bind arguments.
 */
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/runtime/device_api.h>
#include "./ir_util.h"
#include "./arg_binder.h"
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace ir {

void BinderAddAssert(Expr cond,
                     const std::string& arg_name,
                     std::vector<Stmt>* asserts) {
19 20
  Expr scond = Simplify(cond);
  if (is_zero(scond)) {
21 22 23
    LOG(FATAL) << "Bind have an unmet assertion: "
               << cond << ", " << " on argument " << arg_name;
  }
24
  if (!is_one(scond)) {
25 26
    std::ostringstream os;
    os << "Argument " << arg_name << " has an unsatisfied constraint";
27
    asserts->emplace_back(AssertStmt::make(scond, os.str(), Evaluate::make(0)));
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
  }
}

bool ArgBinder::Bind_(const Expr& arg,
                      const Expr& value,
                      const std::string& arg_name,
                      bool with_lets) {
  CHECK_EQ(arg.type(), value.type());
  if (const Variable* v = arg.as<Variable>()) {
    auto it = def_map_->find(v);
    if (it == def_map_->end()) {
      Var v_arg(arg.node_);
      defs_.emplace_back(v_arg);
      if (with_lets) {
        (*def_map_)[v] = arg;
        init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0)));
      } else {
        (*def_map_)[v] = value;
      }
      return true;
    } else {
      BinderAddAssert(it->second == value, arg_name, &asserts_);
    }
  } else {
    BinderAddAssert(arg == value, arg_name, &asserts_);
  }
  return false;
}

void ArgBinder::Bind(const Expr& arg,
                     const Expr& value,
                     const std::string& arg_name,
                     bool with_let) {
  Bind_(arg, value, arg_name, with_let);
}

void ArgBinder::BindArray(const Array<Expr>& arg,
                          const Array<Expr>& value,
                          const std::string& arg_name) {
  CHECK_EQ(arg.size(), value.size())
      << "Argument " << arg_name << " array size mismatch";
  for (size_t i = 0; i < arg.size(); ++i) {
    std::ostringstream os;
    os << arg_name << "[" << i << "]";
    this->Bind(arg[i], value[i], os.str());
  }
}

void ArgBinder::BindBuffer(const Buffer& arg,
                           const Buffer& value,
78 79
                           const std::string& arg_name,
                           bool fuzzy_match) {
80 81 82
  CHECK_EQ(arg->scope, value->scope)
      << "Argument " << arg_name
      << " Buffer bind scope mismatch";
83 84 85 86 87 88 89 90 91
  CHECK_EQ(arg->dtype, value->dtype)
      << "Argument " << arg_name
      << " Buffer bind data type mismatch";
  if (value->data_alignment % arg->data_alignment != 0) {
    LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement "
                 << " required_alignment=" << arg->data_alignment
                 << ", provided_alignment=" << value->data_alignment;
  }
  // bind pointer and offset.
92 93 94 95 96 97 98 99 100 101 102 103
  if (is_zero(arg->elem_offset)) {
    CHECK(is_zero(value->elem_offset))
        << "Trying to bind a Buffer with offset into one without offset";
  }

  this->Bind(arg->data, value->data, arg_name + ".data");
  if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
    if (arg->offset_factor > 1) {
      Expr offset = value->elem_offset;
      Expr factor = make_const(offset.type(), arg->offset_factor);
      Expr zero = make_zero(offset.type());
      BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
104 105
    }
  }
106

107
  if (arg->shape.size() < value->shape.size()) {
108
    CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
109
    size_t diff = value->shape.size() - arg->shape.size();
110
    for (size_t i = 0; i < diff; ++i) {
111
      CHECK(is_one(value->shape[i]))
112 113 114
          << "Argument " << arg_name << " shape mismatch"
          << arg->shape << " vs " << value->shape;
    }
115
    for (size_t i = 0; i < arg->shape.size(); ++i) {
116 117
      std::ostringstream os;
      os << arg_name << ".shape[" << i << "]";
118
      this->Bind(arg->shape[i], value->shape[i + diff], os.str());
119
    }
120
    if (value->strides.size() != 0) {
121 122
      CHECK_EQ(arg->strides.size(), arg->shape.size());
      CHECK_EQ(value->strides.size(), value->shape.size());
123
      for (size_t i = 0; i < arg->strides.size(); ++i) {
124 125
        std::ostringstream os;
        os << arg_name << ".strides[" << i << "]";
126
        this->Bind(arg->strides[i], value->strides[i + diff], os.str());
127 128 129 130 131 132
      }
    }
  } else {
    this->BindArray(arg->shape, value->shape, arg_name + ".shape");
    this->BindArray(arg->strides, value->strides, arg_name + ".strides");
  }
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
}

inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
  return TVMStructGet(t, arr, 0, kind);
}

void ArgBinder::BindDLTensor(const Buffer& buffer,
                             const Expr& device_type,
                             const Expr& device_id,
                             const Var& handle,
                             const std::string& arg_name) {
  const Type tvm_shape_type = TVMShapeIndexType();
  const Type tvm_ndim_type = Int(32);
  const Stmt nop = Evaluate::make(0);
  // dimension checks
  Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
  Expr a_ndim = make_const(tvm_ndim_type,
                           static_cast<int64_t>(buffer->shape.size()));
  std::ostringstream ndim_err_msg;
  ndim_err_msg << arg_name
               << ".ndim is expected to equal "
               << buffer->shape.size();
155
  asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
156 157 158 159 160 161 162 163 164 165
  // type checks
  Type dtype = buffer->dtype;
  std::ostringstream type_err_msg;
  type_err_msg << arg_name << ".dtype is expected to be " << dtype;
  Expr cond = (TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeCode) ==
               UIntImm::make(UInt(8), dtype.code()) &&
               TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeBits) ==
               UIntImm::make(UInt(8), dtype.bits()) &&
               TVMArrayGet(UInt(16), handle, intrinsic::kArrTypeLanes) ==
               UIntImm::make(UInt(16), dtype.lanes()));
166
  asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str(), nop));
167 168 169 170 171 172 173 174
  // data field
  if (Bind_(buffer->data, TVMArrayGet(Handle(), handle, intrinsic::kArrData),
            arg_name + ".data", true)) {
    Var vptr(buffer->data);
    def_handle_dtype_.Set(vptr, make_const(buffer->dtype, 0));
    // mark alignment of external bufs
    init_nest_.emplace_back(AttrStmt::make(
        vptr, ir::attr::storage_alignment,
175
        IntImm::make(Int(32), buffer->data_alignment), nop));
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
  }

  Var v_shape(arg_name + ".shape", Handle());
  def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
  init_nest_.emplace_back(LetStmt::make(
      v_shape, TVMArrayGet(Handle(), handle, intrinsic::kArrShape), nop));
  for (size_t k = 0; k < buffer->shape.size(); ++k) {
    std::ostringstream field_name;
    field_name << v_shape->name_hint << '[' << k << ']';
    Bind_(buffer->shape[k],
          cast(buffer->shape[k].type(),
               Load::make(tvm_shape_type, v_shape,
                          IntImm::make(Int(32), k), const_true(1))),
          field_name.str(), true);
  }
  // strides field
  Var v_strides(arg_name + ".strides", Handle());
  def_handle_dtype_.Set(v_strides, make_const(tvm_shape_type, 0));
  init_nest_.emplace_back(LetStmt::make(
      v_strides, TVMArrayGet(Handle(), handle, intrinsic::kArrStrides),
      nop));
197 198 199
  Expr is_null = Call::make(
    Bool(1), intrinsic::tvm_handle_is_null,
    {v_strides}, Call::PureIntrinsic);
200
  if (buffer->strides.size() == 0) {
201
    // Assert the buffer is compact
202
    Type stype = buffer->DefaultIndexType();
203 204 205 206 207 208 209 210 211 212 213
    Expr expect_stride = make_const(stype, 1);
    Array<Expr> conds;
    for (size_t i = buffer->shape.size(); i != 0; --i) {
      size_t k = i - 1;
      Expr svalue = cast(
          stype,
          Load::make(tvm_shape_type, v_strides,
                     IntImm::make(Int(32), k), const_true(1)));
      conds.push_back(expect_stride == svalue);
      expect_stride = expect_stride * buffer->shape[k];
    }
214 215
    std::ostringstream stride_err_msg;
    stride_err_msg << arg_name << ".strides:"
216
                   << " expected to be compact array";
217 218 219 220 221 222 223
    if (conds.size() != 0) {
      Stmt check =
          AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()),
                           stride_err_msg.str(), Evaluate::make(0));
      check = IfThenElse::make(Not::make(is_null), check, Stmt());
      init_nest_.emplace_back(Block::make(check, Evaluate::make(0)));
    }
224
  } else {
225 226 227 228
    std::ostringstream stride_null_err_msg;
    stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
    asserts_.emplace_back(AssertStmt::make(Not::make(is_null), stride_null_err_msg.str(), nop));

229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
    for (size_t k = 0; k < buffer->strides.size(); ++k) {
      std::ostringstream field_name;
      field_name << v_strides->name_hint << '[' << k << ']';
      Bind_(buffer->strides[k],
            cast(buffer->shape[k].type(),
                 Load::make(tvm_shape_type, v_strides,
                            IntImm::make(Int(32), k), const_true(1))),
            field_name.str(), true);
    }
  }
  // Byte_offset field.
  int data_bytes = GetVectorBytes(buffer->dtype);
  int64_t const_offset;
  if (arith::GetConst(buffer->elem_offset, &const_offset)) {
    Bind_(make_const(UInt(64), const_offset * data_bytes),
               TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset),
          arg_name + ".byte_offset", true);
  } else {
247 248 249 250 251 252 253 254 255 256 257 258
    if (Bind_(buffer->elem_offset,
              cast(buffer->elem_offset.type(),
                   (TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) /
                    make_const(UInt(64), data_bytes))),
              arg_name + ".elem_offset", true)) {
      if (buffer->offset_factor > 1) {
        Expr offset = buffer->elem_offset;
        Expr factor = make_const(offset.type(), buffer->offset_factor);
        Expr zero = make_zero(offset.type());
        BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
      }
    }
259 260 261 262 263 264 265 266 267 268 269 270
  }
  // device info.
  Bind_(device_type,
        TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceType),
        arg_name + ".device_type", true);
  Bind_(device_id,
        TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceId),
        arg_name + ".device_id", true);
}

}  // namespace ir
}  // namespace tvm