interpreter.h 4.1 KB
Newer Older
1 2 3 4 5 6
/*!
 *  Copyright (c) 2018 by Contributors
 * \file tvm/relay/interpreter.h
 * \brief An interpreter for Relay.
 *
 * This file implements a simple reference interpreter for Relay programs.
7
 * Given a Relay module, and a Relay expression it produces a value.
8 9 10 11 12 13 14 15 16 17 18
 *
 * The interpreter's values are a naive representation of the values that
 * can be produced by a Relay program and are exposed via tvm::Node's
 * system to Python for introspection and debugging.
 *
 * The interpreter's intent is to serve as a reference semantics for the Relay IR,
 * as well as for debugging and testing.
 */
#ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_

19
#include <tvm/build_module.h>
20
#include <tvm/relay/module.h>
21 22 23 24 25 26 27 28 29 30
#include <tvm/relay/expr.h>

namespace tvm {
namespace relay {

/*!
 * \brief A Relay value.
 */
class Value;

31 32 33
/*!
 *\brief Create a Interpreter function that can
 *  evaluate an expression and produce a value.
34 35 36 37 38 39 40 41 42 43
 *
 * The resulting value can be passed to Python, making it easy to use
 * for testing and debugging.
 *
 * The interpreter interprets the program fragments not supported by the
 * TVM runtime, although the interpreter is naively implemented it uses
 * TVM operators for evaluating all operators.
 *
 * Our intent is that this will never be the most efficient implementation of
 * Relay's semantics, but a readable and clear one.
44 45 46 47 48
 *
 * \param mod The function module.
 * \param context The primary context that the interepreter runs on.
 * \param target Compiler target flag to compile the functions on the context.
 * \return A function that takes in an expression and returns a value.
49
 */
50 51
runtime::TypedPackedFunc<Value(Expr)>
CreateInterpreter(Module mod, DLContext context, Target target);
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146

/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
 public:
  static constexpr const char* _type_key = "relay.Value";
  TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
};

class Value : public NodeRef {
 public:
  Value() {}
  explicit Value(NodePtr<Node> n) : NodeRef(n) {}
  const ValueNode* operator->() const {
    return static_cast<const ValueNode*>(node_.get());
  }

  using ContainerType = ValueNode;
};

/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;

/*! \brief The container type of Closures. */
class ClosureNode : public ValueNode {
 public:
  /*! \brief The set of free variables in the closure.
   *
   * These are the captured variables which are required for
   * evaluation when we call the closure.
   */
  tvm::Map<Var, Value> env;
  /*! \brief The function which implements the closure.
   *
   * \note May reference the variables contained in the env.
   */
  Function func;

  ClosureNode() {}

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("env", &env);
    v->Visit("func", &func);
  }

  TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);

  static constexpr const char* _type_key = "relay.Closure";
  TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);

/*! \brief A tuple value. */
class TupleValue;

/*! \brief Tuple (x, ... y). */
struct TupleValueNode : ValueNode {
  tvm::Array<Value> fields;

  TupleValueNode() {}

  void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }

  TVM_DLL static TupleValue make(tvm::Array<Value> value);

  static constexpr const char* _type_key = "relay.TupleValue";
  TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value);

/*! \brief A tensor value. */
class TensorValue;

/*! \brief The tensor value container, wrapping an NDArray. */
struct TensorValueNode : ValueNode {
  runtime::NDArray data;

  TensorValueNode() {}

  void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); }

  /*! \brief Build a value from an NDArray. */
  TVM_DLL static TensorValue make(runtime::NDArray data);

  static constexpr const char* _type_key = "relay.TensorValue";
  TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);


}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_INTERPRETER_H_