Unverified Commit 5da361d3 by Zhi Committed by GitHub

[REFACTOR][IR] Move to runtime::String (#5276)

* Use runtime::String

* move string to tvm namespace

* add const char* constructor

* implicit cast from std::string
parent 48082358
......@@ -107,11 +107,12 @@ class PrimExpr : public BaseExpr {
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
* \brief construct from runtime String.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(std::string str); // NOLINT(*)
TVM_DLL PrimExpr(runtime::String value); // NOLINT(*)
/*! \return the data type of this expression. */
DataType dtype() const {
......
......@@ -57,6 +57,7 @@
#define TVM_IR_TRANSFORM_H_
#include <tvm/support/with.h>
#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
......@@ -95,9 +96,9 @@ class PassContextNode : public Object {
int fallback_device{static_cast<int>(kDLCPU)};
/*! \brief The list of required passes. */
Array<PrimExpr> required_pass;
Array<runtime::String> required_pass;
/*! \brief The list of disabled passes. */
Array<PrimExpr> disabled_pass;
Array<runtime::String> disabled_pass;
TraceFunc trace_func;
......@@ -197,7 +198,7 @@ class PassInfoNode : public Object {
std::string name;
/*! \brief The passes that are required to perform the current pass. */
Array<PrimExpr> required;
Array<runtime::String> required;
PassInfoNode() = default;
......@@ -226,7 +227,7 @@ class PassInfo : public ObjectRef {
*/
TVM_DLL PassInfo(int opt_level,
std::string name,
Array<PrimExpr> required);
Array<runtime::String> required);
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
......@@ -346,7 +347,7 @@ Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const Array<PrimExpr>& required);
const Array<runtime::String>& required);
} // namespace transform
} // namespace tvm
......
......@@ -36,6 +36,8 @@
namespace tvm {
using runtime::String;
using runtime::StringObj;
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
......
......@@ -35,6 +35,7 @@
#define TVM_NODE_NODE_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
......@@ -62,6 +63,7 @@ using runtime::make_object;
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::String;
} // namespace tvm
#endif // TVM_NODE_NODE_H_
......@@ -24,6 +24,7 @@
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir/transform.h>
#include <tvm/relay/expr.h>
......@@ -59,7 +60,7 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
const tvm::Array<runtime::String>& required);
/*! \brief Remove expressions which does not effect the program result.
*
......@@ -355,7 +356,7 @@ TVM_DLL Pass Inline();
*
* \return The pass.
*/
TVM_DLL Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
} // namespace transform
......
......@@ -360,7 +360,15 @@ class String : public ObjectRef {
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
explicit String(std::string other);
String(std::string other); // NOLINT(*)
/*!
* \brief Construct a new String object
*
* \param other a char array.
*/
String(const char* other) // NOLINT(*)
: String(std::string(other)) {}
/*!
* \brief Change the value the reference object points to.
......
......@@ -52,11 +52,11 @@ class TargetNode : public Object {
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
Array<PrimExpr> keys_array;
Array<runtime::String> keys_array;
/*! \brief Options for this target */
Array<PrimExpr> options_array;
Array<runtime::String> options_array;
/*! \brief Collection of imported libs */
Array<PrimExpr> libs_array;
Array<runtime::String> libs_array;
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
......
......@@ -326,7 +326,7 @@ class StmtExprMutator :
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* \param only_enable List of runtime::String.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
......@@ -334,7 +334,7 @@ class StmtExprMutator :
TVM_DLL Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<PrimExpr>& only_enable = {});
const Array<runtime::String>& only_enable = {});
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
......
......@@ -56,7 +56,7 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
const tvm::Array<runtime::String>& required);
/*!
* \brief Transform the high-level PrimFunc to a low-level version
......@@ -100,7 +100,7 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
*
* \return The pass.
*/
TVM_DLL Pass RemapThreadAxis(Map<PrimExpr, IterVar> axis_map);
TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);
/*!
......
......@@ -24,6 +24,7 @@ registers the standard task.
import numpy as np
from tvm import target as _target
from tvm import runtime
from tvm.ir import container
from tvm.tir import expr
from tvm.te import tensor, placeholder
......@@ -55,6 +56,8 @@ def serialize_args(args):
return x
if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
if isinstance(x, runtime.container.String):
return str(x)
if x is None:
return None
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
......
......@@ -84,8 +84,7 @@ class GraphRuntimeCodegen(object):
lowered_func = self._get_irmodule()
param_names = self._list_params_name()
params = {}
for name in param_names:
key = name.value
for key in param_names:
arr = self._get_param_by_name(key)
param = empty(arr.shape, dtype=arr.dtype, ctx=arr.ctx)
arr.copyto(param)
......
......@@ -16,8 +16,9 @@
# under the License.
"""Runtime container structures."""
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, ObjectTypes
from tvm.runtime import _ffi_api
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
......@@ -75,18 +76,19 @@ class ADT(Object):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)
self.__init_handle_by_constructor__(_ffi_api.ADT, tag,
*fields)
@property
def tag(self):
return _GetADTTag(self)
return _ffi_api.GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
self, _GetADTFields, len(self), idx)
self, _ffi_api.GetADTFields, len(self), idx)
def __len__(self):
return _GetADTSize(self)
return _ffi_api.GetADTSize(self)
def tuple_object(fields=None):
......@@ -106,7 +108,7 @@ def tuple_object(fields=None):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)
return _ffi_api.Tuple(*fields)
@tvm._ffi.register_object("runtime.String")
......@@ -115,7 +117,7 @@ class String(Object):
Parameters
----------
string : Str
string : str
The string used to construct a runtime String object
Returns
......@@ -124,7 +126,50 @@ class String(Object):
The created object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_String, string)
self.__init_handle_by_constructor__(_ffi_api.String, string)
def __str__(self):
return _ffi_api.GetStdString(self)
def __len__(self):
return _ffi_api.GetStringSize(self)
def __hash__(self):
return _ffi_api.StringHash(self)
def __eq__(self, other):
if isinstance(other, string_types):
return self.__str__() == other
if not isinstance(other, String):
return False
return _ffi_api.CompareString(self, other) == 0
def __ne__(self, other):
return not self.__eq__(other)
def __gt__(self, other):
return _ffi_api.CompareString(self, other) > 0
def __lt__(self, other):
return _ffi_api.CompareString(self, other) < 0
def __getitem__(self, key):
return self.__str__()[key]
def startswith(self, string):
"""Check if the runtime string starts with a given string
Parameters
----------
string : str
The provided string
tvm._ffi._init_api("tvm.runtime.container")
Returns
-------
ret : boolean
Return true if the runtime string starts with the given string,
otherwise, false.
"""
return self.__str__().startswith(string)
......@@ -19,7 +19,7 @@
from numbers import Number, Integral
from tvm._ffi.base import string_types
from . import _ffi_node_api
from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
......@@ -56,7 +56,7 @@ def convert_to_object(value):
if isinstance(value, Number):
return const(value)
if isinstance(value, string_types):
return _ffi_node_api.String(value)
return _ffi_api.String(value)
if isinstance(value, (list, tuple)):
value = [convert_to_object(x) for x in value]
return _ffi_node_api.Array(*value)
......
......@@ -48,26 +48,26 @@ class Target(Object):
@property
def keys(self):
if not self._keys:
self._keys = [k.value for k in self.keys_array]
self._keys = [str(k) for k in self.keys_array]
return self._keys
@property
def options(self):
if not self._options:
self._options = [o.value for o in self.options_array]
self._options = [str(o) for o in self.options_array]
return self._options
@property
def libs(self):
if not self._libs:
self._libs = [l.value for l in self.libs_array]
self._libs = [str(l) for l in self.libs_array]
return self._libs
@property
def model(self):
for opt in self.options_array:
if opt.value.startswith('-model='):
return opt.value[7:]
if opt.startswith('-model='):
return opt[7:]
return 'unknown'
@property
......
......@@ -252,9 +252,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
for (auto var : vars) {
Array<Array<PrimExpr> > feature_row;
ItervarFeature &fea = touch_analyzer.itervar_map[var];
feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_itervar_"), var});
Array<PrimExpr> attr{std::string("_attr_"),
Array<PrimExpr> attr{tvm::tir::StringImmNode::make("_attr_"),
FloatImm(DataType::Float(32), trans(fea.length)),
IntImm(DataType::Int(32), fea.nest_level),
FloatImm(DataType::Float(32), trans(fea.topdown_product)),
......@@ -267,7 +267,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
feature_row.push_back(attr);
// arithmetic
feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_arith_"),
FloatImm(DataType::Float(32), trans(fea.add_ct)),
FloatImm(DataType::Float(32), trans(fea.mul_ct)),
FloatImm(DataType::Float(32), trans(fea.div_ct)),
......@@ -282,7 +282,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(
Array<PrimExpr>{k,
Array<PrimExpr>{tvm::tir::StringImmNode::make(k),
FloatImm(DataType::Float(32), trans(v.stride)),
FloatImm(DataType::Float(32), trans(v.mod)),
FloatImm(DataType::Float(32), trans(v.count)),
......
......@@ -42,7 +42,7 @@ void DictAttrsNode::InitByPackedArgs(
if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator ObjectRef());
} else if (val.type_code() == kTVMStr) {
dict.Set(key, PrimExpr(val.operator std::string()));
dict.Set(key, val.operator String());
} else {
dict.Set(key, val.operator PrimExpr());
}
......
......@@ -40,8 +40,8 @@ PrimExpr::PrimExpr(int32_t value)
PrimExpr::PrimExpr(float value)
: PrimExpr(FloatImm(DataType::Float(32), value)) {}
PrimExpr::PrimExpr(std::string str)
: PrimExpr(tir::StringImmNode::make(str)) {}
PrimExpr::PrimExpr(runtime::String value)
: PrimExpr(tir::StringImmNode::make(value)) {}
PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
using runtime::ObjectTypeChecker;
......@@ -51,6 +51,9 @@ PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
if (ptr->IsInstance<te::TensorNode>()) {
return te::Tensor(ptr)();
}
if (ptr->IsInstance<runtime::StringObj>()) {
return tir::StringImmNode::make(runtime::String(ptr));
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
......
......@@ -24,6 +24,7 @@
#include <tvm/ir/op.h>
#include <tvm/ir/type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
......@@ -140,10 +141,9 @@ void OpRegistry::UpdateAttr(const std::string& key,
// Frontend APIs
TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
.set_body_typed([]() {
Array<tvm::PrimExpr> ret;
for (const std::string& name :
dmlc::Registry<OpRegistry>::ListAllNames()) {
ret.push_back(tvm::PrimExpr(name));
Array<runtime::String> ret;
for (const std::string& name : dmlc::Registry<OpRegistry>::ListAllNames()) {
ret.push_back(name);
}
return ret;
});
......
......@@ -23,6 +23,7 @@
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/device_api.h>
#include <tvm/node/repr_printer.h>
#include <tvm/ir/transform.h>
......@@ -212,7 +213,7 @@ class SequentialNode : public PassNode {
PassInfo::PassInfo(int opt_level,
std::string name,
tvm::Array<tvm::PrimExpr> required) {
tvm::Array<runtime::String> required) {
auto pass_info = make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
......@@ -274,12 +275,10 @@ void SequentialNode::ResolveDependency(const IRModule& mod) {
}
// linearly scan the pass array to match pass_name
inline bool PassArrayContains(const Array<tvm::PrimExpr>& pass_array,
inline bool PassArrayContains(const Array<runtime::String>& pass_array,
const std::string& pass_name) {
for (auto x : pass_array) {
auto* str_name = x.as<tir::StringImmNode>();
CHECK(str_name) << "pass name must be str";
if (str_name->value == pass_name) return true;
if (x == pass_name) return true;
}
return false;
}
......@@ -324,9 +323,7 @@ IRModule SequentialNode::operator()(const IRModule& module,
if (!PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::tir::StringImmNode>();
CHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
mod = GetPass(it)(mod, pass_ctx);
}
mod = pass(mod, pass_ctx);
}
......@@ -337,7 +334,7 @@ Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) {
const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return ModulePass(pass_func, pass_info);
}
......@@ -345,7 +342,7 @@ Pass CreateModulePass(
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("transform.PassInfo")
.set_body_typed([](int opt_level, std::string name, tvm::Array<PrimExpr> required) {
.set_body_typed([](int opt_level, std::string name, tvm::Array<runtime::String> required) {
return PassInfo(opt_level, name, required);
});
......@@ -363,8 +360,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "opt_level: " << node->opt_level;
p->stream << "required passes: [" << "\n";
for (const auto& it : node->required) {
const auto* str = it.as<tvm::tir::StringImmNode>();
p->stream << str->value << ", ";
p->stream << it << ", ";
}
p->stream << "]\n";
});
......@@ -401,7 +397,7 @@ TVM_REGISTER_GLOBAL("transform.Sequential")
tvm::Array<Pass> passes = args[0];
int opt_level = args[1];
std::string name = args[2];
tvm::Array<tvm::PrimExpr> required = args[3];
tvm::Array<runtime::String> required = args[3];
PassInfo pass_info = PassInfo(opt_level, name, required);
*ret = Sequential(passes, pass_info);
});
......@@ -427,8 +423,8 @@ TVM_REGISTER_GLOBAL("transform.PassContext")
auto pctx = PassContext::Create();
int opt_level = args[0];
int fallback_device = args[1];
tvm::Array<tvm::PrimExpr> required = args[2];
tvm::Array<tvm::PrimExpr> disabled = args[3];
tvm::Array<runtime::String> required = args[2];
tvm::Array<runtime::String> disabled = args[3];
TraceFunc trace_func = args[4];
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
......
......@@ -63,7 +63,6 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
static_cast<const runtime::StringObj*>(n)).operator std::string();
});
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
......
......@@ -86,9 +86,10 @@ struct GraphCodegen {
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
auto names = CallFunc<Array<tvm::PrimExpr>>("list_params_name", nullptr);
for (auto expr : names) {
auto key = expr.as<tir::StringImmNode>()->value;
auto names = CallFunc<Array<runtime::String>>("list_params_name", nullptr);
for (const auto& expr : names) {
// Implicit cast from runtime::String to std::string
std::string key = expr;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
}
return ret;
......@@ -191,12 +192,12 @@ class RelayBuildModule : public runtime::ModuleNode {
/*!
* \brief List all paramter names
*
* \return Array<StringImm> names of params
* \return Array<runtime::String> names of params
*/
Array<tvm::PrimExpr> ListParamNames() {
Array<tvm::PrimExpr> ret;
Array<runtime::String> ListParamNames() {
Array<runtime::String> ret;
for (const auto& kv : params_) {
ret.push_back(tir::StringImmNode::make(kv.first));
ret.push_back(kv.first);
}
return ret;
}
......@@ -272,7 +273,7 @@ class RelayBuildModule : public runtime::ModuleNode {
}
Array<Pass> pass_seqs;
Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes.
......
......@@ -617,17 +617,18 @@ class CompileEngineImpl : public CompileEngineNode {
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (src_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = IRModule({}, {});
std::string code_gen_name = code_gen;
if (ext_mods.find(code_gen_name) == ext_mods.end()) {
ext_mods[code_gen_name] = IRModule({}, {});
}
auto symbol_name = src_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
auto gv = GlobalVar(std::string(symbol_name));
ext_mods[code_gen->value]->Add(gv, src_func);
ext_mods[code_gen_name]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
}
......@@ -691,10 +692,10 @@ class CompileEngineImpl : public CompileEngineNode {
}
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
key->source_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined())
<< "External function has not been attached a name yet.";
cache_node->func_name = std::string(name_node);
......
......@@ -70,7 +70,7 @@ class CSourceModuleCodegenBase {
*/
std::string GetExtSymbol(const Function& func) const {
const auto name_node =
func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
return std::string(name_node);
}
......
......@@ -419,7 +419,7 @@ class GraphRuntimeCodegen
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;
// Handle external function
if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
......@@ -482,7 +482,7 @@ class GraphRuntimeCodegen
return {};
}
std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
CHECK(op->GetAttr<tir::StringImm>(attr::kCompiler).defined())
CHECK(op->GetAttr<String>(attr::kCompiler).defined())
<< "Only functions supported by custom codegen";
return {};
}
......@@ -633,10 +633,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
});
} else if (name == "list_params_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Array<tvm::PrimExpr> ret;
Array<runtime::String> ret;
for (const auto &kv : this->output_.params) {
tvm::PrimExpr name = tir::StringImmNode::make(kv.first);
ret.push_back(name);
ret.push_back(kv.first);
}
*rv = ret;
});
......
......@@ -475,7 +475,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
Target target;
if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
} else {
// Next generate the invoke instruction.
......@@ -493,7 +493,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto cfunc = engine_->Lower(key);
auto op_index = -1;
if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
} else {
......@@ -873,7 +873,7 @@ void VMCompiler::Lower(IRModule mod,
IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) {
Array<Pass> pass_seqs;
Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
......
......@@ -122,7 +122,7 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first;
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global
......
......@@ -190,7 +190,7 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) {
if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
func = Function(func->params,
VisitExpr(func->body),
......
......@@ -87,11 +87,10 @@ struct CallTracer : ExprVisitor {
* \return The module with dead functions removed.
*/
IRModule RemoveUnusedFunctions(const IRModule& module,
Array<tvm::PrimExpr> entry_funcs) {
Array<runtime::String> entry_funcs) {
std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) {
auto* str_name = entry.as<tir::StringImmNode>();
auto funcs = CallTracer(module).Trace(str_name->value);
auto funcs = CallTracer(module).Trace(entry);
called_funcs.insert(funcs.cbegin(), funcs.cend());
}
auto existing_functions = module->functions;
......@@ -108,7 +107,7 @@ IRModule RemoveUnusedFunctions(const IRModule& module,
namespace transform {
Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions) {
Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
return relay::vm::RemoveUnusedFunctions(m, entry_functions);
......
......@@ -145,14 +145,14 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
bool FunctionPassNode::SkipFunction(const Function& func) const {
return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
(func->GetAttr<tir::StringImm>(attr::kCompiler).defined());
(func->GetAttr<String>(attr::kCompiler).defined());
}
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) {
const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return FunctionPass(pass_func, pass_info);
}
......
......@@ -1177,7 +1177,6 @@ Array<te::Tensor> ArangeCompute(const Attrs& attrs,
te::Tensor start = inputs[0];
te::Tensor stop = inputs[1];
te::Tensor step = inputs[2];
Array<tvm::PrimExpr> empty = {0};
return { DynamicArange(start, stop, step, param->dtype) };
}
......
......@@ -125,8 +125,7 @@ Pass AlterOpLayout() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
};
return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout")
......
......@@ -59,11 +59,12 @@ class AnnotateTargetWrapper : public ExprMutator {
// handle composite functions
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
auto comp_name = func->GetAttr<tir::StringImm>(attr::kComposite);
auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined()) {
size_t i = comp_name->value.find('.');
std::string comp_name_str = comp_name;
size_t i = comp_name_str.find('.');
if (i != std::string::npos) {
std::string target = comp_name->value.substr(0, i);
std::string target = comp_name_str.substr(0, i);
if (target == target_) return true;
}
}
......@@ -147,7 +148,7 @@ class AnnotateTargetWrapper : public ExprMutator {
Function func;
Expr new_body;
// don't step into composite functions
if (fn->GetAttr<tir::StringImm>(attr::kComposite).defined()) {
if (fn->GetAttr<String>(attr::kComposite).defined()) {
func = GetRef<Function>(fn);
new_body = func->body;
} else {
......@@ -225,7 +226,7 @@ Pass AnnotateTarget(const std::string& target) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
{tir::StringImmNode::make("InferType")});
{"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
}
......
......@@ -133,8 +133,7 @@ Pass CanonicalizeCast() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeCast(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast")
......
......@@ -74,8 +74,7 @@ Pass CanonicalizeOps() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeOps(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps")
......
......@@ -220,8 +220,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D")
......
......@@ -80,8 +80,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense")
......
......@@ -193,8 +193,7 @@ Pass CombineParallelOpBatch(const std::string& op_name,
batch_op_name,
min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch")
......
......@@ -133,9 +133,7 @@ Pass ConvertLayout(const std::string& desired_layout) {
return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
};
return CreateFunctionPass(
pass_func, 3, "ConvertLayout",
{tir::StringImmNode::make("InferType"),
tir::StringImmNode::make("CanonicalizeOps")});
pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"});
}
TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
......
......@@ -573,8 +573,7 @@ Pass RewriteAnnotatedOps(int fallback_device) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::RewriteAnnotatedOps(f, fallback_device));
};
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation")
......
......@@ -91,8 +91,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
};
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
......
......@@ -70,8 +70,7 @@ Pass FastMath() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FastMath(f));
};
return CreateFunctionPass(pass_func, 4, "FastMath",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.FastMath")
......
......@@ -960,8 +960,7 @@ Pass ForwardFoldScaleAxis() {
return Downcast<Function>(
relay::fold_scale_axis::ForwardFoldScaleAxis(f));
};
return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
......@@ -973,8 +972,7 @@ Pass BackwardFoldScaleAxis() {
return Downcast<Function>(
relay::fold_scale_axis::BackwardFoldScaleAxis(f));
};
return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis")
......
......@@ -980,8 +980,7 @@ Pass FuseOps(int fuse_opt_level) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m));
};
return CreateFunctionPass(pass_func, 1, "FuseOps",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.FuseOps")
......
......@@ -131,7 +131,7 @@ class Inliner : ExprMutator {
fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
if (!func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
if (!func->GetAttr<String>(attr::kCompiler).defined()) {
CHECK_EQ(func->params.size(), args.size())
<< "Mismatch found in the number of parameters and call args";
// Bind the parameters with call args.
......
......@@ -101,7 +101,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
return CreateFunctionPass(pass_func, 1, "Legalize", {tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize);
......
......@@ -159,9 +159,9 @@ class MergeCompositeWrapper : public ExprMutator {
if (call->op->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
const auto name_node = func->GetAttr<tir::StringImm>(attr::kComposite);
auto name_node = func->GetAttr<String>(attr::kComposite);
// don't step into existing composite functions
if (name_node.defined() && name_node->value != "") {
if (name_node.defined() && name_node != "") {
tvm::Array<tvm::relay::Expr> new_args;
for (const auto& arg : call->args) {
auto new_e = this->Mutate(arg);
......@@ -185,7 +185,7 @@ class MergeCompositeWrapper : public ExprMutator {
auto free_vars = FreeVars(extract);
// make the composite function
auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
f = WithAttr(std::move(f), attr::kComposite, tir::StringImmNode::make(pattern_name_));
f = WithAttr(std::move(f), attr::kComposite, runtime::String(pattern_name_));
// find the expressions associated with the free vars using the args_map
// this tells us which expressions should be given as inputs to the composite function
Array<Expr> args;
......@@ -207,16 +207,14 @@ class MergeCompositeWrapper : public ExprMutator {
PackedFunc check_;
};
Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names,
Expr MergeComposite(const Expr& expr, const Array<runtime::String>& pattern_names,
const Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
CHECK_EQ(pattern_names.size(), patterns.size());
Expr merged_expr = expr;
// merge the patterns one-by-one in order
for (size_t i = 0; i < patterns.size(); i++) {
std::string pattern_name = pattern_names[i]->value;
Expr pattern = patterns[i];
PackedFunc check = checks[i];
merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr);
merged_expr =
MergeCompositeWrapper(pattern_names[i], patterns[i], checks[i]).Mutate(merged_expr);
}
return merged_expr;
}
......@@ -225,7 +223,7 @@ Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names
namespace transform {
Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
Pass MergeComposite(const tvm::Array<runtime::String>& pattern_names,
const tvm::Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
......@@ -236,8 +234,9 @@ Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
return func_pass;
}
TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) {
tvm::Array<tir::StringImm> pattern_names = args[0];
TVM_REGISTER_GLOBAL("relay._transform.MergeComposite")
.set_body([](TVMArgs args, TVMRetValue* rv) {
tvm::Array<runtime::String> pattern_names = args[0];
tvm::Array<Expr> patterns = args[1];
std::vector<PackedFunc> checks;
for (int i = 2; i < args.size(); i++) {
......
......@@ -245,7 +245,7 @@ class Partitioner : public ExprMutator {
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::tir::StringImmNode::make(target));
tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
......
......@@ -204,8 +204,7 @@ Pass SimplifyInference() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(SimplifyInference(f));
};
return CreateFunctionPass(pass_func, 0, "SimplifyInference",
{tir::StringImmNode::make("InferType")});
return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference")
......
......@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
}
Expr ret =
TransformF([&](const Expr& e) {
......
......@@ -32,14 +32,14 @@ namespace runtime {
using namespace vm;
TVM_REGISTER_GLOBAL("runtime.container._GetADTTag")
TVM_REGISTER_GLOBAL("runtime.GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.tag());
});
TVM_REGISTER_GLOBAL("runtime.container._GetADTSize")
TVM_REGISTER_GLOBAL("runtime.GetADTSize")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
......@@ -47,7 +47,7 @@ TVM_REGISTER_GLOBAL("runtime.container._GetADTSize")
});
TVM_REGISTER_GLOBAL("runtime.container._GetADTFields")
TVM_REGISTER_GLOBAL("runtime.GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
int idx = args[1];
......@@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("runtime.container._GetADTFields")
*rv = adt[idx];
});
TVM_REGISTER_GLOBAL("runtime.container._Tuple")
TVM_REGISTER_GLOBAL("runtime.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) {
......@@ -65,7 +65,7 @@ TVM_REGISTER_GLOBAL("runtime.container._Tuple")
*rv = ADT::Tuple(fields);
});
TVM_REGISTER_GLOBAL("runtime.container._ADT")
TVM_REGISTER_GLOBAL("runtime.ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
......@@ -76,11 +76,31 @@ TVM_REGISTER_GLOBAL("runtime.container._ADT")
*rv = ADT(tag, fields);
});
TVM_REGISTER_GLOBAL("runtime.container._String")
TVM_REGISTER_GLOBAL("runtime.String")
.set_body_typed([](std::string str) {
return String(std::move(str));
});
TVM_REGISTER_GLOBAL("runtime.GetStringSize")
.set_body_typed([](String str) {
return static_cast<int64_t>(str.size());
});
TVM_REGISTER_GLOBAL("runtime.GetStdString")
.set_body_typed([](String str) {
return std::string(str);
});
TVM_REGISTER_GLOBAL("runtime.CompareString")
.set_body_typed([](String lhs, String rhs) {
return lhs.compare(rhs);
});
TVM_REGISTER_GLOBAL("runtime.StringHash")
.set_body_typed([](String str) {
return static_cast<int64_t>(std::hash<String>()(str));
});
TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
......
......@@ -57,7 +57,7 @@ ExtractFuncInfo(const IRModule& mod) {
info.thread_axis_tags.push_back(thread_axis[i]->thread_tag);
}
}
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol)] = info;
}
return fmap;
......
......@@ -22,6 +22,7 @@
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/node/node.h>
#include <tvm/node/repr_printer.h>
#include <tvm/target/target.h>
......@@ -150,12 +151,12 @@ TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc")
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
Array<PrimExpr> tags = args[2];
Array<runtime::String> tags = args[2];
bool allow_override = args[3];
std::vector<std::string> tags_vector;
for (auto& tag : tags) {
tags_vector.push_back(tag.as<tvm::tir::StringImmNode>()->value);
tags_vector.push_back(tag);
}
generic_func
......
......@@ -126,7 +126,7 @@ void CodeGenCPU::Init(const std::string& module_name,
void CodeGenCPU::AddFunction(const PrimFunc& f) {
CodeGenLLVM::AddFunction(f);
if (f_tvm_register_system_symbol_ != nullptr) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back(
......
......@@ -128,7 +128,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
llvm::FunctionType* ftype = llvm::FunctionType::get(
ret_void ? t_void_ : t_int_, param_types, false);
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
CHECK(module_->getFunction(static_cast<std::string>(global_symbol)) == nullptr)
......
......@@ -214,7 +214,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined());
entry_func = global_symbol;
}
......
......@@ -78,7 +78,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
// reserve keywords
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
......
......@@ -56,7 +56,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
GetUniqueName("_");
// add to alloc buffer type.
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
......
......@@ -156,7 +156,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) {
arg_kinds.push_back(kind);
}
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute";
......
......@@ -147,7 +147,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
std::string whole_code = cg.Finish();
// Generate source code for compilation.
Array<Array<PrimExpr> > kernel_info;
Array<Array<runtime::String> > kernel_info;
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
......@@ -161,11 +161,10 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
code = (*f)(code).operator std::string();
}
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
std::string func_name = global_symbol;
kernel_info.push_back(Array<PrimExpr>({func_name, code}));
kernel_info.push_back({global_symbol, code});
}
std::string xclbin;
......
......@@ -90,7 +90,7 @@ runtime::Module BuildSPIRV(IRModule mod) {
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
......
......@@ -78,7 +78,7 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f) {
builder_->MakeInst(spv::OpReturn);
builder_->MakeInst(spv::OpFunctionEnd);
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
......
......@@ -536,7 +536,7 @@ runtime::Module BuildStackVM(const IRModule& mod) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenStackVM: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol;
......
......@@ -62,39 +62,39 @@ Target CreateTarget(const std::string& target_name,
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
t->options_array.push_back(tir::StringImmNode::make(item));
t->options_array.push_back(item);
if (item.find(libs_flag) == 0) {
std::stringstream ss(item.substr(libs_flag.length()));
std::string lib_item;
while (std::getline(ss, lib_item, ',')) {
t->libs_array.push_back(tir::StringImmNode::make(lib_item));
t->libs_array.push_back(lib_item);
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
t->keys_array.push_back(tir::StringImmNode::make(t->device_name));
t->keys_array.push_back(t->device_name);
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
t->keys_array.push_back(tir::StringImmNode::make(key_item));
t->keys_array.push_back(key_item);
}
}
}
if (t->device_name.length() > 0) {
t->keys_array.push_back(tir::StringImmNode::make(t->device_name));
t->keys_array.push_back(t->device_name);
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
if (target_name == "c" && t->device_name == "micro_dev") {
t->device_type = kDLMicroDev;
} else if (target_name == "c" || target_name == "llvm") {
t->keys_array.push_back(tir::StringImmNode::make("cpu"));
t->keys_array.push_back("cpu");
} else if (target_name == "cuda" || target_name == "nvptx") {
t->device_type = kDLGPU;
t->keys_array.push_back(tir::StringImmNode::make("cuda"));
t->keys_array.push_back(tir::StringImmNode::make("gpu"));
t->keys_array.push_back("cuda");
t->keys_array.push_back("gpu");
t->max_num_threads = 1024;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
......@@ -104,8 +104,8 @@ Target CreateTarget(const std::string& target_name,
} else {
t->device_type = kDLROCM;
}
t->keys_array.push_back(tir::StringImmNode::make(target_name));
t->keys_array.push_back(tir::StringImmNode::make("gpu"));
t->keys_array.push_back(target_name);
t->keys_array.push_back("gpu");
t->max_num_threads = 256;
if (t->device_name == "intel_graphics") {
t->thread_warp_size = 16;
......@@ -116,20 +116,20 @@ Target CreateTarget(const std::string& target_name,
} else {
t->device_type = kDLVulkan;
}
t->keys_array.push_back(tir::StringImmNode::make(target_name));
t->keys_array.push_back(tir::StringImmNode::make("gpu"));
t->keys_array.push_back(target_name);
t->keys_array.push_back("gpu");
t->max_num_threads = 256;
} else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL;
t->keys_array.push_back(tir::StringImmNode::make("sdaccel"));
t->keys_array.push_back(tir::StringImmNode::make("hls"));
t->keys_array.push_back("sdaccel");
t->keys_array.push_back("hls");
} else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
t->device_type = kDLAOCL;
t->keys_array.push_back(tir::StringImmNode::make("aocl"));
t->keys_array.push_back(tir::StringImmNode::make("hls"));
t->keys_array.push_back("aocl");
t->keys_array.push_back("hls");
} else if (target_name == "opengl") {
t->device_type = kOpenGL;
t->keys_array.push_back(tir::StringImmNode::make("opengl"));
t->keys_array.push_back("opengl");
} else if (target_name == "stackvm") {
t->device_type = kDLCPU;
} else if (target_name == "ext_dev") {
......@@ -168,7 +168,7 @@ TVM_REGISTER_GLOBAL("target.TargetFromString")
std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> result;
for (auto& expr : keys_array) {
result.push_back(expr.as<tir::StringImmNode>()->value);
result.push_back(expr);
}
return result;
}
......@@ -176,7 +176,7 @@ std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> TargetNode::options() const {
std::vector<std::string> result;
for (auto& expr : options_array) {
result.push_back(expr.as<tir::StringImmNode>()->value);
result.push_back(expr);
}
return result;
}
......@@ -184,7 +184,7 @@ std::vector<std::string> TargetNode::options() const {
std::unordered_set<std::string> TargetNode::libs() const {
std::unordered_set<std::string> result;
for (auto& expr : libs_array) {
result.insert(expr.as<tir::StringImmNode>()->value);
result.insert(expr);
}
return result;
}
......
......@@ -47,7 +47,6 @@ Var::Var(std::string name_hint, Type type_annotation) {
data_ = std::move(n);
}
Var Var::copy_with_suffix(const std::string& suffix) const {
const VarNode* node = get();
ObjectPtr<VarNode> new_ptr;
......@@ -826,20 +825,28 @@ TVM_REGISTER_GLOBAL("tir.Load")
}
});
TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([](
DataType type, std::string name,
Array<PrimExpr> args, int call_type,
Array<ObjectRef> args, int call_type,
FunctionRef func, int value_index
) {
Array<PrimExpr> prim_expr_args;
for (const auto& it : args) {
CHECK(it->IsInstance<runtime::StringObj>() ||
it->IsInstance<PrimExprNode>());
if (const auto* str = it.as<runtime::StringObj>()) {
prim_expr_args.push_back(StringImmNode::make(str->data));
} else {
prim_expr_args.push_back(Downcast<PrimExpr>(it));
}
}
return CallNode::make(type,
name,
args,
static_cast<CallNode::CallType>(call_type),
func,
value_index);
name,
prim_expr_args,
static_cast<CallNode::CallType>(call_type),
func,
value_index);
});
} // namespace tir
......
......@@ -120,10 +120,10 @@ class IRTransformer final :
Stmt IRTransform(Stmt ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const Array<PrimExpr>& only_enable) {
const Array<runtime::String>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
for (PrimExpr s : only_enable) {
only_type_index.insert(Object::TypeKey2Index(s.as<StringImmNode>()->value.c_str()));
for (auto s : only_enable) {
only_type_index.insert(Object::TypeKey2Index(s.c_str()));
}
IRTransformer transform(f_preorder, f_postorder, only_type_index);
return transform(std::move(ir_node));
......
......@@ -124,7 +124,7 @@ Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) {
const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return PrimFuncPass(pass_func, pass_info);
}
......
......@@ -42,7 +42,8 @@ void BinderAddAssert(PrimExpr cond,
if (!is_one(scond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint";
asserts->emplace_back(AssertStmtNode::make(scond, os.str(), EvaluateNode::make(0)));
asserts->emplace_back(AssertStmtNode::make(scond, tvm::tir::StringImmNode::make(os.str()),
EvaluateNode::make(0)));
}
}
......@@ -173,7 +174,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
ndim_err_msg << arg_name
<< ".ndim is expected to equal "
<< buffer->shape.size();
asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str());
asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
// type checks
DataType dtype = buffer->dtype;
std::ostringstream type_err_msg;
......@@ -187,7 +189,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
if (!(dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1))) {
asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str());
asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop));
}
// data field
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
......@@ -245,9 +249,10 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
stride_err_msg << arg_name << ".strides:"
<< " expected to be compact array";
if (conds.size() != 0) {
auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str());
Stmt check =
AssertStmtNode::make(arith::ComputeReduce<tir::AndNode>(conds, PrimExpr()),
stride_err_msg.str(), EvaluateNode::make(0));
stride_msg, EvaluateNode::make(0));
check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
}
......@@ -269,9 +274,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
} else {
std::ostringstream stride_null_err_msg;
stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
asserts_.emplace_back(
AssertStmtNode::make(
NotNode::make(is_null), stride_null_err_msg.str(), nop));
asserts_.emplace_back(AssertStmtNode::make(
NotNode::make(is_null), tvm::tir::StringImmNode::make(stride_null_err_msg.str()), nop));
for (size_t k = 0; k < buffer->strides.size(); ++k) {
std::ostringstream field_name;
......
......@@ -159,8 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
}
});
return IRTransform(parent_for_stmt, nullptr, replace_target_for,
{PrimExpr("For")});
return IRTransform(parent_for_stmt, nullptr, replace_target_for, {"For"});
}
// Remove IfThenElse node from a For node.
......@@ -186,11 +185,9 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
}
});
then_for = IRTransform(for_stmt, nullptr, replace_then_case,
{PrimExpr("IfThenElse")});
then_for = IRTransform(for_stmt, nullptr, replace_then_case, {"IfThenElse"});
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
else_for = IRTransform(for_stmt, nullptr, replace_else_case,
{PrimExpr("IfThenElse")});
else_for = IRTransform(for_stmt, nullptr, replace_else_case, {"IfThenElse"});
}
return std::make_pair(then_for, else_for);
......@@ -411,7 +408,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
*ret = new_for;
}
});
return IRTransform(stmt, nullptr, replace_top_for, {PrimExpr("For")});
return IRTransform(stmt, nullptr, replace_top_for, {runtime::String("For")});
}
Stmt HoistIfThenElse(Stmt stmt) {
......
......@@ -860,7 +860,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
auto it = matrix_abc_.find(simplify_name(node->name));
CHECK(it != matrix_abc_.end())
<< "Cannot find matrix info for " << node->name;
auto matrix_abc = "wmma." + it->second;
auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second);
Stmt body = this->VisitStmt(op->body);
return AttrStmtNode::make(op->node,
op->attr_key,
......
......@@ -47,7 +47,8 @@ class DeviceTypeBinder: public StmtExprMutator {
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmtNode::make(op->value == value, os.str(), body);
return AssertStmtNode::make(op->value == value, tvm::tir::StringImmNode::make(os.str()),
body);
}
}
return StmtExprMutator::VisitStmt_(op);
......
......@@ -41,12 +41,13 @@ namespace tvm {
namespace tir {
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImmNode::make(msg),
EvaluateNode::make(0));
}
PrimFunc MakePackedAPI(PrimFunc&& func,
int num_unpacked_args) {
auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
std::string name_hint = global_symbol;
......@@ -140,17 +141,19 @@ PrimFunc MakePackedAPI(PrimFunc&& func,
AssertStmtNode::make(tcode == kTVMOpaqueHandle ||
tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle ||
tcode == kTVMNullptr, msg.str(), nop));
tcode == kTVMNullptr,
tvm::tir::StringImmNode::make(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop));
seq_check.emplace_back(
AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImmNode::make(msg.str()), nop));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(
AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop));
AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImmNode::make(msg.str()), nop));
}
} else {
args.push_back(v_arg);
......
......@@ -76,12 +76,10 @@ class ThreadAxisRewriter : private StmtExprMutator {
};
PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
PrimFunc RemapThreadAxis(PrimFunc&& f, Map<runtime::String, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
const StringImmNode* str = kv.first.as<StringImmNode>();
CHECK(str != nullptr);
tmap[str->value] = kv.second;
tmap[kv.first] = kv.second;
}
auto thread_axis = f->GetAttr<Array<IterVar> >(tir::attr::kDeviceThreadAxis);
......@@ -101,7 +99,7 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
namespace transform {
Pass RemapThreadAxis(Map<PrimExpr, IterVar> thread_map) {
Pass RemapThreadAxis(Map<runtime::String, IterVar> thread_map) {
auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) {
return RemapThreadAxis(std::move(f), thread_map);
};
......
......@@ -272,7 +272,7 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "SplitHostDevice: Require the target attribute";
auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
......
......@@ -261,7 +261,7 @@ TEST(String, empty) {
using namespace std;
String s{"hello"};
CHECK_EQ(s.empty(), false);
s = "";
s = std::string("");
CHECK_EQ(s.empty(), true);
}
......
......@@ -231,7 +231,7 @@ def test_composite_function():
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
add_relu = add_relu.with_attr("Composite", "test.add_relu")
# merged function
r = relay.Call(add_relu, [a, b])
......@@ -249,7 +249,7 @@ def test_composite_function():
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
add_relu = add_relu.with_attr("Composite", "test.add_relu")
# merged function
cb_1 = relay.annotation.compiler_begin(a, "test")
......
......@@ -134,7 +134,7 @@ def test_recursive_func():
func = relay.Function([i],
sb.get(),
ret_type=relay.TensorType([], 'int32'))
func = func.with_attr("Compiler", tvm.tir.StringImm("a"))
func = func.with_attr("Compiler", "a")
mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
......
......@@ -79,9 +79,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
def set_external_func_attr(func, compiler, ext_symbol):
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm(compiler))
func = func.with_attr("global_symbol",
runtime.container.String(ext_symbol))
func = func.with_attr("Compiler", compiler)
func = func.with_attr("global_symbol", ext_symbol)
return func
......
......@@ -96,12 +96,14 @@ def test_function():
body = relay.Tuple(tvm.runtime.convert([]))
type_params = tvm.runtime.convert([])
fn = relay.Function(params, body, ret_type, type_params)
fn = fn.with_attr("test_attribute", tvm.tir.StringImm("value"))
fn = fn.with_attr("test_attribute", "value")
fn = fn.with_attr("test_attribute1", "value1")
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
assert fn.attrs["test_attribute"] == "value"
assert fn.attrs["test_attribute1"] == "value1"
str(fn)
check_json_roundtrip(fn)
......
......@@ -356,7 +356,7 @@ def test_function_attr():
p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00)
func0 = func0.with_attr("FuncName", tvm.runtime.container.String("a"))
func0 = func0.with_attr("FuncName", "a")
x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10))
......@@ -366,7 +366,7 @@ def test_function_attr():
p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.with_attr("FuncName", tvm.runtime.container.String("b"))
func1 = func1.with_attr("FuncName", "b")
assert not consistent_equal(func0, func1)
......@@ -698,7 +698,7 @@ def test_fn_attribute():
d = relay.var('d', shape=(10, 10))
add_1 = relay.add(c, d)
add_1_fn = relay.Function([c, d], add_1)
add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.runtime.container.String("test"))
add_1_fn = add_1_fn.with_attr("TestAttribute", "test")
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
assert not consistent_equal(add_1_fn, add_fn)
......
......@@ -209,7 +209,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler():
g11 = relay.GlobalVar("g11")
fn11 = relay.Function([x11], x11)
fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a"))
fn11 = fn11.with_attr("Compiler", "a")
mod[g11] = fn11
x1 = relay.var("x1", shape=(3, 5))
......@@ -244,7 +244,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler():
x11 = relay.var("x11", shape=(3, 5))
fn11 = relay.Function([x11], x11)
fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a"))
fn11 = fn11.with_attr("Compiler", "a")
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
......@@ -367,7 +367,7 @@ def test_recursive_not_called_extern_compiler():
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
fn1 = fn1.with_attr("Compiler", "a")
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
mod["main"] = relay.Function([x, y], x + y + g1(x))
......@@ -380,7 +380,7 @@ def test_recursive_not_called_extern_compiler():
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
fn1 = fn1.with_attr("Compiler", "a")
mod["main"] = relay.Function([x, y], x + y + fn1(x))
return mod
......@@ -446,7 +446,7 @@ def test_globalvar_as_call_arg_extern_compiler():
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
fn1 = fn1.with_attr("Compiler", "a")
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
......@@ -456,7 +456,7 @@ def test_globalvar_as_call_arg_extern_compiler():
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
fn2 = fn2.with_attr("Compiler", "b")
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
......@@ -478,7 +478,7 @@ def test_globalvar_as_call_arg_extern_compiler():
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
fn1 = fn1.with_attr("Compiler", "a")
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
......@@ -486,7 +486,7 @@ def test_globalvar_as_call_arg_extern_compiler():
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
fn2 = fn2.with_attr("Compiler", "b")
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
......@@ -539,10 +539,10 @@ def test_inline_globalvar_without_args_extern_compiler():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
fn1 = fn1.with_attr("Compiler", "a")
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
fn2 = fn2.with_attr("Compiler", "b")
g1 = relay.GlobalVar('g1')
g2 = relay.GlobalVar('g2')
mod[g1] = fn1
......@@ -555,10 +555,10 @@ def test_inline_globalvar_without_args_extern_compiler():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
fn1 = fn1.with_attr("Compiler", "a")
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
fn2 = fn2.with_attr("Compiler", "b")
p = relay.var('p', 'bool')
mod['main'] = relay.Function([p], relay.Call(
relay.If(p, fn1, fn2), []))
......@@ -787,7 +787,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa"))
fn0 = fn0.with_attr("Compiler", "aa")
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
......@@ -811,7 +811,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa"))
fn0 = fn0.with_attr("Compiler", "aa")
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
......
......@@ -184,7 +184,7 @@ def test_simple_merge():
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
add_relu = add_relu.with_attr("Composite", "add_relu")
# merged function
r = relay.Call(add_relu, [a, b])
......@@ -249,8 +249,7 @@ def test_branch_merge():
sub_node = relay.subtract(in_1, in_2)
mul_node = relay.multiply(add_node, sub_node)
add_sub_mul = relay.Function([in_1, in_2], mul_node)
add_sub_mul = add_sub_mul.with_attr("Composite",
tir.StringImm("add_sub_mul"))
add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul")
# add_sub_mul1 function
in_3 = relay.var('in_3', shape=(10, 10))
......@@ -259,8 +258,7 @@ def test_branch_merge():
sub_node_1 = relay.subtract(in_3, in_4)
mul_node_1 = relay.multiply(add_node_1, sub_node_1)
add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
add_sub_mul_1 = add_sub_mul_1.with_attr("Composite",
tir.StringImm("add_sub_mul"))
add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", "add_sub_mul")
# merged function
m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
......@@ -319,8 +317,7 @@ def test_reuse_call_merge():
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
add_add_add = add_add_add.with_attr("Composite",
tir.StringImm("add_add_add"))
add_add_add = add_add_add.with_attr("Composite", "add_add_add")
# merged function
sub_node = relay.subtract(a, b)
......@@ -404,7 +401,7 @@ def test_multiple_patterns():
r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite",
tir.StringImm("conv2d_bias_relu"))
"conv2d_bias_relu")
# add_relu function
in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
......@@ -412,7 +409,7 @@ def test_multiple_patterns():
add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r)
add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
add_relu = add_relu.with_attr("Composite", "add_relu")
# merged function
conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
......@@ -481,8 +478,7 @@ def test_merge_order():
out = relay.abs(out)
out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out)
merged_func = merged_func.with_attr('Composite',
tir.StringImm(composite_name))
merged_func = merged_func.with_attr('Composite', composite_name)
ret = relay.Call(merged_func, [input_1, input_2])
return relay.Function([input_1, input_2], ret)
......@@ -547,13 +543,13 @@ def test_parallel_merge():
y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1)
func_1 = func_1.with_attr('Composite', tir.StringImm("add_sub_mul"))
func_1 = func_1.with_attr('Composite', "add_sub_mul")
call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1')
y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2)
func_2 = func_2.with_attr('Composite', tir.StringImm("add_sub_mul"))
func_2 = func_2.with_attr('Composite', "add_sub_mul")
call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2)
return relay.Function([input_1, input_2], out)
......@@ -632,14 +628,14 @@ def test_multiple_input_subgraphs():
add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1)
add_relu_1 = add_relu_1.with_attr('Composite', tir.StringImm('add_relu'))
add_relu_1 = add_relu_1.with_attr('Composite', 'add_relu')
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1')
y1 = relay.var('y1')
add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2)
add_relu_2 = add_relu_2.with_attr('Composite', tir.StringImm('add_relu'))
add_relu_2 = add_relu_2.with_attr('Composite', 'add_relu')
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2')
y2 = relay.var('y2')
......@@ -647,7 +643,7 @@ def test_multiple_input_subgraphs():
sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul)
add_sub_mul = add_sub_mul.with_attr('Composite', tir.StringImm('add_sub_mul'))
add_sub_mul = add_sub_mul.with_attr('Composite', 'add_sub_mul')
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call)
......@@ -660,7 +656,7 @@ def test_multiple_input_subgraphs():
add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu)
add_relu = add_relu.with_attr('Composite', tir.StringImm('add_relu'))
add_relu = add_relu.with_attr('Composite', 'add_relu')
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call)
......@@ -720,7 +716,7 @@ def test_tuple_get_item_merge():
tuple_get_item_node = bn_node[0]
relu_node = relay.nn.relu(tuple_get_item_node)
bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node)
bn_relu = bn_relu.with_attr("Composite", tir.StringImm("bn_relu"))
bn_relu = bn_relu.with_attr("Composite", "bn_relu")
# merged function
r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var])
......
......@@ -24,7 +24,6 @@ import tvm
import tvm.relay.testing
from tvm import relay
from tvm import runtime
from tvm.runtime import container
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay.op.annotation import compiler_begin, compiler_end
......@@ -307,8 +306,8 @@ def test_extern_ccompiler_default_ops():
func = relay.Function([x0, y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
func = func.with_attr("global_symbol", container.String("ccompiler_0"))
func = func.with_attr("Compiler", "ccompiler")
func = func.with_attr("global_symbol", "ccompiler_0")
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
......@@ -392,8 +391,8 @@ def test_extern_dnnl():
func = relay.Function([data0, input0, input1], out)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl"))
func = func.with_attr("global_symbol", container.String("dnnl_0"))
func = func.with_attr("Compiler", "dnnl")
func = func.with_attr("global_symbol", "dnnl_0")
glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule()
mod[glb_var] = func
......@@ -518,10 +517,8 @@ def test_function_lifting():
bn.astuple())
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.with_attr("global_symbol",
container.String("test_compiler_0"))
func0 = func0.with_attr("Compiler", "test_compiler")
func0 = func0.with_attr("global_symbol", "test_compiler_0")
gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0
......@@ -537,10 +534,8 @@ def test_function_lifting():
func1 = relay.Function([data1, weight1], conv)
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func1 = func1.with_attr("global_symbol",
container.String("test_compiler_1"))
func1 = func1.with_attr("Compiler", "test_compiler")
func1 = func1.with_attr("global_symbol", "test_compiler_1")
gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1
......@@ -611,10 +606,8 @@ def test_function_lifting_inline():
bn.astuple())
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.with_attr("global_symbol",
container.String("test_compiler_0"))
func0 = func0.with_attr("Compiler", "test_compiler")
func0 = func0.with_attr("global_symbol", "test_compiler_0")
# main function
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
......@@ -648,8 +641,8 @@ def test_constant_propagation():
func = relay.Function([y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
func = func.with_attr("global_symbol", container.String("ccompiler_0"))
func = func.with_attr("Compiler", "ccompiler")
func = func.with_attr("global_symbol", "ccompiler_0")
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [y])
......@@ -748,10 +741,8 @@ def test_multiple_outputs():
bn_mean, bn_var], tuple_o)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("global_symbol",
container.String("test_target_2"))
func0 = func0.with_attr("Compiler", "test_target")
func0 = func0.with_attr("global_symbol", "test_target_2")
gv0 = relay.GlobalVar("test_target_2")
mod[gv0] = func0
......@@ -816,10 +807,8 @@ def test_mixed_single_multiple_outputs():
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func1 = func1.with_attr("global_symbol",
container.String("test_target_1"))
func1 = func1.with_attr("Compiler", "test_target")
func1 = func1.with_attr("global_symbol", "test_target_1")
gv1 = relay.GlobalVar("test_target_1")
mod[gv1] = func1
......@@ -831,10 +820,8 @@ def test_mixed_single_multiple_outputs():
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("global_symbol",
container.String("test_target_0"))
func0 = func0.with_attr("Compiler", "test_target")
func0 = func0.with_attr("global_symbol", "test_target_0")
gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0
......
......@@ -41,7 +41,7 @@ def test_dict_attrs():
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
assert dattr.name.value == "xyz"
assert dattr.name == "xyz"
assert isinstance(dattr, tvm.ir.DictAttrs)
assert "name" in dattr
assert dattr["x"].value == 1
......
......@@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
PrimExpr("tvm.contrib.cublas.matmul"),
runtime::String("tvm.contrib.cublas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
......@@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs,
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
PrimExpr("tvm.contrib.cublas.batch_matmul"),
runtime::String("tvm.contrib.cublas.batch_matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
......
......@@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
PrimExpr("tvm.contrib.rocblas.matmul"),
runtime::String("tvm.contrib.rocblas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment