Commit 0153649e by tqchen

Checkin reduce

parent 0068781d
/*!
* Copyright (c) 2016 by Contributors
* \file ir_node.h
* \brief Additional high level nodes in the IR
*/
#ifndef TVM_IR_NODE_H_
#define TVM_IR_NODE_H_
#include <ir/Expr.h>
#include <ir/IR.h>
#include <type_traits>
#include <string>
#include "./base.h"
#include "./domain.h"
namespace tvm {
namespace ir {
using Halide::Internal::ExprNode;
using Halide::Internal::IRNodeType;
/*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> {
/*!
* \brief The binary operator of reduction
*/
std::string op;
/*! \brief The source operand */
Expr source;
/*! \brief The reduction domain */
RDomain rdom;
/*! \brief construct expr from name and rdom */
static Expr make(std::string name, Expr src, RDomain rdom);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
v->Visit("op", &op);
v->Visit("source", &source);
v->Visit("rdom", &rdom);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
static constexpr const char* Add = "Add";
static constexpr const char* Max = "Max";
static constexpr const char* Min = "Min";
};
} // namespace ir
} // namespace tvm
#endif // TVM_IR_NODE_H_
......@@ -4,15 +4,45 @@
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir_node.h>
#include <ir/IR.h>
#include <ir/IRPrinter.h>
#include <memory>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace Halide {
namespace Internal {
template<>
void ExprNode<tvm::ir::Reduce>::accept(IRVisitor *v) const {
LOG(FATAL) << "Reduce do not work with IRVisitor yet";
}
} // namespace Internal
} // namespace Halide
namespace tvm {
namespace ir {
// reduce
TVM_REGISTER_NODE_TYPE(Reduce);
Expr make(std::string op, Expr source, RDomain rdom) {
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
n->type = source.type();
n->source = source;
n->op = op;
n->rdom = rdom;
return Expr(n);
}
// HalideIR node
using namespace Halide::Internal;
TVM_REGISTER_NODE_TYPE(FloatImm);
......@@ -55,5 +85,5 @@ TVM_REGISTER_NODE_TYPE(Realize);
TVM_REGISTER_NODE_TYPE(Block);
TVM_REGISTER_NODE_TYPE(IfThenElse);
TVM_REGISTER_NODE_TYPE(Evaluate);
} // namespace ir
} // namespace tvm
......@@ -48,6 +48,10 @@ Tensor TensorNode::make(Array<Expr> shape,
Array<Var> dim_var,
Expr source) {
auto n = std::make_shared<TensorNode>();
if (source.defined()) {
CHECK_EQ(source.type(), dtype);
CHECK_EQ(dim_var.size(), shape.size());
}
n->shape = shape;
n->name = name;
n->dtype = dtype;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment