extern_op.cc 5.8 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 * \brief External computation rule.
 * \file extern_op.cc
 */
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <unordered_set>
28
#include "op_util.h"
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

namespace tvm {
using namespace ir;
// ExternOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ExternOpNode>([](const ExternOpNode *op, IRPrinter *p) {
    p->stream << "extern(" << op->name << ", " << op << ")";
  });

TVM_REGISTER_NODE_TYPE(ExternOpNode);

int ExternOpNode::num_outputs() const {
  return static_cast<int>(output_placeholders.size());
}

Array<IterVar> ExternOpNode::root_iter_vars() const {
  return {};
}

Type ExternOpNode::output_dtype(size_t i) const {
  return output_placeholders[i]->dtype;
}

Array<Expr> ExternOpNode::output_shape(size_t i) const {
  return output_placeholders[i]->shape;
}


Operation ExternOpNode::make(std::string name,
58
                             std::string tag,
59
                             Map<std::string, NodeRef> attrs,
60 61 62 63
                             Array<Tensor> inputs,
                             Array<Buffer> input_placeholders,
                             Array<Buffer> output_placeholders,
                             Stmt body) {
64 65 66
  if (!attrs.defined()) {
    attrs = Map<std::string, NodeRef>();
  }
67
  auto n = make_node<ExternOpNode>();
68 69 70
  n->name = std::move(name);
  n->tag = std::move(tag);
  n->attrs = std::move(attrs);
71 72 73
  CHECK_EQ(inputs.size(), input_placeholders.size());
  for (size_t i = 0; i < inputs.size(); ++i) {
    CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
74 75 76 77
    CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size());
    for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) {
        CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim]));
    }
78 79
    CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
  }
80 81 82 83
  n->inputs = std::move(inputs);
  n->input_placeholders = std::move(input_placeholders);
  n->output_placeholders = std::move(output_placeholders);
  n->body = std::move(body);
84 85 86 87 88 89 90 91 92 93 94
  return Operation(n);
}

Array<Tensor> ExternOpNode::InputTensors() const {
  return inputs;
}

Operation ExternOpNode::ReplaceInputs(
    const Operation& self,
    const std::unordered_map<Tensor, Tensor>& rmap) const {
  CHECK_EQ(self.operator->(), this);
95
  auto n = make_node<ExternOpNode>(*this);
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
  n->body = op::ReplaceTensor(this->body, rmap);
  for (size_t i = 0; i < n->inputs.size(); ++i) {
    Tensor t = n->inputs[i];
    if (rmap.count(t)) {
      n->inputs.Set(i, rmap.at(t));
    }
  }

  if (body.same_as(n->body) &&
      inputs.same_as(n->inputs)) {
    return self;
  } else {
    return Operation(n);
  }
}

void ExternOpNode::PropBoundToInputs(
    const Operation& self,
114
    arith::Analyzer* analyzer,
115 116 117 118 119 120 121 122
    const std::unordered_map<const Variable*, IntSet>& dom_map,
    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
  for (Tensor t : this->inputs) {
    auto it = out_dom_map->find(t);
    if (it == out_dom_map->end()) continue;
    TensorDom& dom = it->second;
    for (size_t i = 0; i < t->shape.size(); ++i) {
      dom.data[i].emplace_back(IntSet::range(
123
          Range::make_by_min_extent(
124 125 126 127 128 129 130 131 132 133 134 135
              make_const(t->shape[i].type(), 0), t->shape[i])));
    }
  }
}

void ExternOpNode::GatherBound(
    const Operation& self,
    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
    std::unordered_map<IterVar, Range>* out_dom_map) const {
}

Stmt ExternOpNode::BuildRealize(
136
    const Stage& stage,
137 138
    const std::unordered_map<IterVar, Range>& realize_map,
    const Stmt& body) const {
139
  CHECK_EQ(stage->op.get(), this);
140 141
  Stmt realize_body = body;
  for (int k = 0; k < num_outputs(); ++k) {
142
    Tensor t = stage->op.output(k);
143
    Region bounds;
144 145
    for (size_t i = 0; i < t->shape.size(); ++i) {
      bounds.push_back(
146
          Range::make_by_min_extent(
147 148 149 150 151 152 153 154 155 156 157
              make_const(t->shape[i].type(), 0), t->shape[i]));
    }
    realize_body = ir::Realize::make(
        t->op, t->value_index, t->dtype,
        bounds, const_true(), realize_body);
  }
  return realize_body;
}

Stmt ExternOpNode::BuildProvide(
    const Stage& stage,
158
    const std::unordered_map<IterVar, Range>& dom_map,
159
    bool debug_keep_trivial_loop) const {
160
  CHECK_EQ(stage->op.operator->(), this);
161
  Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
  auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
    Array<NodeRef> bind_spec;
    Array<Expr> tuple;
    bind_spec.push_back(buffer);
    bind_spec.push_back(tensor);
    for (size_t k = 0; k < buffer->shape.size(); ++k) {
      tuple.push_back(make_const(buffer->shape[k].type(), 0));
      tuple.push_back(buffer->shape[k]);
    }
    ret = AttrStmt::make(
        bind_spec, attr::buffer_bind_scope,
        Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
  };
  for (size_t i = output_placeholders.size(); i != 0; --i) {
    f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
  }
  for (size_t i = inputs.size(); i != 0; --i) {
    f_push_bind(input_placeholders[i - 1], inputs[i - 1]);
  }
  return ret;
182 183
}
}  // namespace tvm