arg_binder.cc 13 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file arg_binder.cc
 * \brief Helper utility to match and bind arguments.
 */
24 25
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
26
#include <tvm/runtime/device_api.h>
27 28
#include "ir_util.h"
#include "arg_binder.h"
29
#include "../../arith/compute_expr.h"
30 31

namespace tvm {
32
namespace tir {
33

34
void BinderAddAssert(PrimExpr cond,
35 36
                     const std::string& arg_name,
                     std::vector<Stmt>* asserts) {
37
  PrimExpr scond = Simplify(cond);
38
  if (is_zero(scond)) {
39 40 41
    LOG(FATAL) << "Bind have an unmet assertion: "
               << cond << ", " << " on argument " << arg_name;
  }
42
  if (!is_one(scond)) {
43 44
    std::ostringstream os;
    os << "Argument " << arg_name << " has an unsatisfied constraint";
45 46
    asserts->emplace_back(AssertStmtNode::make(scond, tvm::tir::StringImmNode::make(os.str()),
                                               EvaluateNode::make(0)));
47 48 49
  }
}

50 51
bool ArgBinder::Bind_(const PrimExpr& arg,
                      const PrimExpr& value,
52 53
                      const std::string& arg_name,
                      bool with_lets) {
54
  CHECK_EQ(arg.dtype(), value.dtype());
55
  if (const VarNode* v = arg.as<VarNode>()) {
56 57
    auto it = def_map_->find(v);
    if (it == def_map_->end()) {
58
      Var v_arg = Downcast<Var>(arg);
59 60 61
      defs_.emplace_back(v_arg);
      if (with_lets) {
        (*def_map_)[v] = arg;
62
        init_nest_.emplace_back(LetStmtNode::make(v_arg, value, EvaluateNode::make(0)));
63 64 65 66 67 68 69 70 71 72 73 74 75
      } else {
        (*def_map_)[v] = value;
      }
      return true;
    } else {
      BinderAddAssert(it->second == value, arg_name, &asserts_);
    }
  } else {
    BinderAddAssert(arg == value, arg_name, &asserts_);
  }
  return false;
}

76 77
void ArgBinder::Bind(const PrimExpr& arg,
                     const PrimExpr& value,
78 79 80 81 82
                     const std::string& arg_name,
                     bool with_let) {
  Bind_(arg, value, arg_name, with_let);
}

83 84
void ArgBinder::BindArray(const Array<PrimExpr>& arg,
                          const Array<PrimExpr>& value,
85 86 87 88 89 90 91 92 93 94 95 96
                          const std::string& arg_name) {
  CHECK_EQ(arg.size(), value.size())
      << "Argument " << arg_name << " array size mismatch";
  for (size_t i = 0; i < arg.size(); ++i) {
    std::ostringstream os;
    os << arg_name << "[" << i << "]";
    this->Bind(arg[i], value[i], os.str());
  }
}

void ArgBinder::BindBuffer(const Buffer& arg,
                           const Buffer& value,
97 98
                           const std::string& arg_name,
                           bool fuzzy_match) {
99 100 101
  CHECK_EQ(arg->scope, value->scope)
      << "Argument " << arg_name
      << " Buffer bind scope mismatch";
102 103 104 105 106 107 108 109 110
  CHECK_EQ(arg->dtype, value->dtype)
      << "Argument " << arg_name
      << " Buffer bind data type mismatch";
  if (value->data_alignment % arg->data_alignment != 0) {
    LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement "
                 << " required_alignment=" << arg->data_alignment
                 << ", provided_alignment=" << value->data_alignment;
  }
  // bind pointer and offset.
111 112
  if (is_zero(arg->elem_offset)) {
    CHECK(is_zero(value->elem_offset))
113 114 115
        << "Trying to bind a Buffer with offset into one without offset "
        << " required elem_offset=" << arg->elem_offset
        << ", provided elem_offset=" << value->elem_offset;
116 117 118 119 120
  }

  this->Bind(arg->data, value->data, arg_name + ".data");
  if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
    if (arg->offset_factor > 1) {
121 122 123
      PrimExpr offset = value->elem_offset;
      PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
      PrimExpr zero = make_zero(offset.dtype());
124 125
      BinderAddAssert(truncmod(offset, factor) == zero,
                      arg_name + ".elem_offset", &asserts_);
126 127
    }
  }
128

129
  if (arg->shape.size() < value->shape.size()) {
130
    CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
131
    size_t diff = value->shape.size() - arg->shape.size();
132
    for (size_t i = 0; i < diff; ++i) {
133
      CHECK(is_one(Simplify(value->shape[i])))
134 135 136
          << "Argument " << arg_name << " shape mismatch"
          << arg->shape << " vs " << value->shape;
    }
137
    for (size_t i = 0; i < arg->shape.size(); ++i) {
138 139
      std::ostringstream os;
      os << arg_name << ".shape[" << i << "]";
140
      this->Bind(arg->shape[i], value->shape[i + diff], os.str());
141
    }
142
    if (value->strides.size() != 0) {
143 144
      CHECK_EQ(arg->strides.size(), arg->shape.size());
      CHECK_EQ(value->strides.size(), value->shape.size());
145
      for (size_t i = 0; i < arg->strides.size(); ++i) {
146 147
        std::ostringstream os;
        os << arg_name << ".strides[" << i << "]";
148
        this->Bind(arg->strides[i], value->strides[i + diff], os.str());
149 150 151 152 153 154
      }
    }
  } else {
    this->BindArray(arg->shape, value->shape, arg_name + ".shape");
    this->BindArray(arg->strides, value->strides, arg_name + ".strides");
  }
155 156
}

157
inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) {
158 159 160 161
  return TVMStructGet(t, arr, 0, kind);
}

void ArgBinder::BindDLTensor(const Buffer& buffer,
162 163
                             const PrimExpr& device_type,
                             const PrimExpr& device_id,
164 165
                             const Var& handle,
                             const std::string& arg_name) {
166 167
  const DataType tvm_shape_type = DataType::ShapeIndex();
  const DataType tvm_ndim_type = DataType::Int(32);
168
  const Stmt nop = EvaluateNode::make(0);
169
  // dimension checks
170 171
  PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
  PrimExpr a_ndim = make_const(tvm_ndim_type,
172 173 174 175 176
                           static_cast<int64_t>(buffer->shape.size()));
  std::ostringstream ndim_err_msg;
  ndim_err_msg << arg_name
               << ".ndim is expected to equal "
               << buffer->shape.size();
177 178
  auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str());
  asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
179
  // type checks
180
  DataType dtype = buffer->dtype;
181 182
  std::ostringstream type_err_msg;
  type_err_msg << arg_name << ".dtype is expected to be " << dtype;
183
  PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
184
               IntImm(DataType::UInt(8), dtype.code()) &&
185
               TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
186
               IntImm(DataType::UInt(8), dtype.bits()) &&
187
               TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
188
               IntImm(DataType::UInt(16), dtype.lanes()));
189 190 191
  if (!(dtype == DataType::Int(4) ||
        dtype == DataType::UInt(4) ||
        dtype == DataType::Int(1))) {
192 193 194
    auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str());
    asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
    asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop));
195
  }
196
  // data field
197
  if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
198 199
            arg_name + ".data", true)) {
    Var vptr(buffer->data);
200
    def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
201
    // mark alignment of external bufs
202
    init_nest_.emplace_back(AttrStmtNode::make(
203
        vptr, tir::attr::storage_alignment,
204
        IntImm(DataType::Int(32), buffer->data_alignment), nop));
205 206
  }

207
  Var v_shape(arg_name + ".shape", DataType::Handle());
208
  def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
209
  init_nest_.emplace_back(LetStmtNode::make(
210
      v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
211
  for (size_t k = 0; k < buffer->shape.size(); ++k) {
212 213 214 215 216
    if (dtype == DataType::Int(4) ||
        dtype == DataType::UInt(4) ||
        dtype == DataType::Int(1)) {
      break;
    }
217 218 219
    std::ostringstream field_name;
    field_name << v_shape->name_hint << '[' << k << ']';
    Bind_(buffer->shape[k],
220
          cast(buffer->shape[k].dtype(),
221
               LoadNode::make(tvm_shape_type, v_shape,
222
                          IntImm(DataType::Int(32), k), const_true(1))),
223 224 225
          field_name.str(), true);
  }
  // strides field
226
  Var v_strides(arg_name + ".strides", DataType::Handle());
227
  def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
228
  init_nest_.emplace_back(LetStmtNode::make(
229
      v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides),
230
      nop));
231
  PrimExpr is_null = CallNode::make(
232
    DataType::Bool(1), intrinsic::tvm_handle_is_null,
233
    {v_strides}, CallNode::PureIntrinsic);
234
  if (buffer->strides.size() == 0) {
235
    // Assert the buffer is compact
236
    DataType stype = buffer->DefaultIndexType();
237 238
    PrimExpr expect_stride = make_const(stype, 1);
    Array<PrimExpr> conds;
239 240
    for (size_t i = buffer->shape.size(); i != 0; --i) {
      size_t k = i - 1;
241
      PrimExpr svalue = cast(
242
          stype,
243
          LoadNode::make(tvm_shape_type, v_strides,
244
                     IntImm(DataType::Int(32), k), const_true(1)));
245 246 247
      conds.push_back(expect_stride == svalue);
      expect_stride = expect_stride * buffer->shape[k];
    }
248 249
    std::ostringstream stride_err_msg;
    stride_err_msg << arg_name << ".strides:"
250
                   << " expected to be compact array";
251
    if (conds.size() != 0) {
252
      auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str());
253
      Stmt check =
254
          AssertStmtNode::make(arith::ComputeReduce<tir::AndNode>(conds, PrimExpr()),
255
                           stride_msg, EvaluateNode::make(0));
256 257
      check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
      asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
258
    }
259
  } else if (buffer->buffer_type == kAutoBroadcast) {
260
    DataType stype = buffer->DefaultIndexType();
261
    PrimExpr stride = make_const(stype, 1);
262 263 264 265
    for (size_t i = buffer->shape.size(); i != 0; --i) {
      size_t k = i - 1;
      std::ostringstream field_name;
      field_name << v_strides->name_hint << '[' << k << ']';
266
      PrimExpr value = cast(buffer->shape[k].dtype(),
267
                        LoadNode::make(tvm_shape_type, v_strides,
268
                                   IntImm(DataType::Int(32), k), const_true(1)));
269 270 271 272 273
      value = tvm::if_then_else(is_null, stride, value);
      value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
      Bind_(buffer->strides[k], value, field_name.str(), true);
      stride = Simplify(stride * buffer->shape[k]);
    }
274
  } else {
275 276
    std::ostringstream stride_null_err_msg;
    stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
277 278
    asserts_.emplace_back(AssertStmtNode::make(
        NotNode::make(is_null), tvm::tir::StringImmNode::make(stride_null_err_msg.str()), nop));
279

280 281 282 283
    for (size_t k = 0; k < buffer->strides.size(); ++k) {
      std::ostringstream field_name;
      field_name << v_strides->name_hint << '[' << k << ']';
      Bind_(buffer->strides[k],
284
            cast(buffer->shape[k].dtype(),
285
                 LoadNode::make(tvm_shape_type, v_strides,
286
                            IntImm(DataType::Int(32), k), const_true(1))),
287 288 289 290 291 292 293
            field_name.str(), true);
    }
  }
  // Byte_offset field.
  int data_bytes = GetVectorBytes(buffer->dtype);
  int64_t const_offset;
  if (arith::GetConst(buffer->elem_offset, &const_offset)) {
294 295
    Bind_(make_const(DataType::UInt(64), const_offset * data_bytes),
               TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset),
296 297
          arg_name + ".byte_offset", true);
  } else {
298
    if (Bind_(buffer->elem_offset,
299 300 301
              cast(buffer->elem_offset.dtype(),
                   (TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset) /
                    make_const(DataType::UInt(64), data_bytes))),
302 303
              arg_name + ".elem_offset", true)) {
      if (buffer->offset_factor > 1) {
304 305 306
        PrimExpr offset = buffer->elem_offset;
        PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
        PrimExpr zero = make_zero(offset.dtype());
307
        BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_);
308 309
      }
    }
310 311 312
  }
  // device info.
  Bind_(device_type,
313
        TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType),
314 315
        arg_name + ".device_type", true);
  Bind_(device_id,
316
        TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId),
317 318 319
        arg_name + ".device_id", true);
}

320
}  // namespace tir
321
}  // namespace tvm