/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tvm/tir/stmt.cc
 */
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>


namespace tvm {
namespace tir {

Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) {
  CHECK(value.defined());
  CHECK(body.defined());
  CHECK_EQ(value.dtype(), var.dtype());

  ObjectPtr<LetStmtNode> node = make_object<LetStmtNode>();
  node->var = std::move(var);
  node->value = std::move(value);
  node->body = std::move(body);
  return Stmt(node);
}

TVM_REGISTER_GLOBAL("tir.LetStmt")
.set_body_typed(LetStmtNode::make);

Stmt AttrStmtNode::make(ObjectRef node,
                    std::string attr_key,
                    PrimExpr value,
                    Stmt body) {
  auto n = make_object<AttrStmtNode>();
  n->node = node;
  n->attr_key = std::move(attr_key);
  n->value = std::move(value);
  n->body = std::move(body);
  return Stmt(n);
}

TVM_REGISTER_GLOBAL("tir.AttrStmt")
.set_body_typed(AttrStmtNode::make);

Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
  CHECK(condition.defined());
  CHECK(message.dtype() == DataType::Int(32) ||
        message.as<StringImmNode>())
      << "TypeError: AssertStmt message must be an int or string:"
      << message << "\n";

  ObjectPtr<AssertStmtNode> node = make_object<AssertStmtNode>();
  node->condition = std::move(condition);
  node->message = std::move(message);
  node->body = std::move(body);
  return Stmt(node);
}

TVM_REGISTER_GLOBAL("tir.AssertStmt")
.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
  if (const auto* str = message.as<StringObj>()) {
    auto msg = StringImmNode::make(str->data);
    return AssertStmtNode::make(condition, msg, body);
  } else {
    return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
  }
});

Stmt ForNode::make(Var loop_var,
                   PrimExpr min,
                   PrimExpr extent,
                   ForType for_type,
                   DeviceAPI device_api,
                   Stmt body) {
  CHECK(min.defined());
  CHECK(extent.defined());
  CHECK(min.dtype().is_scalar());
  CHECK(extent.dtype().is_scalar());
  CHECK(loop_var.dtype().is_scalar());
  CHECK(body.defined());

  ObjectPtr<ForNode> node = make_object<ForNode>();
  node->loop_var = std::move(loop_var);
  node->min = std::move(min);
  node->extent = std::move(extent);
  node->for_type = for_type;
  node->device_api = device_api;
  node->body = std::move(body);
  return Stmt(node);
}

TVM_REGISTER_GLOBAL("tir.For")
.set_body_typed([](
  Var loop_var, PrimExpr min, PrimExpr extent,
  int for_type, int device_api, Stmt body) {
  return ForNode::make(loop_var,
                       min,
                       extent,
                       static_cast<ForType>(for_type),
                       static_cast<DeviceAPI>(device_api),
                       body);
});


Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
  CHECK(value.defined());
  CHECK(index.defined());
  CHECK(predicate.defined());
  CHECK_EQ(value.dtype().lanes(), index.dtype().lanes());
  CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes());

  ObjectPtr<StoreNode> node = make_object<StoreNode>();
  node->buffer_var = std::move(buffer_var);
  node->value = std::move(value);
  node->index = std::move(index);
  node->predicate = std::move(predicate);
  return Stmt(node);
}


TVM_REGISTER_GLOBAL("tir.Store")
.set_body([](TVMArgs args,  TVMRetValue *ret) {
    PrimExpr value = args[1];
    if (args.size() == 3) {
      *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
    } else {
      *ret = StoreNode::make(args[0], value, args[2], args[3]);
    }
  });


Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<PrimExpr> args) {
  CHECK(value_index >=0 && value_index < func->num_outputs())
      << "value index output function return value bound";
  CHECK(value.defined()) << "Provide of undefined value\n";

  for (size_t i = 0; i < args.size(); ++i) {
    CHECK(args[i].defined()) << "Provide to undefined location\n";
  }

  ObjectPtr<ProvideNode> node = make_object<ProvideNode>();
  node->func = std::move(func);
  node->value_index = value_index;
  node->value = std::move(value);
  node->args = std::move(args);
  return Stmt(node);
}

TVM_REGISTER_GLOBAL("tir.Provide")
.set_body_typed(ProvideNode::make);


Stmt AllocateNode::make(Var buffer_var,
                        DataType dtype,
                        Array<PrimExpr> extents,
                        PrimExpr condition,
                        Stmt body) {
    for (size_t i = 0; i < extents.size(); ++i) {
      CHECK(extents[i].defined());
      CHECK(extents[i].dtype().is_scalar());
    }
    CHECK(body.defined());
    CHECK(condition.defined());
    CHECK(condition.dtype().is_bool());

    ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
    node->buffer_var = std::move(buffer_var);
    node->dtype = dtype;
    node->extents = std::move(extents);
    node->condition = std::move(condition);
    node->body = std::move(body);
    return Stmt(node);
}

// overloaded, needs special handling
// has default args
TVM_REGISTER_GLOBAL("tir.Allocate")
.set_body_typed([](
    Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
                   ){
  return AllocateNode::make(buffer_var, type, extents, condition, body);
});

int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
  int64_t result = 1;
  for (size_t i = 0; i < extents.size(); ++i) {
    if (const IntImmNode *int_size = extents[i].as<IntImmNode>()) {
      result *= int_size->value;
      if (result > std::numeric_limits<int32_t>::max()) {
        return 0;
      }
    } else {
      return 0;
    }
  }
  return static_cast<int32_t>(result);
}

Stmt FreeNode::make(Var buffer_var) {
  ObjectPtr<FreeNode> node = make_object<FreeNode>();
  node->buffer_var = buffer_var;
  return Stmt(node);
}

TVM_REGISTER_GLOBAL("tir.Free")
.set_body_typed(FreeNode::make);


Stmt RealizeNode::make(FunctionRef func,
                       int value_index,
                       DataType dtype,
                       Region bounds,
                       PrimExpr condition,
                       Stmt body) {
  for (size_t i = 0; i < bounds.size(); ++i) {
    CHECK(bounds[i]->min.defined());
    CHECK(bounds[i]->extent.defined());
    CHECK(bounds[i]->min.dtype().is_scalar());
    CHECK(bounds[i]->extent.dtype().is_scalar());
  }
  CHECK(body.defined());
  CHECK(condition.defined());
  CHECK(condition.dtype().is_bool());

  ObjectPtr<RealizeNode> node = make_object<RealizeNode>();
  node->func = std::move(func);
  node->value_index = value_index;
  node->dtype = dtype;
  node->bounds = std::move(bounds);
  node->condition = std::move(condition);
  node->body = std::move(body);
  return Stmt(node);
}


TVM_REGISTER_GLOBAL("tir.Realize")
.set_body_typed(RealizeNode::make);


Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
  data_ = make_object<PrefetchNode>(buffer, bounds);
}

TVM_REGISTER_GLOBAL("tir.Prefetch")
.set_body_typed([](Buffer buffer, Array<Range> bounds) {
  return Prefetch(buffer, bounds);
});


SeqStmt::SeqStmt(Array<Stmt> seq) {
  auto node = make_object<SeqStmtNode>();
  node->seq = std::move(seq);
  data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.SeqStmt")
.set_body_typed([](Array<Stmt> seq) {
  return SeqStmt(std::move(seq));
});

Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
  CHECK(condition.defined());
  CHECK(then_case.defined());
  // else_case may be null.

  ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>();
  node->condition = std::move(condition);
  node->then_case = std::move(then_case);
  node->else_case = std::move(else_case);
  return Stmt(node);
}

TVM_REGISTER_GLOBAL("tir.IfThenElse")
.set_body_typed(IfThenElseNode::make);


Stmt EvaluateNode::make(PrimExpr value) {
  CHECK(value.defined());

  ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
  node->value = std::move(value);
  return Stmt(node);
}

TVM_REGISTER_GLOBAL("tir.Evaluate")
.set_body_typed(EvaluateNode::make);

BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
  ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
  node->buffer = std::move(buffer);
  node->value = std::move(value);
  node->indices = std::move(indices);
  data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.BufferStore")
.set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
  return BufferStore(buffer, value, indices);
});

TVM_REGISTER_NODE_TYPE(BufferStoreNode);


BufferRealize::BufferRealize(Buffer buffer,
                             Array<Range> bounds,
                             PrimExpr condition,
                             Stmt body) {
  data_ = make_object<BufferRealizeNode>(
      buffer, bounds, condition, body);
}

TVM_REGISTER_GLOBAL("tir.BufferRealize")
.set_body_typed([](Buffer buffer,
                   Array<Range> bounds,
                   PrimExpr condition,
                   Stmt body) {
  return BufferRealize(buffer, bounds, condition, body);
});

TVM_REGISTER_NODE_TYPE(BufferRealizeNode);

// Printers

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const LetStmtNode*>(node.get());
    p->PrintIndent();
    p->stream << "let " << op->var << " = ";
    p->Print(op->value);
    p->stream << '\n';
    p->Print(op->body);
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const AttrStmtNode*>(node.get());
    p->PrintIndent();
    p->stream << "// attr [";
    p->Print(op->node);
    p->stream << "] "
              << op->attr_key << " = ";
    p->Print(op->value);
    p->stream << '\n';
    p->Print(op->body);
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const AssertStmtNode*>(node.get());
    p->PrintIndent();
    p->stream << "assert(";
    p->Print(op->condition);
    p->stream << ", ";
    p->Print(op->message);
    p->stream << ")\n";
    p->Print(op->body);
  });

std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
  switch (type) {
    case ForType::Serial:
      out << "for";
      break;
    case ForType::Parallel:
      out << "parallel";
      break;
    case ForType::Unrolled:
      out << "unrolled";
      break;
    case ForType::Vectorized:
      out << "vectorized";
      break;
  }
  return out;
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ForNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const ForNode*>(node.get());
    p->PrintIndent();
    p->stream << op->for_type << " (" << op->loop_var << ", ";
    p->Print(op->min);
    p->stream << ", ";
    p->Print(op->extent);
    p->stream << ") {\n";

    p->indent += 2;
    p->Print(op->body);
    p->indent -= 2;

    p->PrintIndent();
    p->stream << "}\n";
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StoreNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const StoreNode*>(node.get());
    p->PrintIndent();
    p->stream << op->buffer_var << "[";
    p->Print(op->index);
    p->stream << "] = ";
    p->Print(op->value);
    if (!is_one(op->predicate)) {
      p->stream << " if ";
      p->Print(op->predicate);
    }
    p->stream << '\n';
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ProvideNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const ProvideNode*>(node.get());
    p->PrintIndent();
    p->stream << op->func->func_name() << "(";
    for (size_t i = 0; i < op->args.size(); ++i) {
      p->Print(op->args[i]);
      if (i < op->args.size() - 1) p->stream << ", ";
    }
    p->stream << ")";
    if (op->func->num_outputs() != 1) {
      p->stream << ".value[" << op->value_index << "]";
    }
    p->stream << " =";
    p->Print(op->value);
    p->stream << '\n';
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const BufferStoreNode*>(node.get());
    p->PrintIndent();
    p->stream << op->buffer->name << "[";
    for (size_t i = 0; i < op->indices.size(); ++i) {
      p->Print(op->indices[i]);
      if (i < op->indices.size() - 1) p->stream << ", ";
    }
    p->stream << "]";
    p->stream << " = ";
    p->Print(op->value);
    p->stream << '\n';
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const AllocateNode*>(node.get());
    p->PrintIndent();
    p->stream << "allocate " << op->buffer_var << "[" << op->dtype;
    for (size_t i = 0; i < op->extents.size(); ++i) {
      p->stream << " * ";
      p->Print(op->extents[i]);
    }
    p->stream << "]";
    if (!is_one(op->condition)) {
      p->stream << " if ";
      p->Print(op->condition);
    }
    p->stream << "\n";
    p->Print(op->body);
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const FreeNode*>(node.get());
    p->PrintIndent();
    p->stream << "free " << op->buffer_var;
    p->stream << '\n';
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const BufferRealizeNode*>(node.get());
    p->PrintIndent();
    p->stream << "buffer_realize " << op->buffer->name << "(";
    for (size_t i = 0; i < op->bounds.size(); ++i) {
      p->stream << "[";
      p->Print(op->bounds[i]->min);
      p->stream << ", ";
      p->Print(op->bounds[i]->extent);
      p->stream << "]";
      if (i < op->bounds.size() - 1) p->stream << ", ";
    }
    p->stream << ")";
    if (!is_one(op->condition)) {
      p->stream << " if ";
      p->Print(op->condition);
    }
    p->stream << " {\n";

    p->indent += 2;
    p->Print(op->body);
    p->indent -= 2;

    p->PrintIndent();
    p->stream << "}\n";
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const RealizeNode*>(node.get());
    p->PrintIndent();
    p->stream << "realize " << op->func->func_name() << "(";
    for (size_t i = 0; i < op->bounds.size(); ++i) {
      p->stream << "[";
      p->Print(op->bounds[i]->min);
      p->stream << ", ";
      p->Print(op->bounds[i]->extent);
      p->stream << "]";
      if (i < op->bounds.size() - 1) p->stream << ", ";
    }
    p->stream << ")";
    if (op->func->num_outputs() != 1) {
      p->stream << ".value[" << op->value_index << "]";
    }
    if (!is_one(op->condition)) {
      p->stream << " if ";
      p->Print(op->condition);
    }
    p->stream << " {\n";

    p->indent += 2;
    p->Print(op->body);
    p->indent -= 2;

    p->PrintIndent();
    p->stream << "}\n";
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const PrefetchNode*>(node.get());
    p->PrintIndent();
    p->stream << "prefetch " << op->buffer << "(";
    for (size_t i = 0; i < op->bounds.size(); ++i) {
      p->stream << "[";
      p->Print(op->bounds[i]->min);
      p->stream << ", ";
      p->Print(op->bounds[i]->extent);
      p->stream << "]";
      if (i < op->bounds.size() - 1) p->stream << ", ";
    }
    p->stream << ")";
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const SeqStmtNode*>(node.get());
    for (Stmt stmt : op->seq) {
      p->Print(stmt);
    }
  });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const IfThenElseNode*>(node.get());
    p->PrintIndent();
    while (true) {
      p->stream << "if (" << op->condition << ") {\n";
      p->indent += 2;
      p->Print(op->then_case);
      p->indent -= 2;

      if (!op->else_case.defined()) {
        break;
      }

      if (const IfThenElseNode *nested_if = op->else_case.as<IfThenElseNode>()) {
        p->PrintIndent();
        p->stream << "} else ";
        op = nested_if;
      } else {
        p->PrintIndent();
        p->stream << "} else {\n";
        p->indent += 2;
        p->Print(op->else_case);
        p->indent -= 2;
        break;
      }
    }
    p->PrintIndent();
    p->stream << "}\n";
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const EvaluateNode*>(node.get());
    p->PrintIndent();
    p->Print(op->value);
    p->stream << "\n";
  });

template<typename T>
void PrintList(const Array<T> &exprs, ReprPrinter* p) {
  for (size_t i = 0; i < exprs.size(); ++i) {
    p->Print(exprs[i]);
    if (i < exprs.size() - 1) {
      p->stream << ", ";
    }
  }
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
    auto* op = static_cast<const ShuffleNode*>(node.get());
    p->stream << "shuffle(";
    PrintList(op->vectors, p);
    p->stream << ", ";
    PrintList(op->indices, p);
    p->stream << ")";
  });

TVM_REGISTER_NODE_TYPE(AttrStmtNode);
TVM_REGISTER_NODE_TYPE(PrefetchNode);
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_NODE_TYPE(LetStmtNode);
TVM_REGISTER_NODE_TYPE(AssertStmtNode);
TVM_REGISTER_NODE_TYPE(ForNode);
TVM_REGISTER_NODE_TYPE(StoreNode);
TVM_REGISTER_NODE_TYPE(ProvideNode);
TVM_REGISTER_NODE_TYPE(AllocateNode);
TVM_REGISTER_NODE_TYPE(FreeNode);
TVM_REGISTER_NODE_TYPE(RealizeNode);
TVM_REGISTER_NODE_TYPE(SeqStmtNode);
TVM_REGISTER_NODE_TYPE(IfThenElseNode);
TVM_REGISTER_NODE_TYPE(EvaluateNode);

}  // namespace tir
}  // namespace tvm