lower_tvm_builtin.cc 14.2 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 *  Lower TVM related builtin intrinsics such as packed call.
 * \file tir/transforms/lower_tvm_buildin.cc
23
 */
24 25
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
26 27 28
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>

29
#include <unordered_set>
30

31
#include "ir_util.h"
32

33
namespace tvm {
34
namespace tir {
35

36
inline PrimExpr ConstInt32(size_t index) {
37
  CHECK_LE(index, std::numeric_limits<int>::max());
38
  return make_const(DataType::Int(32), static_cast<int>(index));
39 40
}

41 42
inline PrimExpr StackAlloca(std::string type, size_t num) {
  Array<PrimExpr> args = {StringImmNode::make(type), ConstInt32(num)};
43 44 45 46
  return CallNode::make(
      DataType::Handle(),
      intrinsic::tvm_stack_alloca,
      args, CallNode::Intrinsic);
47 48 49 50
}

// Calculate the statistics of packed function.
// These information are needed during codegen.
51
class BuiltinLower : public StmtExprMutator {
52 53
 public:
  Stmt Build(Stmt stmt) {
54 55 56 57
    stack_shape_ = Var("stack_shape", DataType::Handle());
    stack_array_ = Var("stack_array", DataType::Handle());
    stack_value_ = Var("stack_value", DataType::Handle());
    stack_tcode_ = Var("stack_tcode", DataType::Handle());
58
    stmt = this->VisitStmt(stmt);
59
    if (max_shape_stack_ != 0) {
60
      stmt = LetStmtNode::make(
61 62 63
          stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
    }
    if (max_array_stack_ != 0) {
64
      stmt = LetStmtNode::make(
65 66 67
          stack_array_, StackAlloca("array", max_array_stack_), stmt);
    }
    if (max_arg_stack_ != 0) {
68
      stmt = LetStmtNode::make(
69
          stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
70
      stmt = LetStmtNode::make(
71 72 73 74
          stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
    }
    return stmt;
  }
75

76 77
  Stmt VisitStmt(const Stmt& s) final {
    auto stmt = StmtExprMutator::VisitStmt(s);
78 79
    CHECK_EQ(run_shape_stack_, 0);
    CHECK_EQ(run_array_stack_, 0);
80 81 82 83 84 85 86

    if (prep_seq_.size() != 0) {
      Stmt ret = SeqStmt::Flatten(prep_seq_, stmt);
      prep_seq_.clear();
      return ret;
    } else {
      return stmt;
87 88
    }
  }
89

90
  Stmt VisitStmt_(const AllocateNode* op) {
91
    // Lower allocate to device allocate when needed.
92
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
93
    op = stmt.as<AllocateNode>();
94
    // Get constant allocation bound.
95
    int64_t nbytes = GetVectorBytes(op->dtype);
96
    if (device_type_.defined()) {
97 98
      if (const auto* dev_type = device_type_.as<IntImmNode>()) {
        if (dev_type->value == kDLCPU) {
99 100 101 102 103 104 105
          int32_t constant_size = op->constant_allocation_size();
          if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
            return stmt;
          }
        }
      }
    }
106
    PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
107 108 109
    for (size_t i = 0; i < op->extents.size(); ++i) {
      total_bytes = total_bytes * op->extents[i];
    }
110 111
    CHECK(device_type_.defined()) << "Unknown device type in current IR";
    CHECK(device_id_.defined()) << "Unknown device id in current IR";
112 113 114 115
    Stmt throw_last_error = EvaluateNode::make(
        CallNode::make(DataType::Int(32),
                       intrinsic::tvm_throw_last_error, {},
                       CallNode::Intrinsic));
116

117
    Stmt body = SeqStmt({
118 119 120 121 122
        IfThenElseNode::make(
            CallNode::make(DataType::Bool(1),
                           intrinsic::tvm_handle_is_null,
                           {op->buffer_var}, CallNode::PureIntrinsic),
            throw_last_error),
123
        op->body});
124

125
    Stmt alloca = LetStmtNode::make(
126
        op->buffer_var,
127 128 129 130 131
        CallNode::make(op->buffer_var.dtype(),
                       "TVMBackendAllocWorkspace",
                       {cast(DataType::Int(32), device_type_),
                        cast(DataType::Int(32), device_id_),
                        cast(DataType::UInt(64), total_bytes),
132 133
                        IntImm(DataType::Int(32), op->dtype.code()),
                        IntImm(DataType::Int(32), op->dtype.bits())},
134
                       CallNode::Extern),
135
        body);
136

137
    PrimExpr free_op = CallNode::make(DataType::Int(32),
138 139 140 141 142 143 144
                                  "TVMBackendFreeWorkspace",
                                  {cast(DataType::Int(32), device_type_),
                                   cast(DataType::Int(32), device_id_),
                                   op->buffer_var},
                                  CallNode::Extern);
    Stmt free_stmt = IfThenElseNode::make(
        free_op != make_zero(DataType::Int(32)), throw_last_error);
145
    body = SeqStmt({alloca, free_stmt});
146
    body = AttrStmtNode::make(
147
        op->buffer_var, attr::storage_alignment,
148
        make_const(DataType::Int(32), runtime::kTempAllocaAlignment),
149 150
        body);
    return body;
151 152
  }

153
  Stmt VisitStmt_(const AttrStmtNode* op) final {
154 155 156
    if (op->attr_key == attr::device_context_id) {
      CHECK(!device_id_.defined());
      device_id_ = op->value;
157
      return this->VisitStmt(op->body);
158 159 160
    } else if (op->attr_key == attr::device_context_type) {
      CHECK(!device_type_.defined());
      device_type_ = op->value;
161
      return this->VisitStmt(op->body);
162
    } else {
163
      return StmtExprMutator::VisitStmt_(op);
164 165
    }
  }
166
  PrimExpr VisitExpr_(const CallNode* op) final {
167
    if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
168
      return MakeCallPacked(op);
169
    } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
170
      return MakeCallTracePacked(op);
171
    } else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
172
      return MakeShape(op);
173
    } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
174
      return MakeArray(op);
175
    } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
176
      return make_zero(op->dtype);
177
    } else {
178
      return StmtExprMutator::VisitExpr_(op);
179 180 181
    }
  }
  // call shape
182
  PrimExpr MakeShape(const CallNode* op) {
183 184
    size_t stack_begin = run_shape_stack_;
    run_shape_stack_ += op->args.size();
185
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
186
    op = expr.as<CallNode>();
187 188
    for (size_t i = 0; i < op->args.size(); ++i) {
      prep_seq_.emplace_back(
189
          StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]),
190
                      ConstInt32(stack_begin +i), const_true(1)));
191
    }
192
    return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
193 194
  }
  // make array
195
  PrimExpr MakeArray(const CallNode* op) {
196 197
    size_t idx = run_array_stack_;
    run_array_stack_ += 1;
198
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
199
    op = expr.as<CallNode>();
200 201 202 203
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
204
    PrimExpr strides = op->args[2];
205
    if (!strides.defined() || is_zero(strides)) {
206
      strides = make_zero(DataType::Handle());
207 208 209 210 211
    }
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
212
    DataType dtype = op->args[4].dtype();
213 214
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode,
215
                     make_const(DataType::UInt(8), static_cast<int>(dtype.code()))));
216 217
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
218
                     make_const(DataType::UInt(8), dtype.bits())));
219 220
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
221
                     make_const(DataType::UInt(16), dtype.lanes())));
222 223
    // set byte offset
    int data_bytes = GetVectorBytes(dtype);
224
    PrimExpr byte_offset = op->args[5];
225
    if (!is_zero(byte_offset)) {
226
      byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
227
    }
228 229
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
230
                     cast(DataType::UInt(64), byte_offset)));
231 232 233 234
    CHECK(device_type_.defined()) << "Unknown device type in current IR";
    CHECK(device_id_.defined()) << "Unknown device id in current IR";
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
235
                     cast(DataType::Int(32), device_id_)));
236 237
    prep_seq_.emplace_back(
        TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
238 239
                     cast(DataType::Int(32), device_type_)));
    return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
240
  }
Gus Smith committed
241
  // call packed.
242
  PrimExpr MakeCallPacked(const CallNode* op) {
243 244 245 246 247
    size_t restore_shape_stack = run_shape_stack_;
    size_t restore_array_stack = run_array_stack_;
    size_t arg_stack_begin = run_arg_stack_;
    run_arg_stack_ += op->args.size();
    // Specially handle the buffer packed intrinsic
248
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
249
    op = expr.as<CallNode>();
250
    for (size_t i = 1; i < op->args.size(); ++i) {
251 252
      PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
      PrimExpr arg = op->args[i];
253 254
      DataType t = arg.dtype();
      DataType api_type = APIType(t);
255
      if (t != api_type) {
256
        arg = CastNode::make(api_type, arg);
257 258 259 260 261
      }
      prep_seq_.emplace_back(TVMStructSet(
          stack_value_, static_cast<int>(arg_stack_begin + i - 1),
          intrinsic::kTVMValueContent, arg));
      int arg_tcode = api_type.code();
262
      if (api_type.is_handle() && arg.as<StringImmNode>()) {
263
        arg_tcode = kTVMStr;
264
      }
265
      if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
266
      prep_seq_.emplace_back(
267
          StoreNode::make(stack_tcode_,
268
                      ConstInt32(arg_tcode),
269
                      stack_index, const_true(1)));
270 271 272 273 274 275 276 277
    }
    // UPDATE stack value
    max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
    max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
    max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
    run_shape_stack_ = restore_shape_stack;
    run_array_stack_ = restore_array_stack;
    run_arg_stack_ = arg_stack_begin;
278
    Array<PrimExpr> packed_args = {
279 280 281 282 283 284
      op->args[0],
      stack_value_,
      stack_tcode_,
      ConstInt32(arg_stack_begin),
      ConstInt32(arg_stack_begin + op->args.size() - 1)
    };
285
    return CallNode::make(
286
        DataType::Int(32), intrinsic::tvm_call_packed_lowered,
287
        packed_args, CallNode::Intrinsic);
288 289
  }

290
  PrimExpr MakeCallTracePacked(const CallNode *op) {
291 292 293 294 295 296
    size_t restore_shape_stack = run_shape_stack_;
    size_t restore_array_stack = run_array_stack_;
    size_t arg_stack_begin = run_arg_stack_;
    run_arg_stack_ += op->args.size();
    size_t args_size = op->args.size();
    CHECK_GT(args_size, 0);
297
    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
298
    op = expr.as<CallNode>();
299
    for (size_t i = 1; i < op->args.size(); ++i) {
300 301
      PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
      PrimExpr arg = op->args[i];
302 303
      DataType t = arg.dtype();
      DataType api_type = APIType(t);
304
      if (t != api_type) {
305
        arg = CastNode::make(api_type, arg);
306 307 308 309 310 311 312
      }
      prep_seq_.emplace_back(TVMStructSet(
          stack_value_, static_cast<int>(arg_stack_begin + i - 1),
          intrinsic::kTVMValueContent, arg));
      int arg_tcode = api_type.code();
      CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
      prep_seq_.emplace_back(
313
          StoreNode::make(stack_tcode_,
314 315 316 317 318 319 320 321 322 323 324 325
                      ConstInt32(arg_tcode),
                      stack_index, const_true(1)));
    }
    // UPDATE stack value
    max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
    max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
    max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
    run_shape_stack_ = restore_shape_stack;
    run_array_stack_ = restore_array_stack;
    // Update the top of the stack, so we can use more than one
    // packed function's arguments with the one stack.
    run_arg_stack_ = arg_stack_begin + args_size - 1;
326
    Array<PrimExpr> packed_args = {
327 328 329 330 331 332 333 334
      op->args[0],
      stack_value_,
      stack_tcode_,
      ConstInt32(arg_stack_begin),
      ConstInt32(arg_stack_begin + op->args.size() - 1),
      // Pass traced value.
      op->args[args_size - 1]
    };
335
    return CallNode::make(
336
        op->dtype, intrinsic::tvm_call_trace_packed_lowered,
337
        packed_args, CallNode::Intrinsic);
338 339 340
  }

 private:
341
  bool IsArrayHandle(const PrimExpr& arg) {
342
    // specially set array handle.
343
    if (const CallNode* buf = arg.as<CallNode>()) {
344
      if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
345
          buf->args[2].as<IntImmNode>()->value == intrinsic::kArrAddr) {
346 347 348 349 350 351 352 353
        return true;
      }
    }
    return false;
  }

  // The prepration sequence to be emitted.
  std::vector<Stmt> prep_seq_;
354 355
  PrimExpr device_type_;
  PrimExpr device_id_;
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
  // Var handle for each stack.
  Var stack_shape_;
  Var stack_array_;
  Var stack_tcode_;
  Var stack_value_;
  // The running statistics
  uint64_t run_shape_stack_{0};
  uint64_t run_array_stack_{0};
  uint64_t run_arg_stack_{0};
  // statistics of stacks
  uint64_t max_shape_stack_{0};
  uint64_t max_array_stack_{0};
  uint64_t max_arg_stack_{0};
};

371 372 373 374 375 376 377 378 379
namespace transform {

Pass LowerTVMBuiltin() {
  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
    auto* n = f.CopyOnWrite();
    n->body = BuiltinLower().Build(n->body);
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
380 381
}

382 383 384 385
TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin")
.set_body_typed(LowerTVMBuiltin);

}  // namespace transform
386
}  // namespace tir
387
}  // namespace tvm