Commit b4a6c0e7 by Tianqi Chen Committed by GitHub

[BUILD] Fix reflection build for gcc-8 (#1304)

parent 80e4bc02
Subproject commit 9b3f9753ae81d657743c555e0cacc4e43f0bed2d Subproject commit e864aa6757cdbe78b1296fe5231fd3050b7802c3
...@@ -375,41 +375,43 @@ class NodeAttrSetter : public AttrVisitor { ...@@ -375,41 +375,43 @@ class NodeAttrSetter : public AttrVisitor {
std::string type_key; std::string type_key;
std::unordered_map<std::string, runtime::TVMArgValue> attrs; std::unordered_map<std::string, runtime::TVMArgValue> attrs;
template<typename T>
void SetValue(const char* key, T* value) {
auto it = attrs.find(key);
if (it == attrs.end()) {
LOG(FATAL) << type_key << ": require field " << key;
}
*value = it->second.operator T();
attrs.erase(it);
}
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final {
SetValue(key, value); *value = GetAttr(key).operator double();
} }
void Visit(const char* key, int64_t* value) final { void Visit(const char* key, int64_t* value) final {
SetValue(key, value); *value = GetAttr(key).operator int64_t();
} }
void Visit(const char* key, uint64_t* value) final { void Visit(const char* key, uint64_t* value) final {
SetValue(key, value); *value = GetAttr(key).operator uint64_t();
} }
void Visit(const char* key, int* value) final { void Visit(const char* key, int* value) final {
SetValue(key, value); *value = GetAttr(key).operator int();
} }
void Visit(const char* key, bool* value) final { void Visit(const char* key, bool* value) final {
SetValue(key, value); *value = GetAttr(key).operator bool();
} }
void Visit(const char* key, std::string* value) final { void Visit(const char* key, std::string* value) final {
SetValue(key, value); *value = GetAttr(key).operator std::string();
} }
void Visit(const char* key, void** value) final { void Visit(const char* key, void** value) final {
SetValue(key, value); *value = GetAttr(key).operator void*();
} }
void Visit(const char* key, Type* value) final { void Visit(const char* key, Type* value) final {
SetValue(key, value); *value = GetAttr(key).operator Type();
} }
void Visit(const char* key, NodeRef* value) final { void Visit(const char* key, NodeRef* value) final {
SetValue(key, value); *value = GetAttr(key).operator NodeRef();
}
private:
runtime::TVMArgValue GetAttr(const char* key) {
auto it = attrs.find(key);
if (it == attrs.end()) {
LOG(FATAL) << type_key << ": require field " << key;
}
runtime::TVMArgValue v = it->second;
attrs.erase(it);
return v;
} }
}; };
......
...@@ -29,14 +29,14 @@ class LoopUnroller : public IRMutator { ...@@ -29,14 +29,14 @@ class LoopUnroller : public IRMutator {
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final { Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
if (op->attr_key == "pragma_auto_unroll_max_step") { if (op->attr_key == "pragma_auto_unroll_max_step") {
int value; int value = 0;
CHECK(arith::GetConstInt(op->value, &value)); CHECK(arith::GetConstInt(op->value, &value));
std::swap(value, auto_max_step_); std::swap(value, auto_max_step_);
Stmt ret = this->Mutate(op->body); Stmt ret = this->Mutate(op->body);
std::swap(value, auto_max_step_); std::swap(value, auto_max_step_);
return ret; return ret;
} else if (op->attr_key == "pragma_unroll_explicit") { } else if (op->attr_key == "pragma_unroll_explicit") {
int value; int value = 0;
CHECK(arith::GetConstInt(op->value, &value)); CHECK(arith::GetConstInt(op->value, &value));
bool explicit_unroll = value; bool explicit_unroll = value;
std::swap(explicit_unroll, explicit_unroll_); std::swap(explicit_unroll, explicit_unroll_);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment