Commit 84b29b7b by Tianqi Chen

Fix warnings from Logging, enable Plan memory to take external memory (#93)

* Fix warnings from Logging, enable Plan memory to take external memory

* fix external memory id

* fix graph
parent 30912b1d
...@@ -747,8 +747,8 @@ inline void JSONWriter::BeginArray(bool multi_line) { ...@@ -747,8 +747,8 @@ inline void JSONWriter::BeginArray(bool multi_line) {
} }
inline void JSONWriter::EndArray() { inline void JSONWriter::EndArray() {
CHECK_NE(scope_multi_line_.size(), 0); CHECK_NE(scope_multi_line_.size(), 0U);
CHECK_NE(scope_counter_.size(), 0); CHECK_NE(scope_counter_.size(), 0U);
bool newline = scope_multi_line_.back(); bool newline = scope_multi_line_.back();
size_t nelem = scope_counter_.back(); size_t nelem = scope_counter_.back();
scope_multi_line_.pop_back(); scope_multi_line_.pop_back();
...@@ -764,8 +764,8 @@ inline void JSONWriter::BeginObject(bool multi_line) { ...@@ -764,8 +764,8 @@ inline void JSONWriter::BeginObject(bool multi_line) {
} }
inline void JSONWriter::EndObject() { inline void JSONWriter::EndObject() {
CHECK_NE(scope_multi_line_.size(), 0); CHECK_NE(scope_multi_line_.size(), 0U);
CHECK_NE(scope_counter_.size(), 0); CHECK_NE(scope_counter_.size(), 0U);
bool newline = scope_multi_line_.back(); bool newline = scope_multi_line_.back();
size_t nelem = scope_counter_.back(); size_t nelem = scope_counter_.back();
scope_multi_line_.pop_back(); scope_multi_line_.pop_back();
...@@ -842,7 +842,7 @@ inline void JSONObjectReadHelper::ReadAllFields(JSONReader *reader) { ...@@ -842,7 +842,7 @@ inline void JSONObjectReadHelper::ReadAllFields(JSONReader *reader) {
for (std::map<std::string, Entry>::iterator for (std::map<std::string, Entry>::iterator
it = map_.begin(); it != map_.end(); ++it) { it = map_.begin(); it != map_.end(); ++it) {
if (it->second.optional) continue; if (it->second.optional) continue;
CHECK_NE(visited.count(it->first), 0) CHECK_NE(visited.count(it->first), 0U)
<< "JSONReader: Missing field \"" << it->first << "\"\n At " << "JSONReader: Missing field \"" << it->first << "\"\n At "
<< reader->line_info(); << reader->line_info();
} }
...@@ -857,7 +857,7 @@ inline void JSONObjectReadHelper::ReaderFunction(JSONReader *reader, void *addr) ...@@ -857,7 +857,7 @@ inline void JSONObjectReadHelper::ReaderFunction(JSONReader *reader, void *addr)
template<typename T> template<typename T>
inline void JSONObjectReadHelper:: inline void JSONObjectReadHelper::
DeclareFieldInternal(const std::string &key, T *addr, bool optional) { DeclareFieldInternal(const std::string &key, T *addr, bool optional) {
CHECK_EQ(map_.count(key), 0) CHECK_EQ(map_.count(key), 0U)
<< "Adding duplicate field " << key; << "Adding duplicate field " << key;
Entry e; Entry e;
e.func = ReaderFunction<T>; e.func = ReaderFunction<T>;
......
/*!
* Copyright (c) 2016 by Contributors
* \file optional.h
* \brief Container to hold optional data.
*/
#ifndef DMLC_OPTIONAL_H_
#define DMLC_OPTIONAL_H_
#include <iostream>
#include <utility>
#include <algorithm>
#include "./base.h"
#include "./logging.h"
#include "./type_traits.h"
namespace dmlc {
/*! \brief dummy type for assign null to optional */
struct nullopt_t {
#if defined(_MSC_VER) && _MSC_VER < 1900
/*! \brief dummy constructor */
explicit nullopt_t(int) {}
#else
/*! \brief dummy constructor */
constexpr nullopt_t(int) {}
#endif
};
/*! Assign null to optional: optional<T> x = nullopt; */
constexpr const nullopt_t nullopt = nullopt_t(0);
/*!
* \brief c++17 compatible optional class.
*
* At any time an optional<T> instance either
* hold no value (string representation "None")
* or hold a value of type T.
*/
template<typename T>
class optional {
public:
/*! \brief construct an optional object that contains no value */
optional() : is_none(true) {}
/*! \brief construct an optional object with value */
explicit optional(const T& value) {
is_none = false;
new (&val) T(value);
}
/*! \brief construct an optional object with another optional object */
optional(const optional<T>& other) {
is_none = other.is_none;
if (!is_none) {
new (&val) T(other.value());
}
}
/*! \brief deconstructor */
~optional() {
if (!is_none) {
reinterpret_cast<T*>(&val)->~T();
}
}
/*! \brief swap two optional */
void swap(optional<T>& other) {
std::swap(val, other.val);
std::swap(is_none, other.is_none);
}
/*! \brief set this object to hold value
* \param value the value to hold
* \return return self to support chain assignment
*/
optional<T>& operator=(const T& value) {
(optional<T>(value)).swap(*this);
return *this;
}
/*! \brief set this object to hold the same value with other
* \param other the other object
* \return return self to support chain assignment
*/
optional<T>& operator=(const optional<T> &other) {
(optional<T>(other)).swap(*this);
return *this;
}
/*! \brief clear the value this object is holding.
* optional<T> x = nullopt;
*/
optional<T>& operator=(nullopt_t) {
(optional<T>()).swap(*this);
return *this;
}
/*! \brief non-const dereference operator */
T& operator*() { // NOLINT(*)
return *reinterpret_cast<T*>(&val);
}
/*! \brief const dereference operator */
const T& operator*() const {
return *reinterpret_cast<const T*>(&val);
}
/*! \brief return the holded value.
* throws std::logic_error if holding no value
*/
const T& value() const {
if (is_none) {
throw std::logic_error("bad optional access");
}
return *reinterpret_cast<const T*>(&val);
}
/*! \brief whether this object is holding a value */
explicit operator bool() const { return !is_none; }
private:
// whether this is none
bool is_none;
// on stack storage of value
typename std::aligned_storage<sizeof(T), alignof(T)>::type val;
};
/*! \brief serialize an optional object to string.
*
* \code
* dmlc::optional<int> x;
* std::cout << x; // None
* x = 0;
* std::cout << x; // 0
* \endcode
*
* \param os output stream
* \param t source optional<T> object
* \return output stream
*/
template<typename T>
std::ostream &operator<<(std::ostream &os, const optional<T> &t) {
if (t) {
os << *t;
} else {
os << "None";
}
return os;
}
/*! \brief parse a string object into optional<T>
*
* \code
* dmlc::optional<int> x;
* std::string s1 = "1";
* std::istringstream is1(s1);
* s1 >> x; // x == optional<int>(1)
*
* std::string s2 = "None";
* std::istringstream is2(s2);
* s2 >> x; // x == optional<int>()
* \endcode
*
* \param is input stream
* \param t target optional<T> object
* \return input stream
*/
template<typename T>
std::istream &operator>>(std::istream &is, optional<T> &t) {
char buf[4];
std::streampos origin = is.tellg();
is.read(buf, 4);
if (is.fail() || buf[0] != 'N' || buf[1] != 'o' ||
buf[2] != 'n' || buf[3] != 'e') {
is.clear();
is.seekg(origin);
T x;
is >> x;
t = x;
} else {
t = nullopt;
}
return is;
}
/*! \brief description for optional int */
DMLC_DECLARE_TYPE_NAME(optional<int>, "int or None");
} // namespace dmlc
#endif // DMLC_OPTIONAL_H_
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "./json.h" #include "./json.h"
#include "./logging.h" #include "./logging.h"
#include "./type_traits.h" #include "./type_traits.h"
#include "./optional.h"
namespace dmlc { namespace dmlc {
// this file is backward compatible with non-c++11 // this file is backward compatible with non-c++11
...@@ -758,7 +759,7 @@ class FieldEntry<int> ...@@ -758,7 +759,7 @@ class FieldEntry<int>
// override print default // override print default
virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*) virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*)
if (is_enum_) { if (is_enum_) {
CHECK_NE(enum_back_map_.count(value), 0) CHECK_NE(enum_back_map_.count(value), 0U)
<< "Value not found in enum declared"; << "Value not found in enum declared";
os << enum_back_map_.at(value); os << enum_back_map_.at(value);
} else { } else {
...@@ -781,6 +782,115 @@ class FieldEntry<int> ...@@ -781,6 +782,115 @@ class FieldEntry<int>
} }
}; };
// specialize define for optional<int>(enum)
template<>
class FieldEntry<optional<int> >
: public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > {
public:
// construct
FieldEntry<optional<int> >() : is_enum_(false) {}
// parent
typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
if (is_enum_ && value != "None") {
std::map<std::string, int>::const_iterator it = enum_map_.find(value);
std::ostringstream os;
if (it == enum_map_.end()) {
os << "Invalid Input: \'" << value;
os << "\', valid values are: ";
PrintEnums(os);
throw dmlc::ParamError(os.str());
} else {
os << it->second;
Parent::Set(head, os.str());
}
} else {
Parent::Set(head, value);
}
}
virtual ParamFieldInfo GetFieldInfo() const {
if (is_enum_) {
ParamFieldInfo info;
std::ostringstream os;
info.name = key_;
info.type = type_;
PrintEnums(os);
if (has_default_) {
os << ',' << "optional, default=";
PrintDefaultValueString(os);
} else {
os << ", required";
}
info.type_info_str = os.str();
info.description = description_;
return info;
} else {
return Parent::GetFieldInfo();
}
}
// add enum
inline FieldEntry<optional<int> > &add_enum(const std::string &key, int value) {
CHECK_NE(key, "None") << "None is reserved for empty optional<int>";
if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
enum_back_map_.count(value) != 0) {
std::ostringstream os;
os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
os << "Enums: ";
for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
it != enum_map_.end(); ++it) {
os << "(" << it->first << ": " << it->second << "), ";
}
throw dmlc::ParamError(os.str());
}
enum_map_[key] = value;
enum_back_map_[value] = key;
is_enum_ = true;
return this->self();
}
protected:
// enum flag
bool is_enum_;
// enum map
std::map<std::string, int> enum_map_;
// enum map
std::map<int, std::string> enum_back_map_;
// override print behavior
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
os << '\'';
PrintValue(os, default_value_);
os << '\'';
}
// override print default
virtual void PrintValue(std::ostream &os, optional<int> value) const { // NOLINT(*)
if (is_enum_) {
if (!value) {
os << "None";
} else {
CHECK_NE(enum_back_map_.count(value.value()), 0U)
<< "Value not found in enum declared";
os << enum_back_map_.at(value.value());
}
} else {
os << value;
}
}
private:
inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
os << "{None";
for (std::map<std::string, int>::const_iterator
it = enum_map_.begin(); it != enum_map_.end(); ++it) {
os << ", ";
os << "\'" << it->first << '\'';
}
os << '}';
}
};
// specialize define for string // specialize define for string
template<> template<>
class FieldEntry<std::string> class FieldEntry<std::string>
......
...@@ -75,7 +75,7 @@ class Registry { ...@@ -75,7 +75,7 @@ class Registry {
* \return ref to the registered entry, used to set properties * \return ref to the registered entry, used to set properties
*/ */
inline EntryType &__REGISTER__(const std::string& name) { inline EntryType &__REGISTER__(const std::string& name) {
CHECK_EQ(fmap_.count(name), 0) CHECK_EQ(fmap_.count(name), 0U)
<< name << " already registered"; << name << " already registered";
EntryType *e = new EntryType(); EntryType *e = new EntryType();
e->name = name; e->name = name;
......
...@@ -263,11 +263,11 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -263,11 +263,11 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed"; CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
// parameter check. // parameter check.
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i]->outputs.size(), 1) CHECK_EQ(args[i]->outputs.size(), 1U)
<< "Argument " << i << " is a tuple, single value is required"; << "Argument " << i << " is a tuple, single value is required";
} }
for (const auto& kv : kwargs) { for (const auto& kv : kwargs) {
CHECK_EQ(kv.second->outputs.size(), 1) CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required"; << "Keyword Argument " << kv.first << " is a tuple, single value is required";
} }
// assign new name // assign new name
...@@ -316,7 +316,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -316,7 +316,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
} }
} }
} else { } else {
CHECK_EQ(kwargs.size(), 0) << "Variable length function do not accept kwargs"; CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(args.size()); n->inputs.reserve(args.size());
for (const Symbol* s : args) { for (const Symbol* s : args) {
n->inputs.push_back(s->outputs[0]); n->inputs.push_back(s->outputs[0]);
...@@ -325,7 +325,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -325,7 +325,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
UpdateNodeVersion(n); UpdateNodeVersion(n);
} else { } else {
// general composition // general composition
CHECK_EQ(args.size(), 0) CHECK_EQ(args.size(), 0U)
<< "General composition only support kwargs for now"; << "General composition only support kwargs for now";
size_t nmatched = 0; size_t nmatched = 0;
size_t arg_counter = 0; size_t arg_counter = 0;
...@@ -395,7 +395,7 @@ Symbol Symbol::operator () (const array_view<const Symbol*>& args, ...@@ -395,7 +395,7 @@ Symbol Symbol::operator () (const array_view<const Symbol*>& args,
} }
void Symbol::AddControlDeps(const Symbol& src) { void Symbol::AddControlDeps(const Symbol& src) {
CHECK_EQ(outputs.size(), 1) CHECK_EQ(outputs.size(), 1U)
<< "AddControlDeps only works for nongrouped symbol"; << "AddControlDeps only works for nongrouped symbol";
Node* n = outputs[0].node.get(); Node* n = outputs[0].node.get();
for (const NodeEntry& sp : src.outputs) { for (const NodeEntry& sp : src.outputs) {
......
...@@ -46,11 +46,11 @@ Graph Gradient(Graph src) { ...@@ -46,11 +46,11 @@ Graph Gradient(Graph src) {
using MirrorFun = std::function<int (const Node& node)>; using MirrorFun = std::function<int (const Node& node)>;
using AttrHintFun = std::function<NodeEntry (const NodeEntry& src, const NodeEntry &like)>; using AttrHintFun = std::function<NodeEntry (const NodeEntry& src, const NodeEntry &like)>;
CHECK_NE(src.attrs.count("grad_ys"), 0) CHECK_NE(src.attrs.count("grad_ys"), 0U)
<< "Gradient require grad_ys to be presented."; << "Gradient require grad_ys to be presented.";
CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0) CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U)
<< "Gradient require grad_ys_out_grad to be presented."; << "Gradient require grad_ys_out_grad to be presented.";
CHECK_NE(src.attrs.count("grad_xs"), 0) CHECK_NE(src.attrs.count("grad_xs"), 0U)
<< "Gradient require grad_xs to be presented."; << "Gradient require grad_xs to be presented.";
const std::vector<NodeEntry>& ys = const std::vector<NodeEntry>& ys =
src.GetAttr<std::vector<NodeEntry> >("grad_ys"); src.GetAttr<std::vector<NodeEntry> >("grad_ys");
......
...@@ -75,7 +75,7 @@ inline uint32_t ColorNodeGroup( ...@@ -75,7 +75,7 @@ inline uint32_t ColorNodeGroup(
std::vector<uint32_t> node_importance, std::vector<uint32_t> node_importance,
uint32_t max_ncolor, uint32_t max_ncolor,
std::vector<uint32_t> *color) { std::vector<uint32_t> *color) {
CHECK_NE(max_ncolor, 0); CHECK_NE(max_ncolor, 0U);
CHECK_EQ(graph.num_nodes(), node_importance.size()); CHECK_EQ(graph.num_nodes(), node_importance.size());
color->clear(); color->clear();
......
...@@ -66,7 +66,7 @@ Graph InferAttr(Graph &&ret, ...@@ -66,7 +66,7 @@ Graph InferAttr(Graph &&ret,
if (inode.source->is_variable()) { if (inode.source->is_variable()) {
// Variable node. No operator. Only one output entry. // Variable node. No operator. Only one output entry.
CHECK(inode.source->op() == nullptr); CHECK(inode.source->op() == nullptr);
CHECK_EQ(num_outputs, 1); CHECK_EQ(num_outputs, 1U);
const uint32_t out_ent_id = idx.entry_id(nid, 0); const uint32_t out_ent_id = idx.entry_id(nid, 0);
if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) { if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
auto it = inode.source->attrs.dict.find(shape_attr_key); auto it = inode.source->attrs.dict.find(shape_attr_key);
...@@ -76,7 +76,7 @@ Graph InferAttr(Graph &&ret, ...@@ -76,7 +76,7 @@ Graph InferAttr(Graph &&ret,
} }
} }
} else if (is_backward.get(inode.source->op(), false)) { } else if (is_backward.get(inode.source->op(), false)) {
CHECK_GE(inode.control_deps.size(), 1) CHECK_GE(inode.control_deps.size(), 1U)
<< "BackwardOp need to have control_deps to its forward op"; << "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
NodePtr fwd_ptr = inode.source->control_deps[0]; NodePtr fwd_ptr = inode.source->control_deps[0];
......
...@@ -15,11 +15,11 @@ namespace { ...@@ -15,11 +15,11 @@ namespace {
// simply logic to place device according to device_group hint // simply logic to place device according to device_group hint
// insert copy node when there is // insert copy node when there is
Graph PlaceDevice(Graph src) { Graph PlaceDevice(Graph src) {
CHECK_NE(src.attrs.count("device_group_attr_key"), 0) CHECK(src.attrs.count("device_group_attr_key"))
<< "Need graph attribute \"device_group_attr_key\" in PlaceDevice"; << "Need graph attribute \"device_group_attr_key\" in PlaceDevice";
CHECK_NE(src.attrs.count("device_assign_map"), 0) CHECK(src.attrs.count("device_assign_map"))
<< "Need graph attribute \"device_assign_map\" in PlaceDevice"; << "Need graph attribute \"device_assign_map\" in PlaceDevice";
CHECK_NE(src.attrs.count("device_copy_op"), 0) CHECK(src.attrs.count("device_copy_op"))
<< "Need graph attribute \"device_copy_op\" in PlaceDevice"; << "Need graph attribute \"device_copy_op\" in PlaceDevice";
std::string device_group_attr_key = src.GetAttr<std::string>("device_group_attr_key"); std::string device_group_attr_key = src.GetAttr<std::string>("device_group_attr_key");
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op")); const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
......
...@@ -21,6 +21,8 @@ class GraphAllocator { ...@@ -21,6 +21,8 @@ class GraphAllocator {
using StorageID = int; using StorageID = int;
// bad storage id // bad storage id
static const StorageID kBadStorageID = -1; static const StorageID kBadStorageID = -1;
// external storage id
static const StorageID kExternalStorageID = -2;
// request a free storage // request a free storage
StorageID Request(int dev_id, int dtype, TShape shape, uint32_t node_id) { StorageID Request(int dev_id, int dtype, TShape shape, uint32_t node_id) {
if (shape.ndim() == 0) return kBadStorageID; if (shape.ndim() == 0) return kBadStorageID;
...@@ -62,6 +64,7 @@ class GraphAllocator { ...@@ -62,6 +64,7 @@ class GraphAllocator {
// release a memory space. // release a memory space.
void Release(StorageID id, uint32_t node_id) { void Release(StorageID id, uint32_t node_id) {
CHECK_NE(id, kBadStorageID); CHECK_NE(id, kBadStorageID);
if (id == kExternalStorageID) return;
StorageEntry *e = data_[id].get(); StorageEntry *e = data_[id].get();
e->released_by_node = node_id; e->released_by_node = node_id;
free_.insert({e->max_bytes, e}); free_.insert({e->max_bytes, e});
...@@ -161,7 +164,14 @@ Graph PlanMemory(Graph ret) { ...@@ -161,7 +164,14 @@ Graph PlanMemory(Graph ret) {
++ref_count[idx.entry_id(e)]; ++ref_count[idx.entry_id(e)];
} }
// step 2: allocate memory. // step 2: allocate memory.
StorageVector storage(idx.num_node_entries(), -1); StorageVector storage;
if (ret.attrs.count("storage") != 0) {
storage = ret.MoveCopyAttr<StorageVector>("storage");
} else {
storage.resize(idx.num_node_entries(), -1);
}
std::vector<int> storage_inplace_index(idx.num_node_entries(), -1); std::vector<int> storage_inplace_index(idx.num_node_entries(), -1);
const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape"); const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype"); const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype");
......
...@@ -155,7 +155,7 @@ struct JSONGraph { ...@@ -155,7 +155,7 @@ struct JSONGraph {
// Load a graph from JSON file. // Load a graph from JSON file.
Graph LoadJSON(Graph src) { Graph LoadJSON(Graph src) {
CHECK_NE(src.attrs.count("json"), 0) CHECK_NE(src.attrs.count("json"), 0U)
<< "Load JSON require json to be presented."; << "Load JSON require json to be presented.";
const std::string &json_str = const std::string &json_str =
nnvm::get<std::string>(*src.attrs.at("json")); nnvm::get<std::string>(*src.attrs.at("json"));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment