codegen_spirv.cc 23.7 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file codegen_spirv.cc
 * \brief Generate SPIRV block
 */
24 25
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
26 27
#include <string>
#include "codegen_spirv.h"
28
#include "../../arith/compute_expr.h"
29 30 31 32 33 34 35 36 37 38 39

namespace tvm {
namespace codegen {

std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const LoweredFunc& f) {
  this->InitFuncState();
  CHECK(f->is_restricted)
      << "SPIRV only takes restricted memory model";
  std::vector<Var> pod_args;
  uint32_t num_buffer = 0;
  for (Var arg : f->args) {
40
    DataType t = arg.dtype();
41 42 43
    if (t.is_handle()) {
      auto it = f->handle_data_type.find(arg);
      if (it != f->handle_data_type.end()) {
44
        DataType value_type = (*it).second.dtype();
45 46 47 48 49 50 51 52 53 54 55 56
        spirv::Value arg_value = builder_->BufferArgument(
            builder_->GetSType(value_type), 0, num_buffer);
        storage_info_[arg.get()].UpdateContentType(value_type);
        var_map_[arg.get()] = arg_value;
      } else {
        LOG(FATAL) << "require all handles to be typed";
      }
      ++num_buffer;
    } else {
      pod_args.push_back(arg);
    }
  }
57
  spirv::Value func_ptr = builder_->NewFunction();
58 59 60 61 62 63
  builder_->StartFunction(func_ptr);

  // All the POD arguments are passed in through PushConstant
  if (pod_args.size() != 0) {
    std::vector<spirv::SType> value_types;
    for (size_t i = 0; i < pod_args.size(); ++i) {
64
      value_types.push_back(builder_->GetSType(pod_args[i].dtype()));
65 66 67 68 69 70 71 72 73 74 75 76 77
    }
    spirv::Value ptr = builder_->DeclarePushConstant(value_types);
    for (size_t i = 0; i < pod_args.size(); ++i) {
      spirv::Value value = builder_->GetPushConstant(
          ptr, value_types[i], static_cast<uint32_t>(i));
      var_map_[pod_args[i].get()] = value;
    }
  }
  this->VisitStmt(f->body);
  builder_->SetLocalSize(func_ptr, workgroup_size_);
  builder_->MakeInst(spv::OpReturn);
  builder_->MakeInst(spv::OpFunctionEnd);

78 79
  builder_->CommitKernelFunction(func_ptr, f->name);

80 81 82 83 84 85 86
  return builder_->Finalize();
}

void CodeGenSPIRV::InitFuncState() {
  std::fill(workgroup_size_, workgroup_size_ + 3, 1);
  var_map_.clear();
  storage_info_.clear();
87
  analyzer_.reset(new arith::Analyzer());
88 89 90 91 92
  builder_.reset(new spirv::IRBuilder());
  builder_->InitHeader();
}

spirv::Value CodeGenSPIRV::GetThreadIndex(
93
    const IterVar& iv, const PrimExpr& extent) {
94 95 96 97
  runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
  spirv::Value v;
  if (ts.rank == 1) {
    v = builder_->GetLocalID(ts.dim_index);
98
    int size = 0;
99 100 101 102 103 104 105
    CHECK(arith::GetConstInt(extent, &size))
        << "SPIRV only allows constant thread group size " << " get " << extent;
    CHECK_LT(ts.dim_index, 3);
    workgroup_size_[ts.dim_index] = static_cast<uint32_t>(size);
  } else {
    v = builder_->GetWorkgroupID(ts.dim_index);
  }
106
  return builder_->Cast(builder_->GetSType(iv->var.dtype()), v);
107 108
}

109 110
spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
  const std::string& sync = op->args[0].as<StringImmNode>()->value;
111 112 113 114
  spirv::Value value;
  if (sync == "warp") {
    return value;
  } else if (sync == "shared") {
115
    auto type_int = builder_->GetSType(DataType::Int(32));
116
    builder_->MakeInst(
117 118 119 120
      spv::OpControlBarrier,
      builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
      builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
      builder_->IntImm(type_int, static_cast<int64_t>(
121
        spv::MemorySemanticsSequentiallyConsistentMask |
122
        spv::MemorySemanticsWorkgroupMemoryMask)));
123 124 125 126 127 128
  } else {
    LOG(FATAL) << "Do not support sync " << sync;
  }
  return value;
}

129
spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) {
130 131 132 133 134
  auto it = var_map_.find(op);
  CHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint;
  return it->second;
}

135
spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) {
136
  return builder_->IntImm(builder_->GetSType(op->dtype), op->value);
137 138
}

139
spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) {
140
  return builder_->FloatImm(builder_->GetSType(op->dtype), op->value);
141 142
}

143
spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) {
144 145 146 147
  LOG(FATAL) << "StringImm is not supported in Device code";
  return spirv::Value();
}

148
spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) {
149
  return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value));
150 151
}

152
spirv::Value CodeGenSPIRV::VisitExpr_(const AddNode* op) {
153 154 155
  return builder_->Add(MakeValue(op->a), MakeValue(op->b));
}

156
spirv::Value CodeGenSPIRV::VisitExpr_(const SubNode* op) {
157 158 159
  return builder_->Sub(MakeValue(op->a), MakeValue(op->b));
}

160
spirv::Value CodeGenSPIRV::VisitExpr_(const MulNode* op) {
161 162 163
  return builder_->Mul(MakeValue(op->a), MakeValue(op->b));
}

164
spirv::Value CodeGenSPIRV::VisitExpr_(const DivNode* op) {
165 166 167
  return builder_->Div(MakeValue(op->a), MakeValue(op->b));
}

168
spirv::Value CodeGenSPIRV::VisitExpr_(const ModNode* op) {
169 170 171
  return builder_->Mod(MakeValue(op->a), MakeValue(op->b));
}

172
spirv::Value CodeGenSPIRV::VisitExpr_(const MinNode* op) {
173 174 175 176 177
  spirv::Value a = MakeValue(op->a);
  spirv::Value b = MakeValue(op->b);
  return builder_->Select(builder_->LT(a, b), a, b);
}

178
spirv::Value CodeGenSPIRV::VisitExpr_(const MaxNode* op) {
179 180 181 182 183
  spirv::Value a = MakeValue(op->a);
  spirv::Value b = MakeValue(op->b);
  return builder_->Select(builder_->GT(a, b), a, b);
}

184
spirv::Value CodeGenSPIRV::VisitExpr_(const LTNode* op) {
185 186 187
  return builder_->LT(MakeValue(op->a), MakeValue(op->b));
}

188
spirv::Value CodeGenSPIRV::VisitExpr_(const LENode* op) {
189 190 191
  return builder_->LE(MakeValue(op->a), MakeValue(op->b));
}

192
spirv::Value CodeGenSPIRV::VisitExpr_(const GTNode* op) {
193 194 195
  return builder_->GT(MakeValue(op->a), MakeValue(op->b));
}

196
spirv::Value CodeGenSPIRV::VisitExpr_(const GENode* op) {
197 198 199
  return builder_->GE(MakeValue(op->a), MakeValue(op->b));
}

200
spirv::Value CodeGenSPIRV::VisitExpr_(const EQNode* op) {
201 202 203
  return builder_->EQ(MakeValue(op->a), MakeValue(op->b));
}

204
spirv::Value CodeGenSPIRV::VisitExpr_(const NENode* op) {
205 206 207
  return builder_->NE(MakeValue(op->a), MakeValue(op->b));
}

208
spirv::Value CodeGenSPIRV::VisitExpr_(const AndNode* op) {
209 210 211 212 213
  spirv::Value a = MakeValue(op->a);
  spirv::Value b = MakeValue(op->b);
  return builder_->MakeValue(spv::OpLogicalAnd, a.stype, a, b);
}

214
spirv::Value CodeGenSPIRV::VisitExpr_(const OrNode* op) {
215 216 217 218 219
  spirv::Value a = MakeValue(op->a);
  spirv::Value b = MakeValue(op->b);
  return builder_->MakeValue(spv::OpLogicalOr, a.stype, a, b);
}

220
spirv::Value CodeGenSPIRV::VisitExpr_(const NotNode* op) {
221 222 223 224
  spirv::Value a = MakeValue(op->a);
  return builder_->MakeValue(spv::OpLogicalNot, a.stype, a);
}

225
spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
226 227 228 229 230
  return builder_->Select(MakeValue(op->condition),
                          MakeValue(op->true_value),
                          MakeValue(op->false_value));
}

231
spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
232 233
  CHECK(!var_map_.count(op->var.get()));
  var_map_[op->var.get()] = MakeValue(op->value);
234
  analyzer_->Bind(op->var, op->value);
235 236 237
  return MakeValue(op->body);
}

238
spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
239 240
  if (op->is_intrinsic("spirv_glsl450")) {
    CHECK_GE(op->args.size(), 2U);
241 242
    uint32_t inst_id = static_cast<uint32_t>(
        op->args[0].as<IntImmNode>()->value);
243 244 245 246 247
    std::vector<spirv::Value> values;
    for (size_t i = 1; i < op->args.size(); ++i) {
      values.push_back(MakeValue(op->args[i]));
    }
    return builder_->CallGLSL450(
248
        builder_->GetSType(op->dtype), inst_id, values);
249
  } else if (op->is_intrinsic(CallNode::bitwise_and)) {
250 251 252 253
    CHECK_EQ(op->args.size(), 2U);
    spirv::Value a = MakeValue(op->args[0]);
    spirv::Value b = MakeValue(op->args[1]);
    return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b);
254
  } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
255 256 257 258
    CHECK_EQ(op->args.size(), 2U);
    spirv::Value a = MakeValue(op->args[0]);
    spirv::Value b = MakeValue(op->args[1]);
    return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b);
259
  } else if (op->is_intrinsic(CallNode::bitwise_or)) {
260 261 262 263
    CHECK_EQ(op->args.size(), 2U);
    spirv::Value a = MakeValue(op->args[0]);
    spirv::Value b = MakeValue(op->args[1]);
    return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b);
264
  } else if (op->is_intrinsic(CallNode::bitwise_not)) {
265 266 267
    CHECK_EQ(op->args.size(), 1U);
    spirv::Value a = MakeValue(op->args[0]);
    return builder_->MakeValue(spv::OpNot, a.stype, a);
268
  } else if (op->is_intrinsic(CallNode::shift_left)) {
269 270 271 272
    CHECK_EQ(op->args.size(), 2U);
    spirv::Value a = MakeValue(op->args[0]);
    spirv::Value b = MakeValue(op->args[1]);
    return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b);
273
  } else if (op->is_intrinsic(CallNode::shift_right)) {
274 275 276
    CHECK_EQ(op->args.size(), 2U);
    spirv::Value a = MakeValue(op->args[0]);
    spirv::Value b = MakeValue(op->args[1]);
277
    if (op->args[0].dtype().is_int()) {
278 279 280 281
      return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b);
    } else {
      return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
    }
282
  } else if (op->is_intrinsic(CallNode::reinterpret)) {
283
    return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype),
284
                               MakeValue(op->args[0]));
285 286 287 288 289 290
  } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
    CHECK_EQ(op->args.size(), 2U);
    uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
    uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
    uint64_t val = (high << 32U) | low;
    return builder_->UIntImm(builder_->GetSType(op->dtype), val);
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
  } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
    return this->CreateStorageSync(op);
  } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
    CHECK_EQ(op->args.size(), 3U);
    spirv::Value cond = MakeValue(op->args[0]);
    spirv::Label then_label = builder_->NewLabel();
    spirv::Label else_label = builder_->NewLabel();
    spirv::Label merge_label = builder_->NewLabel();
    builder_->MakeInst(
        spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
    builder_->MakeInst(
        spv::OpBranchConditional, cond, then_label, else_label);
    // then block, must get label after we see the value
    builder_->StartLabel(then_label);
    spirv::Value then_value = MakeValue(op->args[1]);
    spirv::Label then_value_label = builder_->CurrentLabel();
    builder_->MakeInst(spv::OpBranch, merge_label);
    // else block
    builder_->StartLabel(else_label);
    spirv::Value else_value = MakeValue(op->args[2]);
    spirv::Label else_value_label = builder_->CurrentLabel();
    builder_->MakeInst(spv::OpBranch, merge_label);
    // merge block
    builder_->StartLabel(merge_label);
    spirv::PhiValue phi = builder_->MakePhi(then_value.stype, 2);
    phi.SetIncoming(0, then_value, then_value_label);
    phi.SetIncoming(1, else_value, else_value_label);
    return phi;
  } else if (op->is_intrinsic("popcount")) {
    return builder_->MakeValue(
        spv::OpBitCount,
322
        builder_->GetSType(op->dtype),
323 324
        MakeValue(op->args[0]));
  } else {
325 326
    if (op->call_type == CallNode::Intrinsic ||
        op->call_type == CallNode::PureIntrinsic) {
327
      LOG(FATAL) << "Unresolved intrinsic " << op->name
328
                 << " with return type " << op->dtype;
329 330
    } else if (op->call_type == CallNode::Extern ||
               op->call_type == CallNode::PureExtern) {
331
      LOG(FATAL) << "Unresolved extern " << op->name
332
                 << " with return type " << op->dtype;
333 334 335 336 337 338 339
    } else {
      LOG(FATAL) << "Unresolved call type " << op->call_type;
    }
    return spirv::Value();
  }
}

340
spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) {
341 342 343 344 345 346
  std::vector<spirv::Value> values;
  spirv::Value base = MakeValue(op->base);
  for (int i = 0; i < op->lanes; ++i) {
    spirv::Value v = base;
    if (i != 0) {
      spirv::Value offset = MakeValue(
347
          make_const(op->stride.dtype(), i) * op->stride);
348 349 350 351 352 353 354
      v = builder_->Add(v, offset);
    }
    values.push_back(v);
  }
  return builder_->Concat(values);
}

355
spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {
356 357 358 359 360 361 362 363
  std::vector<spirv::Value> values;
  spirv::Value v = MakeValue(op->value);
  for (int i = 0; i < op->lanes; i++) {
    values.push_back(v);
  }
  return builder_->Concat(values);
}

364
spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
365 366 367 368 369
  CHECK(is_one(op->predicate));
  auto it = storage_info_.find(op->buffer_var.get());
  CHECK(it != storage_info_.end());
  StorageInfo& info = it->second;
  if (!info.content_fixed) {
370
    info.UpdateContentType(op->dtype);
371 372 373 374 375 376 377 378 379 380 381
  }

  spirv::SType content_type = builder_->GetSType(info.content_type);
  spirv::Value buffer = MakeValue(op->buffer_var);
  spirv::SType ptr_type = builder_->GetPointerType(
      content_type, buffer.stype.storage_class);

  uint32_t mask = spv::MemoryAccessMaskNone;
  if (info.is_volatile) {
    mask |= spv::MemoryAccessVolatileMask;
  }
382 383
  if (op->dtype.lanes() == 1) {
    CHECK_EQ(info.content_type, op->dtype)
384 385 386 387 388 389
        << "Vulkan only allow one type access to the same buffer";
    spirv::Value index = MakeValue(op->index);
    spirv::Value ptr = builder_->StructArrayAccess(
        ptr_type, buffer, index);
    return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
  } else {
390
    if (op->dtype.element_of() == info.content_type) {
391 392 393 394 395 396 397 398 399 400 401
      // because content type is element type, we can only do scalarize load.
      std::vector<spirv::Value> values;
      auto f = [&](int i, spirv::Value index) {
        spirv::Value ptr = builder_->StructArrayAccess(
            ptr_type, buffer, index);
        values.emplace_back(
            builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
      };
      this->Scalarize(op->index, f);
      return builder_->Concat(values);
    } else {
402
      if (const RampNode* ramp = op->index.as<RampNode>()) {
403
        if (is_one(ramp->stride)) {
404
          CHECK_EQ(ramp->lanes, op->dtype.lanes());
405 406 407
          arith::ModularSet me = analyzer_->modular_set(ramp->base);
          CHECK((me->coeff % ramp->lanes) == 0 &&
                (me->base % ramp->lanes)  == 0)
408
              << "Only aligned vector access is allowed in SPIRV";
409
          PrimExpr vec_index = tir::Simplify(
410
              ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
411 412 413 414 415 416 417 418 419 420 421 422
          spirv::Value ptr = builder_->StructArrayAccess(
              ptr_type, buffer, MakeValue(vec_index));
          return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
        }
      }
    }
    LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
  }
  LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
  return spirv::Value();
}

423
void CodeGenSPIRV::Scalarize(const PrimExpr& e,
424
                             std::function<void(int i, spirv::Value v)> f) {
425
  if (const RampNode* ramp = e.as<RampNode>()) {
426
    for (int i = 0; i < ramp->dtype.lanes(); ++i) {
427
      PrimExpr offset = ramp->base + ramp->stride * i;
428 429 430
      f(i, MakeValue(offset));
    }
  } else {
431
    spirv::SType etype = builder_->GetSType(e.dtype().element_of());
432
    spirv::Value value = MakeValue(e);
433
    for (int i = 0; i < e.dtype().lanes(); ++i) {
434 435 436 437 438 439
      f(i, builder_->MakeValue(
          spv::OpCompositeExtract, etype, value, i));
    }
  }
}

440
void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
441 442 443 444 445 446
  CHECK(is_one(op->predicate));
  auto it = storage_info_.find(op->buffer_var.get());
  CHECK(it != storage_info_.end());
  StorageInfo& info = it->second;

  if (!info.content_fixed) {
447
    info.UpdateContentType(op->value.dtype());
448 449 450 451 452 453 454 455 456 457 458 459 460
  }

  spirv::SType content_type = builder_->GetSType(info.content_type);
  spirv::Value buffer = MakeValue(op->buffer_var);
  spirv::Value value = MakeValue(op->value);
  spirv::SType ptr_type = builder_->GetPointerType(
      content_type, buffer.stype.storage_class);

  uint32_t mask = spv::MemoryAccessMaskNone;
  if (info.is_volatile) {
    mask |= spv::MemoryAccessVolatileMask;
  }

461 462
  if (op->value.dtype().lanes() == 1) {
    CHECK_EQ(info.content_type, op->value.dtype())
463 464 465 466 467 468
        << "Vulkan only allow one type access to the same buffer";
    spirv::Value index = MakeValue(op->index);
    spirv::Value ptr = builder_->StructArrayAccess(
        ptr_type, buffer, index);
    builder_->MakeInst(spv::OpStore, ptr, value, mask);
  } else {
469
    if (op->value.dtype().element_of() == info.content_type) {
470 471 472 473 474 475 476 477 478 479
      // because content type is element type, we can only do scalarize load.
      auto f = [&](int i, spirv::Value index) {
        spirv::Value elem = builder_->MakeValue(
            spv::OpCompositeExtract, content_type, value, i);
        spirv::Value ptr = builder_->StructArrayAccess(
            ptr_type, buffer, index);
        builder_->MakeInst(spv::OpStore, ptr, elem, mask);
      };
      this->Scalarize(op->index, f);
    } else {
480
      if (const RampNode* ramp = op->index.as<RampNode>()) {
481
        if (is_one(ramp->stride)) {
482
          CHECK_EQ(ramp->lanes, op->value.dtype().lanes());
483 484 485
          arith::ModularSet me = analyzer_->modular_set(ramp->base);
          CHECK((me->coeff % ramp->lanes) == 0 &&
                (me->base % ramp->lanes)  == 0)
486
              << "Only aligned vector access is allowed in SPIRV";
487
          PrimExpr vec_index = tir::Simplify(
488
              ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
489 490 491 492 493 494 495 496 497 498 499
          spirv::Value ptr = builder_->StructArrayAccess(
              ptr_type, buffer, MakeValue(vec_index));
          builder_->MakeInst(spv::OpStore, ptr, value, mask);
          return;
        }
      }
      LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
    }
  }
}

500
void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
501
  CHECK(is_zero(op->min));
502
  analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
  spirv::Value init_value = MakeValue(op->min);
  spirv::Value extent_value = MakeValue(op->extent);
  // Must get init label after making value(to make sure they are correct)
  spirv::Label init_label = builder_->CurrentLabel();
  spirv::Label head_label = builder_->NewLabel();
  spirv::Label body_label = builder_->NewLabel();
  spirv::Label continue_label = builder_->NewLabel();
  spirv::Label merge_label = builder_->NewLabel();
  builder_->MakeInst(spv::OpBranch, head_label);

  // Loop head
  builder_->StartLabel(head_label);
  spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
  loop_var.SetIncoming(0, init_value, init_label);
  spirv::Value loop_cond = builder_->LT(loop_var, extent_value);
  uint32_t control = (
      op->for_type == ForType::Unrolled ?
      spv::LoopControlUnrollMask : spv::LoopControlMaskNone);
  builder_->MakeInst(
      spv::OpLoopMerge, merge_label, continue_label, control);
  builder_->MakeInst(
      spv::OpBranchConditional, loop_cond, body_label, merge_label,
      weight_likely_branch_, 1);

  // loop body
  builder_->StartLabel(body_label);
  var_map_[op->loop_var.get()] = spirv::Value(loop_var);
  this->VisitStmt(op->body);
  builder_->MakeInst(spv::OpBranch, continue_label);

  // loop continue
  builder_->StartLabel(continue_label);
  spirv::Value one =
536
      op->loop_var.dtype().is_int() ?
537 538 539 540 541 542 543 544 545
      builder_->IntImm(loop_var.stype, 1) :
      builder_->UIntImm(loop_var.stype, 1);
  spirv::Value next_value = builder_->Add(loop_var, one);
  loop_var.SetIncoming(1, next_value, builder_->CurrentLabel());
  builder_->MakeInst(spv::OpBranch, head_label);
  // loop merge
  builder_->StartLabel(merge_label);
}

546
void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
  spirv::Value cond = MakeValue(op->condition);
  spirv::Label then_label = builder_->NewLabel();
  spirv::Label merge_label = builder_->NewLabel();
  if (op->else_case.defined()) {
    spirv::Label else_label = builder_->NewLabel();
    builder_->MakeInst(
        spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
    builder_->MakeInst(
        spv::OpBranchConditional, cond, then_label, else_label);
    // then block
    builder_->StartLabel(then_label);
    this->VisitStmt(op->then_case);
    builder_->MakeInst(spv::OpBranch, merge_label);
    // else block
    builder_->StartLabel(else_label);
    this->VisitStmt(op->else_case);
    builder_->MakeInst(spv::OpBranch, merge_label);
  } else {
    builder_->MakeInst(
        spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
    builder_->MakeInst(
        spv::OpBranchConditional, cond, then_label, merge_label,
        weight_likely_branch_, 1);
    // then block
    builder_->StartLabel(then_label);
    this->VisitStmt(op->then_case);
    builder_->MakeInst(spv::OpBranch, merge_label);
  }
  // start merge label;
  builder_->StartLabel(merge_label);
}

579
void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
580 581
  CHECK(!is_zero(op->condition));
  CHECK(!op->new_expr.defined());
582
  CHECK(!op->dtype.is_handle());
583 584 585 586 587
  int32_t constant_size = op->constant_allocation_size();
  CHECK_GT(constant_size, 0)
      << "Can only handle constant size stack allocation in GPU";
  spirv::Value buf;
  StorageInfo& info = storage_info_[op->buffer_var.get()];
588
  spirv::SType etype = builder_->GetSType(op->dtype);
589
  if (info.scope.rank == runtime::StorageRank::kLocal) {
590 591 592 593 594
    buf = builder_->Allocate(
        etype, static_cast<uint32_t>(constant_size),
        spv::StorageClassFunction);
  } else {
    // shared memory
595
    CHECK(info.scope.rank == runtime::StorageRank::kShared)
596 597 598 599 600 601 602
        << "Can only allocate shared or local memory inside kernel";
    // Shared memory
    buf = builder_->Allocate(
        etype, static_cast<uint32_t>(constant_size),
        spv::StorageClassWorkgroup);
  }
  CHECK(!info.content_fixed);
603
  info.UpdateContentType(op->dtype);
604 605 606 607 608
  CHECK(!var_map_.count(op->buffer_var.get()));
  var_map_[op->buffer_var.get()] = buf;
  this->VisitStmt(op->body);
}

609
void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
610
  if (op->attr_key == attr::thread_extent) {
611
    IterVar iv = Downcast<IterVar>(op->node);
612 613 614
    if (iv->thread_tag.length() != 0) {
      if (!var_map_.count(iv->var.get())) {
        var_map_[iv->var.get()] = GetThreadIndex(iv, op->value);
615
        analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
616 617
      }
    }
618
  } else if (op->attr_key == tir::attr::storage_scope) {
619
    const VarNode* v = op->node.as<VarNode>();
620 621
    CHECK(v);
    storage_info_[v].scope =
622
        runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
623
  } else if (op->attr_key == tir::attr::volatile_scope) {
624
    const VarNode* v = op->node.as<VarNode>();
625 626 627 628 629 630
    CHECK(v);
    storage_info_[v].is_volatile = true;
  }
  this->VisitStmt(op->body);
}

631
void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) {
632
  With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
633
  this->VisitStmt(op->body);
634 635
}

636
void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) {
637
  CHECK(!var_map_.count(op->var.get()));
638
  CHECK(!op->var.dtype().is_handle());
639
  var_map_[op->var.get()] = MakeValue(op->value);
640
  analyzer_->Bind(op->var, op->value);
641 642 643
  this->VisitStmt(op->body);
}

644 645 646
void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) {
  for (Stmt stmt : op->seq) {
    this->VisitStmt(stmt);
647 648 649
  }
}

650
void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) {
651 652 653
  MakeValue(op->value);
}

654
void CodeGenSPIRV::VisitStmt_(const ProducerConsumerNode* op) {
655 656 657 658 659
  this->VisitStmt(op->body);
}

}  // namespace codegen
}  // namespace tvm