# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import print_function import sys import os.path, re, StringIO blacklist = [ 'Windows.h', 'mach/clock.h', 'mach/mach.h', 'malloc.h', 'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h', 'sys/stat.h', 'sys/types.h', 'omp.h', 'execinfo.h', 'packet/sse-inl.h' ] def get_sources(def_file): sources = [] files = [] visited = set() mxnet_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)) for line in open(def_file): files = files + line.strip().split(' ') for f in files: f = f.strip() if not f or f.endswith('.o:') or f == '\\': continue fn = os.path.relpath(f) if os.path.abspath(f).startswith(mxnet_path) and fn not in visited: sources.append(fn) visited.add(fn) return sources sources = get_sources(sys.argv[1]) def find_source(name, start): candidates = [] for x in sources: if x == name or x.endswith('/' + name): candidates.append(x) if not candidates: return '' if len(candidates) == 1: return candidates[0] for x in candidates: if x.split('/')[1] == start.split('/')[1]: return x return '' re1 = re.compile('<([./a-zA-Z0-9_-]*)>') re2 = re.compile('"([./a-zA-Z0-9_-]*)"') sysheaders = [] history = set([]) out = StringIO.StringIO() def expand(x, pending): if x in history and x not in ['mshadow/mshadow/expr_scalar-inl.h']: # MULTIPLE includes return if x in pending: #print('loop found: %s in ' % x, pending) return print("//===== EXPANDING: %s =====\n" % x, file=out) for line in open(x): if line.find('#include') < 0: out.write(line) continue if line.strip().find('#include') > 0: print(line) continue m = re1.search(line) if not m: m = re2.search(line) if not m: print(line + ' not found') continue h = m.groups()[0].strip('./') source = find_source(h, x) if not source: if (h not in blacklist and h not in sysheaders and 'mkl' not in h and 'nnpack' not in h): sysheaders.append(h) else: expand(source, pending + [x]) print("//===== EXPANDED: %s =====\n" % x, file=out) history.add(x) expand(sys.argv[2], []) f = open(sys.argv[3], 'wb') for k in sorted(sysheaders): print("#include <%s>" % k, file=f) print('', file=f) print(out.getvalue(), file=f) for x in sources: if x not in history and not x.endswith('.o'): print('Not processed:', x)