/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file quantize.cc * * \brief transform a graph to a low-bit graph * for compression and acceleration. */ #include <dmlc/thread_local.h> #include <tvm/relay/op_attr_types.h> #include <tvm/relay/transform.h> #include <stack> #include "./quantize.h" namespace tvm { namespace relay { namespace quantize { TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); bool SimulatedQuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 5); const auto param = attrs.as<SimulatedQuantizeAttrs>(); CHECK(param != nullptr); const auto* data = types[0].as<TensorTypeNode>(); CHECK(data != nullptr); CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; reporter->Assign(types[1], TensorTypeNode::make({}, Float(32))); // dom_scale reporter->Assign(types[2], TensorTypeNode::make({}, Float(32))); // clip_min reporter->Assign(types[3], TensorTypeNode::make({}, Float(32))); // clip_max reporter->Assign(types[4], types[0]); // output return true; } RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") .describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE) .set_num_inputs(4) .add_argument("data", "Tensor", "The input data.") .add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar") .add_argument("clip_min", "Tensor", "lower bound. It should be a scalar") .add_argument("clip_max", "Tensor", "upper bound. It should be a scalar") .set_attrs_type<SimulatedQuantizeAttrs>() .set_support_level(11) .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); TVM_REGISTER_API("relay._quantize.simulated_quantize") .set_body_typed<Expr(Expr, Expr, Expr, Expr, int, bool, std::string)>( [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign, std::string rounding) { auto attrs = make_node<SimulatedQuantizeAttrs>(); attrs->kind = kind; attrs->sign = sign; attrs->rounding = rounding; static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); return CallNode::make(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); }); /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMQConfigThreadLocalEntry { /*! \brief The default build config if the stack is empty */ QConfig default_config; /*! \brief The current build config context */ std::stack<QConfig> context_stack; TVMQConfigThreadLocalEntry() : default_config(make_node<QConfigNode>()) { } }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore<TVMQConfigThreadLocalEntry> TVMQConfigThreadLocalStore; void QConfig::EnterQConfigScope(const QConfig& build_config) { TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.push(build_config); } void QConfig::ExitQConfigScope() { TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.pop(); } QConfig& QConfig::Current() { TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } return entry->default_config; } TVM_REGISTER_NODE_TYPE(QConfigNode); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<QConfigNode>([](const QConfigNode *op, IRPrinter *p) { p->stream << "qconfig("; p->stream << "nbit_input=" << op->nbit_input << ", "; p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", "; p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", "; p->stream << "rounding==" << op->rounding; p->stream << ")"; }); TVM_REGISTER_API("relay._quantize._GetCurrentQConfig") .set_body_typed(QConfig::Current); TVM_REGISTER_API("relay._quantize._EnterQConfigScope") .set_body_typed(QConfig::EnterQConfigScope); TVM_REGISTER_API("relay._quantize._ExitQConfigScope") .set_body_typed(QConfig::ExitQConfigScope); } // namespace quantize } // namespace relay } // namespace tvm