/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tensor.cc */ #include <tvm/runtime/registry.h> #include <tvm/te/tensor.h> #include <tvm/te/operation.h> #include <tvm/te/tensor_intrin.h> #include <memory> namespace tvm { namespace te { IterVar thread_axis(Range dom, std::string tag) { return IterVarNode::make( dom, Var(tag), kThreadIndex, tag); } IterVar reduce_axis(Range dom, std::string name) { return IterVarNode::make( dom, Var(name), kCommReduce); } Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor PrimExpr Tensor::operator()(Array<Var> indices) const { Array<PrimExpr> arr(indices.begin(), indices.end()); return operator()(arr); } PrimExpr Tensor::operator()(Array<PrimExpr> indices) const { using tir::CallNode; if (ndim() != 0) { CHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read" << "ndim = " << ndim() << ", indices.size=" << indices.size(); } auto n = CallNode::make( (*this)->dtype, (*this)->op->name, indices, CallNode::Halide, (*this)->op, (*this)->value_index); return n; } Tensor Operation::output(size_t i) const { auto node = make_object<TensorNode>(); node->op = *this; node->value_index = i; node->dtype = (*this)->output_dtype(i); node->shape = (*this)->output_shape(i); return Tensor(node); } Tensor TensorNode::make(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) { auto n = make_object<TensorNode>(); n->shape = std::move(shape); n->dtype = dtype; n->op = op; n->value_index = value_index; return Tensor(n); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<TensorNode>([](const ObjectRef& node, ReprPrinter* p) { auto* t = static_cast<const TensorNode*>(node.get()); p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; }); TVM_REGISTER_NODE_TYPE(TensorNode); // TensorIntrin TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array<Tensor> inputs, Array<Buffer> buffers, Array<Var> scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { auto n = make_object<TensorIntrinNode>(); n->name = std::move(name); n->op = std::move(op); n->inputs = std::move(inputs); n->buffers = std::move(buffers); n->scalar_params = std::move(scalar_params); n->body = std::move(body); n->reduce_init = std::move(reduce_init); n->reduce_update = std::move(reduce_update); return TensorIntrin(n); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<TensorIntrinNode>([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast<const TensorIntrinNode*>(node.get()); p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; }); TVM_REGISTER_NODE_TYPE(TensorIntrinNode); // TensorIntrinCall TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions, Array<IterVar> reduce_axis, Array<PrimExpr> scalar_inputs) { auto n = make_object<TensorIntrinCallNode>(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); n->regions = std::move(regions); n->reduce_axis = std::move(reduce_axis); n->scalar_inputs = std::move(scalar_inputs); return TensorIntrinCall(n); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, ReprPrinter* p) { auto* n = static_cast<const TensorIntrinCallNode*>(node.get()); p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; }); TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); TVM_REGISTER_GLOBAL("te.Tensor") .set_body_typed(TensorNode::make); TVM_REGISTER_GLOBAL("te.TensorIntrin") .set_body_typed(TensorIntrinNode::make); TVM_REGISTER_GLOBAL("te.TensorIntrinCall") .set_body_typed(TensorIntrinCallNode::make); TVM_REGISTER_GLOBAL("te.TensorEqual") .set_body_method(&Tensor::operator==); TVM_REGISTER_GLOBAL("te.TensorHash") .set_body_typed([](Tensor tensor) -> int64_t { return static_cast<int64_t>(std::hash<Tensor>()(tensor)); }); TVM_REGISTER_GLOBAL("te.OpGetOutput") .set_body_typed([](Operation op, int64_t output) { return op.output(static_cast<size_t>(output)); }); TVM_REGISTER_GLOBAL("te.OpNumOutputs") .set_body_method<Operation>(&OperationNode::num_outputs); TVM_REGISTER_GLOBAL("te.OpInputTensors") .set_body_method<Operation>(&OperationNode::InputTensors); } // namespace te } // namespace tvm