alpha_equal.cc 12.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
/*!
 *  Copyright (c) 2018 by Contributors
 * \file src/tvm/relay/ir/alpha_equal.cc
 * \brief Alpha equality check by deep comparing two nodes.
 */
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/pass.h>
#include "type_functor.h"
#include "../../lang/attr_functor.h"

namespace tvm {
namespace relay {

// Alpha equal handler for relay.
class AlphaEqualHandler:
      public AttrsEqualHandler,
      public TypeFunctor<bool(const Type&, const Type&)>,
      public ExprFunctor<bool(const Expr&, const Expr&)> {
 public:
  explicit AlphaEqualHandler(bool map_free_var)
      : map_free_var_(map_free_var) {}

  /*!
   * Check equality of two nodes.
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return the compare result.
   */
  bool Equal(const NodeRef& lhs, const NodeRef& rhs) {
    if (lhs.same_as(rhs)) return true;
    if (!lhs.defined() || !rhs.defined()) return false;
    if (lhs->derived_from<TypeNode>()) {
      if (!rhs->derived_from<TypeNode>()) return false;
      return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
    }
    if (lhs->derived_from<ExprNode>()) {
      if (!rhs->derived_from<ExprNode>()) return false;
      return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
    }
    return AttrEqual(lhs, rhs);
  }

  /*!
   * Check equality of two attributes.
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return the compare result.
   */
  bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
    return AttrsEqualHandler::Equal(lhs, rhs);
  }
  /*!
   * Check equality of two types.
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return the compare result.
   */
  bool TypeEqual(const Type& lhs, const Type& rhs) {
    if (lhs.same_as(rhs)) return true;
    if (!lhs.defined() || !rhs.defined()) return false;
    return this->VisitType(lhs, rhs);
  }
  /*!
   * Check equality of two expressions.
   *
   * \note We run graph structural equality checking when comparing two Exprs.
   *   This means that AlphaEqualHandler can only be used once for each pair.
   *   The equality checker checks data-flow equvalence of the Expr DAG.
   *   This function also runs faster as it memomizes equal_map.
   *
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return the compare result.
   */
  bool ExprEqual(const Expr& lhs, const Expr& rhs) {
    if (lhs.same_as(rhs)) return true;
    if (!lhs.defined() || !rhs.defined()) return false;
    auto it = equal_map_.find(lhs);
    if (it != equal_map_.end()) {
      return it->second.same_as(rhs);
    }
    if (this->VisitExpr(lhs, rhs)) {
      equal_map_[lhs] = rhs;
      return true;
    } else {
      return false;
    }
  }

 protected:
  /*!
   * \brief Check if data type equals each other.
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return the compare result.
   */
  bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
    return lhs == rhs;
  }
  /*!
   * \brief Check Equality of leaf node of the graph.
   *  if map_free_var_ is set to true, try to map via equal node.
   * \param lhs The left hand operand.
   * \param rhs The right hand operand.
   * \return the compare result.
   */
  bool LeafNodeEqual(const NodeRef& lhs, const NodeRef& rhs) {
    if (lhs.same_as(rhs)) return true;
    auto it = equal_map_.find(lhs);
    if (it != equal_map_.end()) {
      return it->second.same_as(rhs);
    } else {
      if (map_free_var_) {
        if (lhs->type_index() != rhs->type_index()) return false;
        equal_map_[lhs] = rhs;
        return true;
      } else {
        return false;
      }
    }
  }
  using AttrsEqualHandler::VisitAttr_;
  bool VisitAttr_(const Variable* lhs, const NodeRef& other) final {
    return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
  }

  // Type equality
  bool VisitType_(const TensorTypeNode* lhs, const Type& other) final {
    if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) {
      return (lhs->dtype == rhs->dtype &&
              AttrEqual(lhs->shape, rhs->shape));
    } else {
      return false;
    }
  }

  bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final {
    return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
  }

  bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
    if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
      if (lhs->kind != rhs->kind) return false;
      return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
    } else {
      return false;
    }
  }

  bool VisitType_(const FuncTypeNode* lhs, const Type& other) final {
    if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) {
      if (lhs->arg_types.size() != rhs->arg_types.size()) return false;
      if (lhs->type_params.size() != rhs->type_params.size()) return false;
      if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false;
      for (size_t i = 0; i < lhs->type_params.size(); ++i) {
        if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) {
          return false;
        }
        equal_map_[lhs->type_params[i]] = rhs->type_params[i];
        // set up type parameter equal
        if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) {
          // map variable
          equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
        }
      }
      for (size_t i = 0; i < lhs->arg_types.size(); i++) {
        if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
      }
      if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
      for (size_t i = 0; i < lhs->type_constraints.size(); i++) {
        if (!TypeEqual(lhs->type_constraints[i],
                       rhs->type_constraints[i])) {
          return false;
        }
      }
      return true;
    } else {
      return false;
    }
  }

  bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
    if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
186
      if (lhs->func->name != rhs->func->name) return false;
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
      if (lhs->num_inputs != rhs->num_inputs) return false;
      if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
      if (lhs->args.size() != rhs->args.size()) return false;
      for (size_t i = 0; i < lhs->args.size(); ++i) {
        if (!TypeEqual(lhs->args[i], rhs->args[i])) return false;
      }
      return true;
    } else {
      return false;
    }
  }

  bool VisitType_(const TupleTypeNode* lhs, const Type& other) final {
    if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) {
      if (lhs->fields.size() != rhs->fields.size()) return false;
      for (size_t i = 0; i < lhs->fields.size(); ++i) {
        if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false;
      }
      return true;
    } else {
      return false;
    }
  }
  // Expr equal checking.
  bool NDArrayEqual(const runtime::NDArray& lhs,
                    const runtime::NDArray& rhs) {
    if (lhs.defined() != rhs.defined()) {
      return false;
    } else if (lhs.same_as(rhs)) {
      return true;
    } else {
      auto ldt = lhs->dtype;
      auto rdt = rhs->dtype;
      CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
      CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
      if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
        size_t data_size = runtime::GetDataSize(*lhs.operator->());
        return std::memcmp(lhs->data, rhs->data, data_size) == 0;
      } else {
        return false;
      }
    }
  }
  // merge declaration of two variables together.
  bool MergeVarDecl(const Var& lhs, const Var& rhs) {
    if (lhs.same_as(rhs)) return true;
    if (!lhs.defined() || !rhs.defined()) return false;
    if (!TypeEqual(lhs->type_annotation,
                   rhs->type_annotation)) return false;
    CHECK(!equal_map_.count(lhs))
        << "Duplicated declaration of variable " <<  lhs;
    equal_map_[lhs] = rhs;
    return true;
  }

  bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
    if (const VarNode* rhs = other.as<VarNode>()) {
      if (lhs->name_hint != rhs->name_hint) return false;
      if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
      return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
    if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
      // use name equality for global var for now.
      if (lhs->name_hint != rhs->name_hint) return false;
      return true;
    } else {
      return false;
    }
  }

  bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
    if (const TupleNode* rhs = other.as<TupleNode>()) {
      if (lhs->fields.size() != rhs->fields.size()) return false;
      for (size_t i = 0; i < lhs->fields.size(); ++i) {
        if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false;
      }
      return true;
    } else {
      return false;
    }
  }

  bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final {
    if (const FunctionNode* rhs = other.as<FunctionNode>()) {
      if (lhs->params.size() != rhs->params.size()) return false;
      if (lhs->type_params.size() != rhs->type_params.size()) return false;
      // map type parameter to be the same
      for (size_t i = 0; i < lhs->type_params.size(); ++i) {
        if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false;
        equal_map_[lhs->type_params[i]] = rhs->type_params[i];
      }
      // check parameter type annotations
      for (size_t i = 0; i < lhs->params.size(); ++i) {
        if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false;
      }
      // check return types.
      if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
      return ExprEqual(lhs->body, rhs->body);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const CallNode* lhs, const Expr& other) final {
    if (const CallNode* rhs = other.as<CallNode>()) {
      if (!ExprEqual(lhs->op, rhs->op)) return false;
      if (lhs->args.size() != rhs->args.size()) return false;
      if (lhs->type_args.size() != rhs->type_args.size()) return false;

      for (size_t i = 0; i < lhs->args.size(); ++i) {
        if (!ExprEqual(lhs->args[i], rhs->args[i])) return false;
      }
      for (size_t i = 0; i < lhs->type_args.size(); ++i) {
        if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
      }
      return AttrEqual(lhs->attrs, rhs->attrs);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
    if (const LetNode* rhs = other.as<LetNode>()) {
      if (!ExprEqual(lhs->value, rhs->value)) return false;
      if (!MergeVarDecl(lhs->var, rhs->var)) return false;
      return ExprEqual(lhs->body, rhs->body);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const IfNode* lhs, const Expr& other) final {
    if (const IfNode* rhs = other.as<IfNode>()) {
      return ExprEqual(lhs->cond, rhs->cond) &&
          ExprEqual(lhs->true_branch, rhs->true_branch) &&
          ExprEqual(lhs->false_branch, rhs->false_branch);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const OpNode* op, const Expr& other) final {
    return op == other.get();
  }

  bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
    if (const ConstantNode* rhs = other.as<ConstantNode>()) {
      return NDArrayEqual(lhs->data, rhs->data);
    } else {
      return false;
    }
  }

  bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final {
    if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) {
      return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index;
    } else {
      return false;
    }
  }

 private:
  // whether to map open terms.
  bool map_free_var_{false};
  // renaming of NodeRef to indicate two nodes equals to each other
  std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_;
};

bool AlphaEqual(const Type& lhs, const Type& rhs) {
  return AlphaEqualHandler(false).TypeEqual(lhs, rhs);
}

bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
  return AlphaEqualHandler(false).ExprEqual(lhs, rhs);
}

// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = AlphaEqualHandler(false).Equal(args[0], args[1]);
  });

TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]);
  });

TVM_REGISTER_API("relay._make._graph_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = AlphaEqualHandler(true).Equal(args[0], args[1]);
  });
}  // namespace relay
}  // namespace tvm