Commit cc5a3cf0 by Yida Wang Committed by Tianqi Chen

[RELAY][PASS]use attribute registration style in the mac count pass (#2645)

parent aac5837f
...@@ -16,50 +16,30 @@ ...@@ -16,50 +16,30 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace { namespace mac_count {
bool IsConv2DNode(const ExprNode* node) { inline int64_t GetCartesianProd(Array<IndexExpr> arr) {
const auto* call_node = dynamic_cast<const CallNode*>(node); int64_t ret = 1;
return call_node != nullptr && call_node->attrs.as<Conv2DAttrs>(); for (size_t i = 0; i < arr.size(); i++) {
} const auto* intImm = arr[i].as<IntImm>();
ret *= static_cast<int64_t>(intImm->value);
bool IsDenseNode(const ExprNode* node) {
const auto* call_node = dynamic_cast<const CallNode*>(node);
return call_node != nullptr && call_node->attrs.as<DenseAttrs>();
}
} // namespace
class MacCounter : private ExprVisitor {
public:
MacCounter() {
count_ = 0;
}
static int64_t GetTotalMacNumber(const Expr& expr) {
LOG(INFO) << "This pass only counts MACs in direct CONV 2D and Dense ops";
MacCounter counter;
counter(expr);
return counter.count_;
}
private:
void VisitExpr_(const CallNode* call_node) final {
if (IsConv2DNode(call_node)) {
count_ += ComputeConv2DMacs(call_node);
} else if (IsDenseNode(call_node)) {
count_ += ComputeDenseMacs(call_node);
}
ExprVisitor::VisitExpr_(call_node);
} }
return ret;
}
/* /*
* \brief Get the number of MACs of a CONV 2D node. * \brief Preparation function for MAC count.
* \param call_node The CONV 2D call node. * \param call_node The call node.
* \return The number of MACs. * \return The number of MACs.
*/ */
int64_t ComputeConv2DMacs(const CallNode* call_node) { using FMacCount = runtime::TypedPackedFunc<
CHECK(IsConv2DNode(call_node)) int64_t(const Call& call_node)>;
<< "The input call node must be a CONV 2D node.";
//----------------------------------------------
// Per operator defs for MAC count
//----------------------------------------------
int64_t ConvMacCount(const Call& call_node) {
if (!call_node->checked_type_.defined()) { if (!call_node->checked_type_.defined()) {
LOG(WARNING) << "The infer type pass should be called before the mac count pass"; LOG(WARNING) << "The infer type pass should be called before the mac count pass";
return 0; return 0;
...@@ -87,16 +67,9 @@ class MacCounter : private ExprVisitor { ...@@ -87,16 +67,9 @@ class MacCounter : private ExprVisitor {
<< "The dimension of the output tensor in Conv 2D should be 4 or 5."; << "The dimension of the output tensor in Conv 2D should be 4 or 5.";
int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
return count; return count;
} }
/* int64_t DenseMacCount(const Call& call_node) {
* \brief Get the number of MACs of a Dense node.
* \param call_node The Dense call node.
* \return The number of MACs.
*/
int64_t ComputeDenseMacs(const CallNode* call_node) {
CHECK(IsDenseNode(call_node))
<< "The input call node must be a Dense node.";
if (!call_node->checked_type_.defined()) { if (!call_node->checked_type_.defined()) {
LOG(WARNING) << "The infer type pass should be called before the mac count pass"; LOG(WARNING) << "The infer type pass should be called before the mac count pass";
return 0; return 0;
...@@ -118,15 +91,33 @@ class MacCounter : private ExprVisitor { ...@@ -118,15 +91,33 @@ class MacCounter : private ExprVisitor {
<< "The dimensions of input arguments do not match."; << "The dimensions of input arguments do not match.";
int64_t count = d1 * d2 * d3; int64_t count = d1 * d2 * d3;
return count; return count;
} }
int64_t GetCartesianProd(Array<IndexExpr> arr) { RELAY_REGISTER_OP("nn.conv2d")
int64_t ret = 1; .set_attr<FMacCount>("FMacCount", ConvMacCount);
for (size_t i = 0; i < arr.size(); i++) {
const auto* intImm = arr[i].as<IntImm>(); RELAY_REGISTER_OP("nn.dense")
ret *= static_cast<int64_t>(intImm->value); .set_attr<FMacCount>("FMacCount", DenseMacCount);
class MacCounter : private ExprVisitor {
public:
MacCounter() {
count_ = 0;
} }
return ret; static int64_t GetTotalMacNumber(const Expr& expr) {
LOG(INFO) << "This pass only counts MACs in direct CONV 2D and Dense ops";
MacCounter counter;
counter(expr);
return counter.count_;
}
private:
void VisitExpr_(const CallNode* call_node) final {
static const auto& fprep =
Op::GetAttr<FMacCount>("FMacCount");
auto f = fprep.get(call_node->op, nullptr);
if (f != nullptr) count_ += f(GetRef<Call>(call_node));
ExprVisitor::VisitExpr_(call_node);
} }
int64_t count_; int64_t count_;
...@@ -141,5 +132,6 @@ TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber") ...@@ -141,5 +132,6 @@ TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber")
*ret = GetTotalMacNumber(args[0]); *ret = GetTotalMacNumber(args[0]);
}); });
} // namespace mac_count
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
"""Unit tests for MAC counter.""" """Unit tests for MAC counter."""
import tvm import tvm
from tvm import relay from tvm import relay
import sys
def test_gemm(): def test_gemm():
n = 512 n = 512
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment