Unverified Commit 8c31d0dd by Zhi Committed by GitHub

Remove PrimExpr from String (#5311)

parent 92d0ec14
...@@ -108,12 +108,6 @@ class PrimExpr : public BaseExpr { ...@@ -108,12 +108,6 @@ class PrimExpr : public BaseExpr {
*/ */
TVM_DLL PrimExpr(float value); // NOLINT(*) TVM_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from runtime String.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(runtime::String value); // NOLINT(*)
/*! \return the data type of this expression. */ /*! \return the data type of this expression. */
DataType dtype() const { DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype; return static_cast<const PrimExprNode*>(get())->dtype;
......
...@@ -40,9 +40,6 @@ PrimExpr::PrimExpr(int32_t value) ...@@ -40,9 +40,6 @@ PrimExpr::PrimExpr(int32_t value)
PrimExpr::PrimExpr(float value) PrimExpr::PrimExpr(float value)
: PrimExpr(FloatImm(DataType::Float(32), value)) {} : PrimExpr(FloatImm(DataType::Float(32), value)) {}
PrimExpr::PrimExpr(runtime::String value)
: PrimExpr(tir::StringImmNode::make(value)) {}
PrimExpr PrimExpr::FromObject_(ObjectRef ref) { PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
using runtime::ObjectTypeChecker; using runtime::ObjectTypeChecker;
if (auto* ptr = ref.as<tir::IterVarNode>()) { if (auto* ptr = ref.as<tir::IterVarNode>()) {
......
...@@ -137,7 +137,7 @@ Target CreateTarget(const std::string& target_name, ...@@ -137,7 +137,7 @@ Target CreateTarget(const std::string& target_name,
} else if (target_name == "hybrid") { } else if (target_name == "hybrid") {
t->device_type = kDLCPU; t->device_type = kDLCPU;
} else if (target_name == "hexagon") { } else if (target_name == "hexagon") {
t->keys_array.push_back(runtime::String("hexagon")); t->keys_array.push_back("hexagon");
t->device_type = kDLHexagon; t->device_type = kDLHexagon;
} else { } else {
LOG(ERROR) << "Unknown target name " << target_name; LOG(ERROR) << "Unknown target name " << target_name;
......
...@@ -58,7 +58,6 @@ Stmt AttrStmtNode::make(ObjectRef node, ...@@ -58,7 +58,6 @@ Stmt AttrStmtNode::make(ObjectRef node,
TVM_REGISTER_GLOBAL("tir.AttrStmt") TVM_REGISTER_GLOBAL("tir.AttrStmt")
.set_body_typed(AttrStmtNode::make); .set_body_typed(AttrStmtNode::make);
Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
CHECK(condition.defined()); CHECK(condition.defined());
CHECK(message.dtype() == DataType::Int(32) || CHECK(message.dtype() == DataType::Int(32) ||
...@@ -74,8 +73,14 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { ...@@ -74,8 +73,14 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
} }
TVM_REGISTER_GLOBAL("tir.AssertStmt") TVM_REGISTER_GLOBAL("tir.AssertStmt")
.set_body_typed(AssertStmtNode::make); .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
if (const auto* str = message.as<StringObj>()) {
auto msg = StringImmNode::make(str->data);
return AssertStmtNode::make(condition, msg, body);
} else {
return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
}
});
Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) { Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
CHECK(body.defined()); CHECK(body.defined());
......
...@@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, ...@@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
{ { n, m } }, { lhs->dtype }, { lhs, rhs }, { { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) { [&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({ return call_packed({
runtime::String("tvm.contrib.cublas.matmul"), StringImmNode::make("tvm.contrib.cublas.matmul"),
pack_buffer(ins[0]), pack_buffer(ins[0]),
pack_buffer(ins[1]), pack_buffer(ins[1]),
pack_buffer(outs[0]), pack_buffer(outs[0]),
...@@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs, ...@@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs,
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs }, { { b, n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) { [&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({ return call_packed({
runtime::String("tvm.contrib.cublas.batch_matmul"), StringImmNode::make("tvm.contrib.cublas.batch_matmul"),
pack_buffer(ins[0]), pack_buffer(ins[0]),
pack_buffer(ins[1]), pack_buffer(ins[1]),
pack_buffer(outs[0]), pack_buffer(outs[0]),
......
...@@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs, ...@@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
{ { n, m } }, { lhs->dtype }, { lhs, rhs }, { { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) { [&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({ return call_packed({
runtime::String("tvm.contrib.rocblas.matmul"), StringImmNode::make("tvm.contrib.rocblas.matmul"),
pack_buffer(ins[0]), pack_buffer(ins[0]),
pack_buffer(ins[1]), pack_buffer(ins[1]),
pack_buffer(outs[0]), pack_buffer(outs[0]),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment