Commit d3efd7fc by Yizhi Liu Committed by Tianqi Chen

[WIP][Frontend] Scala/Java package (#176)

* JVM package skeleton

* [JVM] link libtvm.so and list function names

* [JVM] Function & NDArray skeleton

* [JVM] TVMFuncCall in JNI

* [JVM] handle string arg in TVMFuncCall

* [JVM] get module function

* [JVM] entry function for Module

* [JVM] construct Module from function return value

* [JVM] TVMContext, TVMArray attributes

* [JVM] NDArray from / to java array

* [JVM] load so and compute on cpu

* [JVM] move PackedFunc to individual modules

* [JVM] assembly package & native library loader

* [JVM] unit test & codestyle check settings

* [JVM] NDArray from & to different dtypes

* [JVM] NDArray from native double array. Add linux-cpu profile.

* [JVM] modify Makefile

* [JVM] add linux-x86_64-gpu profile

* [tvm4j] delay load libtvm_runtime.so

* [tvm4j] refactor to pure java

* [tvm4j] remove scalastyle-config.xml

* [tvm4j] remove link HalideIR, remove Shape, remove scala binary versions

* [tvm4j] only allow convert from/to same type array

* [tvm4j] make NDArray api more readable

* [tvm4j] refactor for c api

* [tvm4j] add Jenkins tests

* [tvm4j] fix duplicate Dockerfile cmd

* [tvm4j] fix ut script filename

* [tvm4j] add module load tests

* [tvm4j] add javadoc, remove types package

* [tvm4j] fix test script

* [tvm4j] remove ut temp dir

* [tvm4j] fix missing package types

* [tvm4j] java code style check

* [tvm4j] fix java lint

* [tvm4j] downgrade checkstyle plugin for JDK7

* [tvm4j] add stylecheck in jenkins tests

* [tvm4j] specify source file encoding

* [tvm4j] lazy init function; add Function.call() api; allow manully release Module,NDArray,Function

* [tvm4j] fix ModFree

* [tvm4j] cache Function in API
parent 86ff24ab
......@@ -104,6 +104,18 @@ nnvm
## IOS
DerivedData/
## Java
*.class
jvm/*/target/
jvm/*/*/target/
*.worksheet
*.idea
*.iml
*.classpath
*.project
*.settings
*/node_modules/
## Various settings
*.pbxuser
!default.pbxuser
......@@ -119,4 +131,5 @@ xcuserdata/
*.moved-aside
*.xccheckout
*.xcscmblueprint
.DS_Store
\ No newline at end of file
.DS_Store
#!groovy
// -*- mode: groovy -*-
// Jenkins pipeline
// See documents at https://jenkins.io/doc/book/pipeline/jenkinsfile/
......@@ -183,6 +184,17 @@ stage('Unit Test') {
}
}
}
},
'java': {
node('GPU' && 'linux') {
ws('workspace/tvm/ut-java') {
init_git()
unpack_lib('gpu', tvm_lib)
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} gpu ./tests/scripts/task_java_unittest.sh"
}
}
}
}
}
......
......@@ -120,6 +120,29 @@ ifdef ADD_LDFLAGS
LDFLAGS += $(ADD_LDFLAGS)
endif
ifeq ($(OS),Windows_NT)
JVM_PKG_PROFILE := windows
else
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
JVM_PKG_PROFILE := osx-x86_64
else
JVM_PKG_PROFILE := linux-x86_64
endif
endif
JVM_TEST_ARGS := $(if $(JVM_TEST_ARGS),$(JVM_TEST_ARGS),-DskipTests -Dcheckstyle.skip=true)
ifeq ($(USE_CUDA), 1)
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-gpu
else ifeq ($(USE_OPENCL), 1)
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-gpu
else ifeq ($(USE_METAL), 1)
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-gpu
else
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-cpu
endif
include tests/cpp/unittest.mk
test: $(TEST)
......@@ -176,7 +199,10 @@ pylint:
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc
lint: cpplint pylint
jnilint:
python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src
lint: cpplint pylint jnilint
doc:
doxygen docs/Doxyfile
......@@ -194,6 +220,12 @@ cython3:
cyclean:
rm -rf python/tvm/*/*/*.so python/tvm/*/*/*.cpp
jvmpkg:
(cd $(ROOTDIR)/jvm; \
mvn clean package -P$(JVM_PKG_PROFILE) -Dcxx="$(CXX)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dcurrent_libdir="$(ROOTDIR)/lib" $(JVM_TEST_ARGS))
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
cd HalideIR; make clean; cd $(ROOTDIR)
......
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-linux-x86_64-cpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full Linux-x86_64 CPU-only</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>libtvm4j-linux-x86_64-cpu</artifactId>
<version>${project.version}</version>
<type>so</type>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<appendAssemblyId>false</appendAssemblyId>
<descriptors>
<descriptor>src/main/assembly/assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<assembly>
<id>full</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<files>
<file>
<source>../../../lib/libtvm_runtime.so</source>
<outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode>
</file>
</files>
<dependencySets>
<dependencySet>
<includes>
<include>*:*:jar</include>
</includes>
<outputDirectory>/</outputDirectory>
<useProjectArtifact>true</useProjectArtifact>
<unpack>true</unpack>
<scope>runtime</scope>
</dependencySet>
<dependencySet>
<outputDirectory>lib/native</outputDirectory>
<outputFileNameMapping>libtvm4j.so</outputFileNameMapping>
<unpack>false</unpack>
<useProjectArtifact>false</useProjectArtifact>
<useStrictFiltering>false</useStrictFiltering>
<includes>
<include>ml.dmlc.tvm:libtvm4j-linux-x86_64-cpu:so</include>
</includes>
</dependencySet>
</dependencySets>
</assembly>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-linux-x86_64-gpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full Linux-x86_64 GPU</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>libtvm4j-linux-x86_64-gpu</artifactId>
<version>${project.version}</version>
<type>so</type>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<appendAssemblyId>false</appendAssemblyId>
<descriptors>
<descriptor>src/main/assembly/assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<assembly>
<id>full</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<files>
<file>
<source>../../../lib/libtvm_runtime.so</source>
<outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode>
</file>
</files>
<dependencySets>
<dependencySet>
<includes>
<include>*:*:jar</include>
</includes>
<outputDirectory>/</outputDirectory>
<useProjectArtifact>true</useProjectArtifact>
<unpack>true</unpack>
<scope>runtime</scope>
</dependencySet>
<dependencySet>
<outputDirectory>lib/native</outputDirectory>
<outputFileNameMapping>libtvm4j.so</outputFileNameMapping>
<unpack>false</unpack>
<useProjectArtifact>false</useProjectArtifact>
<useStrictFiltering>false</useStrictFiltering>
<includes>
<include>ml.dmlc.tvm:libtvm4j-linux-x86_64-gpu:so</include>
</includes>
</dependencySet>
</dependencySets>
</assembly>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-osx-x86_64-cpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full OSX-x86_64 CPU-only</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>libtvm4j-osx-x86_64-cpu</artifactId>
<version>${project.version}</version>
<type>jnilib</type>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<appendAssemblyId>false</appendAssemblyId>
<descriptors>
<descriptor>src/main/assembly/assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<assembly>
<id>full</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<files>
<file>
<source>../../../lib/libtvm_runtime.so</source>
<outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode>
</file>
</files>
<dependencySets>
<dependencySet>
<includes>
<include>*:*:jar</include>
</includes>
<outputDirectory>/</outputDirectory>
<useProjectArtifact>true</useProjectArtifact>
<unpack>true</unpack>
<scope>runtime</scope>
</dependencySet>
<dependencySet>
<outputDirectory>lib/native</outputDirectory>
<outputFileNameMapping>libtvm4j.jnilib</outputFileNameMapping>
<unpack>false</unpack>
<useProjectArtifact>false</useProjectArtifact>
<useStrictFiltering>false</useStrictFiltering>
<includes>
<include>ml.dmlc.tvm:libtvm4j-osx-x86_64-cpu:jnilib</include>
</includes>
</dependencySet>
</dependencySets>
</assembly>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full Parent</name>
<packaging>pom</packaging>
<profiles>
<profile>
<id>osx-x86_64-cpu</id>
<modules>
<module>osx-x86_64-cpu</module>
</modules>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<modules>
<module>linux-x86_64-cpu</module>
</modules>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<modules>
<module>linux-x86_64-gpu</module>
</modules>
</profile>
<profile>
<id>release</id>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>jar-no-fork</goal>
</goals>
<configuration>
<includePom>true</includePom>>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<includeDependencySources>true</includeDependencySources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<serverId>ossrh</serverId>
<nexusUrl>https://oss.sonatype.org/</nexusUrl>
<autoReleaseAfterClose>true</autoReleaseAfterClose>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Core</name>
<profiles>
<profile>
<id>osx-x86_64-cpu</id>
<properties>
<platform>osx-x86_64-cpu</platform>
</properties>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<properties>
<platform>linux-x86_64-cpu</platform>
</properties>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<properties>
<platform>linux-x86_64-gpu</platform>
</properties>
</profile>
</profiles>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<version>2.17</version>
<dependencies>
<dependency>
<groupId>com.puppycrawl.tools</groupId>
<artifactId>checkstyle</artifactId>
<version>6.12</version>
</dependency>
</dependencies>
<executions>
<execution>
<phase>process-sources</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
<configuration>
<failsOnError>true</failsOnError>
<configLocation>${project.parent.basedir}/conf/google_checks.xml</configLocation>
<consoleOutput>true</consoleOutput>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.7</version>
<configuration>
<forkCount>1</forkCount>
<reuseForks>true</reuseForks>
<threadCount>1</threadCount>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/${platform}/target
-Dlibtvm.so.path=${project.parent.basedir}/../lib/libtvm_runtime.so
</argLine>
</configuration>
<executions>
<execution>
<id>test</id>
<!--
We put it in the integration-test phase,
because the test suites require the jni library,
which means, in order to run the unit tests,
we have to compile all the modules first.
-->
<phase>integration-test</phase>
<goals>
<goal>test</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.HashMap;
import java.util.Map;
/**
* TVM API functions.
*/
public final class API {
private static ThreadLocal<Map<String, Function>> apiFuncs
= new ThreadLocal<Map<String, Function>>() {
@Override
protected Map<String, Function> initialValue() {
return new HashMap<String, Function>();
}
};
/**
* Get a tvm api function according by name.
* @param name function name.
* @return a TVM Function.
*/
public static Function get(final String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction(name);
apiFuncs.get().put(name, func);
}
return func;
}
/**
* Cannot be instantiated.
*/
private API() {
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
/**
* Internal api functions.
*/
public final class APIInternal {
/**
* Get a tvm api function according by name.
* @param name function name.
* @return a TVM Function.
*/
public static Function get(final String name) {
return API.get(name);
}
/**
* Cannot be instantiated.
*/
private APIInternal() {
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import ml.dmlc.tvm.NativeLibraryLoader.Action;
import java.io.File;
import java.io.IOException;
/**
* Initializing methods and types.
*/
final class Base {
/**
* Hold Long reference for JNI.
*/
public static class RefLong {
public final long value;
public RefLong(final long value) {
this.value = value;
}
public RefLong() {
this(0L);
}
}
/**
* Hold TVMValue reference for JNI.
*/
public static class RefTVMValue {
public final TVMValue value;
public RefTVMValue(TVMValue value) {
this.value = value;
}
public RefTVMValue() {
this(null);
}
}
public static final LibInfo _LIB = new LibInfo();
static {
try {
try {
tryLoadLibraryOS("tvm4j");
} catch (UnsatisfiedLinkError e) {
System.err.println("[WARN] TVM native library not found in path. "
+ "Copying native library from the archive. "
+ "Consider installing the library somewhere in the path "
+ "(for Windows: PATH, for Linux: LD_LIBRARY_PATH), "
+ "or specifying by Java cmd option -Djava.library.path=[lib path].");
NativeLibraryLoader.loadLibrary("tvm4j");
}
} catch (Throwable e) {
System.err.println("[ERROR] Couldn't find native library tvm4j");
throw new RuntimeException(e);
}
String tvmLibFilename = System.getProperty("libtvm.so.path");
if (tvmLibFilename == null || !new File(tvmLibFilename).isFile()
|| _LIB.nativeLibInit(tvmLibFilename) != 0) {
try {
NativeLibraryLoader.extractResourceFileToTempDir("libtvm_runtime.so", new Action() {
@Override public void invoke(File target) {
System.err.println("Loading tvm runtime from " + target.getPath());
checkCall(_LIB.nativeLibInit(target.getPath()));
}
});
} catch (IOException e) {
throw new RuntimeException(e);
}
}
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override public void run() {
_LIB.shutdown();
}
});
}
/**
* Load JNI for different OS.
* @param libname library name.
* @throws UnsatisfiedLinkError if loading fails.
*/
private static void tryLoadLibraryOS(String libname) throws UnsatisfiedLinkError {
try {
System.err.println(String.format("Try loading %s from native path.", libname));
System.loadLibrary(libname);
} catch (UnsatisfiedLinkError e) {
String os = System.getProperty("os.name");
// ref: http://lopica.sourceforge.net/os.html
if (os.startsWith("Linux")) {
tryLoadLibraryXPU(libname, "linux-x86_64");
} else if (os.startsWith("Mac")) {
tryLoadLibraryXPU(libname, "osx-x86_64");
} else {
// TODO(yizhi) support windows later
throw new UnsatisfiedLinkError("Windows not supported currently");
}
}
}
/**
* Load native library for different architectures.
* @param libname library name.
* @param arch architecture.
* @throws UnsatisfiedLinkError if loading fails
*/
private static void tryLoadLibraryXPU(String libname, String arch) throws UnsatisfiedLinkError {
try {
// try gpu first
System.err.println(String.format("Try loading %s-%s-gpu from native path.", libname, arch));
System.loadLibrary(String.format("%s-%s-gpu", libname, arch));
} catch (UnsatisfiedLinkError e) {
System.err.println(String.format("Try loading %s-%s-cpu from native path.", libname, arch));
System.loadLibrary(String.format("%s-%s-cpu", libname, arch));
}
}
// helper function definitions
/**
* Check the return value of C API call
* <p>
* This function will raise exception when error occurs.
* Wrap every API call with this function
* </p>
* @param ret return value from API calls
*/
public static void checkCall(int ret) throws TVMError {
if (ret != 0) {
throw new TVMError(_LIB.tvmGetLastError());
}
}
/**
* TVM Runtime error.
*/
static class TVMError extends RuntimeException {
public TVMError(String err) {
super(err);
}
}
/**
* Cannot be instantiated.
*/
private Base() {
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class Function {
final long handle;
public final boolean isResident;
private boolean isReleased = false;
/**
* Get registered function.
* @param name full function name.
* @return TVM function.
*/
static Function getFunction(final String name) {
for (String fullName : listGlobalFuncNames()) {
if (fullName.equals(name)) {
return getGlobalFunc(fullName, true, false);
}
}
return null;
}
/**
* Get list of global functions registered.
* @return List of global functions names.
*/
private static List<String> listGlobalFuncNames() {
List<String> names = new ArrayList<String>();
Base.checkCall(Base._LIB.tvmFuncListGlobalNames(names));
return Collections.unmodifiableList(names);
}
/**
* Get a global function by name.
* @param name The name of the function.
* @param isResident Whether it is a global 'resident' function.
* @param allowMissing Whether allow missing function or raise an error.
* @return The function to be returned, None if function is missing.
*/
private static Function getGlobalFunc(String name, boolean isResident, boolean allowMissing) {
Base.RefLong handle = new Base.RefLong();
Base.checkCall(Base._LIB.tvmFuncGetGlobal(name, handle));
if (handle.value != 0) {
return new Function(handle.value, isResident);
} else {
if (allowMissing) {
return null;
} else {
throw new IllegalArgumentException("Cannot find global function " + name);
}
}
}
/**
* Initialize the function with handle
* @param handle the handle to the underlying function.
* @param isResident Whether this is a resident function in jvm
*/
public Function(long handle, boolean isResident) {
this.handle = handle;
this.isResident = isResident;
}
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Release the Function.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
if (!isReleased) {
if (!isResident) {
Base.checkCall(Base._LIB.tvmFuncFree(handle));
isReleased = true;
}
}
}
/**
* Invoke the function.
* @return the result.
*/
public TVMValue invoke() {
Base.RefTVMValue ret = new Base.RefTVMValue();
Base.checkCall(Base._LIB.tvmFuncCall(handle, ret));
return ret.value;
}
/**
* Push argument to the function.
* @param arg int argument.
* @return this
*/
public Function pushArg(int arg) {
Base._LIB.tvmFuncPushArgLong(arg);
return this;
}
/**
* Push argument to the function.
* @param arg long argument.
* @return this
*/
public Function pushArg(long arg) {
Base._LIB.tvmFuncPushArgLong(arg);
return this;
}
/**
* Push argument to the function.
* @param arg float argument.
* @return this
*/
public Function pushArg(float arg) {
Base._LIB.tvmFuncPushArgDouble(arg);
return this;
}
/**
* Push argument to the function.
* @param arg double argument.
* @return this
*/
public Function pushArg(double arg) {
Base._LIB.tvmFuncPushArgDouble(arg);
return this;
}
/**
* Push argument to the function.
* @param arg String argument.
* @return this
*/
public Function pushArg(String arg) {
Base._LIB.tvmFuncPushArgString(arg);
return this;
}
/**
* Push argument to the function.
* @param arg NDArray.
* @return this
*/
public Function pushArg(NDArray arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.ARRAY_HANDLE.id);
return this;
}
/**
* Invoke function with arguments.
* @param args Can be Integer, Long, Float, Double, String, NDArray.
* @return the result.
*/
public TVMValue call(Object... args) {
for (Object arg : args) {
if (arg instanceof Integer) {
pushArg((Integer) arg);
} else if (arg instanceof Long) {
pushArg((Long) arg);
} else if (arg instanceof Float) {
pushArg((Float) arg);
} else if (arg instanceof Double) {
pushArg((Double) arg);
} else if (arg instanceof String) {
pushArg((String) arg);
} else if (arg instanceof NDArray) {
pushArg((NDArray) arg);
} else {
throw new IllegalArgumentException("Invalid argument: " + arg);
}
}
return invoke();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.List;
class LibInfo {
public native int nativeLibInit(String tvmLibFile);
public native int shutdown();
public native String tvmGetLastError();
// Function
public native void tvmFuncPushArgLong(long arg);
public native void tvmFuncPushArgDouble(double arg);
public native void tvmFuncPushArgString(String arg);
public native void tvmFuncPushArgHandle(long arg, int argType);
public native int tvmFuncListGlobalNames(List<String> funcNames);
public native int tvmFuncFree(long handle);
public native int tvmFuncGetGlobal(String name, Base.RefLong handle);
public native int tvmFuncCall(long handle, Base.RefTVMValue retVal);
// Module
public native int tvmModFree(long handle);
public native int tvmModGetFunction(long handle, String name,
int queryImports, Base.RefLong retHandle);
public native int tvmModImport(long mod, long dep);
// NDArray
public native int tvmArrayFree(long handle);
public native int tvmArrayAlloc(long[] shape,
int dtypeCode,
int dtypeBits,
int dtypeLanes,
int deviceType,
int deviceId,
Base.RefLong refHandle);
public native int tvmArrayGetShape(long handle, List<Long> shape);
public native int tvmArrayCopyFromTo(long from, long to);
public native int tvmArrayCopyFromJArray(byte[] fromRaw, long from, long to);
public native int tvmArrayCopyToJArray(long from, byte[] to);
// TVMContext
public native int tvmSynchronize(int deviceType, int deviceId);
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.HashMap;
import java.util.Map;
/**
* Container of compiled functions of TVM.
*/
public class Module {
public final long handle;
private boolean isReleased = false;
private static ThreadLocal<Map<String, Function>> apiFuncs
= new ThreadLocal<Map<String, Function>>() {
@Override
protected Map<String, Function> initialValue() {
return new HashMap<String, Function>();
}
};
private static Function getApi(String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction("module." + name);
apiFuncs.get().put(name, func);
}
return func;
}
public Module(long handle) {
this.handle = handle;
}
private Function entry = null;
private final String entryName = "__tvm_main__";
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Release the Module.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
if (!isReleased) {
Base.checkCall(Base._LIB.tvmModFree(handle));
isReleased = true;
}
}
/**
* Get the entry function.
* @return The entry function if exist
*/
public Function entryFunc() {
if (entry == null) {
entry = getFunction(entryName);
}
return entry;
}
/**
* Get function from the module.
* @param name The name of the function.
* @param queryImports Whether also query modules imported by this module.
* @return The result function.
*/
public Function getFunction(String name, boolean queryImports) {
Base.RefLong retHandle = new Base.RefLong();
Base.checkCall(Base._LIB.tvmModGetFunction(
handle, name, queryImports ? 1 : 0, retHandle));
if (retHandle.value == 0) {
throw new IllegalArgumentException("Module has no function " + name);
}
return new Function(retHandle.value, false);
}
public Function getFunction(String name) {
return getFunction(name, false);
}
/**
* Add module to the import list of current one.
* @param module The other module.
*/
public void importModule(Module module) {
Base.checkCall(Base._LIB.tvmModImport(handle, module.handle));
}
/**
* Load module from file.
* @param path The path to the module file.
* @param fmt The format of the file,
* if not specified it will be inferred from suffix of the file.
* @return The loaded module
*/
public static Module load(String path, String fmt) {
TVMValue ret = getApi("_LoadFromFile").pushArg(path).pushArg(fmt).invoke();
assert ret.typeCode == TypeCode.MODULE_HANDLE;
return ret.asModule();
}
public static Module load(String path) {
return load(path, "");
}
/**
* Whether module runtime is enabled for target,
* e.g., The following code checks if gpu is enabled.
* Module.enabled("gpu")
* @param target The target device type.
* @return Whether runtime is enabled.
*/
public static boolean enabled(String target) {
TVMValue ret = getApi("_Enabled").pushArg(target).invoke();
return ret.asLong() != 0;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
class NativeLibraryLoader {
private static final String libPathInJar = "/lib/native/";
private static File tempDir;
static {
try {
tempDir = File.createTempFile("tvm", "");
if (!tempDir.delete() || !tempDir.mkdir()) {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}
/*
* Different cleanup strategies for Windows and Linux.
* TODO: shutdown hook won't work on Windows
*/
if (!"Windows".equals(getUnifiedOSName())) {
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override public void run() {
for (File f : tempDir.listFiles()) {
System.err.println("Deleting " + f.getAbsolutePath());
if (!f.delete()) {
System.err.println("[WARN] Couldn't delete temporary file " + f.getAbsolutePath());
}
}
System.err.println("Deleting " + tempDir.getAbsolutePath());
if (!tempDir.delete()) {
System.err.println(
"[WARN] Couldn't delete temporary directory " + tempDir.getAbsolutePath());
}
}
});
} else {
throw new RuntimeException("Windows not supported yet.");
}
} catch (IOException ex) {
System.err.println("Couldn't create temporary directory: " + ex.getMessage());
throw new RuntimeException(ex);
}
}
/**
* Find the library as a resource in jar, copy it to a tempfile
* and load it using System.load(). The name of the library has to be the
* base name, it is mapped to the corresponding system name using
* System.mapLibraryName(). e.g., the library "foo" is called "libfoo.so"
* under Linux and "foo.dll" under Windows, but you just have to pass "foo" to
* the loadLibrary().
*
* @param libname basename of the library
* @throws UnsatisfiedLinkError if library not found.
* @throws IOException if file not found.
*/
public static void loadLibrary(String libname) throws UnsatisfiedLinkError, IOException {
String mappedLibname = System.mapLibraryName(libname);
String loadLibname = mappedLibname;
if (mappedLibname.endsWith("dylib")) {
System.err.println("Replaced .dylib with .jnilib");
loadLibname = mappedLibname.replace(".dylib", ".jnilib");
}
System.err.println("Attempting to load " + loadLibname);
extractResourceFileToTempDir(loadLibname, new Action() {
@Override public void invoke(File target) {
System.err.println("Loading library from " + target.getPath());
System.load(target.getPath());
}
});
}
/**
* Translate all those Windows to "Windows". ("Windows XP", "Windows Vista", "Windows 7", etc.)
*/
private static String unifyOSName(String osname) {
if (osname.startsWith("Windows")) {
return "Windows";
}
return osname;
}
private static String getUnifiedOSName() {
return unifyOSName(System.getProperty("os.name"));
}
private static File createTempFile(String name) throws IOException {
return new File(tempDir + File.separator + name);
}
static interface Action {
public void invoke(File file);
}
/**
* Copies the resource file to a temp file and do an action.
* @param filename source file name (in lib/native).
* @param action callback function to deal with the copied file.
*/
public static void extractResourceFileToTempDir(String filename, Action action)
throws IOException {
final String libFileInJar = libPathInJar + filename;
InputStream is = NativeLibraryLoader.class.getResourceAsStream(libFileInJar);
if (is == null) {
throw new UnsatisfiedLinkError("Couldn't find the resource " + filename);
}
System.err.println(String.format("Loading %s from %s", filename, libPathInJar));
try {
File tempfile = createTempFile(filename);
OutputStream os = new FileOutputStream(tempfile);
final long savedTime = System.currentTimeMillis();
byte[] buf = new byte[8192];
int len = is.read(buf);
while (len > 0) {
os.write(buf, 0, len);
len = is.read(buf);
}
os.flush();
final FileInputStream lock = new FileInputStream(tempfile);
os.close();
double seconds = (double) (System.currentTimeMillis() - savedTime) / 1e3;
System.err.println(String.format("Copying took %.2f seconds.", seconds));
action.invoke(tempfile);
lock.close();
} catch (IOException io) {
System.err.println("[ERROR] Could not create the temp file: " + io.toString());
throw io;
} catch (UnsatisfiedLinkError ule) {
System.err.println("Couldn't load copied link file: " + ule.toString());
throw ule;
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.HashMap;
import java.util.Map;
public class TVMContext {
private static final int RPC_SESS_MASK = 128;
private static final Map<Integer, String> MASK2STR = new HashMap<Integer, String>();
private static final Map<String, Integer> STR2MASK = new HashMap<String, Integer>();
static {
MASK2STR.put(1, "cpu");
MASK2STR.put(2, "gpu");
MASK2STR.put(4, "opencl");
MASK2STR.put(8, "metal");
MASK2STR.put(9, "vpi");
STR2MASK.put("cpu", 1);
STR2MASK.put("gpu", 2);
STR2MASK.put("cuda", 2);
STR2MASK.put("cl", 4);
STR2MASK.put("opencl", 4);
STR2MASK.put("metal", 8);
STR2MASK.put("vpi", 9);
}
/**
* Construct a CPU device.
* @param devId The device id
* @return The created context
*/
public static TVMContext cpu(int devId) {
return new TVMContext(1, devId);
}
public static TVMContext cpu() {
return cpu(0);
}
/**
* Construct a GPU device.
* @param devId The device id
* @return The created context
*/
public static TVMContext gpu(int devId) {
return new TVMContext(2, devId);
}
public static TVMContext gpu() {
return gpu(0);
}
/**
* Construct a OpenCL device.
* @param devId The device id
* @return The created context
*/
public static TVMContext opencl(int devId) {
return new TVMContext(4, devId);
}
public static TVMContext opencl() {
return opencl(0);
}
/**
* Construct a metal device.
* @param devId The device id
* @return The created context
*/
public static TVMContext metal(int devId) {
return new TVMContext(8, devId);
}
public static TVMContext metal() {
return metal(0);
}
/**
* Construct a VPI simulated device.
* @param devId The device id
* @return The created context
*/
public static TVMContext vpi(int devId) {
return new TVMContext(9, devId);
}
public static TVMContext vpi() {
return vpi(0);
}
public final int deviceType;
public final int deviceId;
public TVMContext(int deviceType, int deviceId) {
this.deviceType = deviceType;
this.deviceId = deviceId;
}
public TVMContext(String deviceType, int deviceId) {
this(STR2MASK.get(deviceType), deviceId);
}
/**
* Whether this device exists.
* @return true if exists.
*/
public boolean exist() {
TVMValue ret = APIInternal.get("_GetDeviceAttr")
.pushArg(deviceType).pushArg(deviceId).pushArg(0).invoke();
return ((TVMValueLong) ret).value != 0;
}
/**
* Maximum number of threads on each block.
* @return the maximum thread number.
*/
public long maxThreadsPerBlock() {
TVMValue ret = APIInternal.get("_GetDeviceAttr")
.pushArg(deviceType).pushArg(deviceId).pushArg(1).invoke();
return ((TVMValueLong) ret).value;
}
/**
* Number of threads that executes in concurrent.
* @return the thread number.
*/
public long warpSize() {
TVMValue ret = APIInternal.get("_GetDeviceAttr")
.pushArg(deviceType).pushArg(deviceId).pushArg(2).invoke();
return ((TVMValueLong) ret).value;
}
/**
* Synchronize until jobs finished at the context.
*/
public void sync() {
Base.checkCall(Base._LIB.tvmSynchronize(deviceType, deviceId));
}
@Override public int hashCode() {
return (deviceType << 16) | deviceId;
}
@Override public boolean equals(Object other) {
if (other != null && other instanceof TVMContext) {
TVMContext obj = (TVMContext) other;
return deviceId == obj.deviceId && deviceType == obj.deviceType;
}
return false;
}
@Override public String toString() {
if (deviceType >= RPC_SESS_MASK) {
int tblId = deviceType / RPC_SESS_MASK - 1;
int devType = deviceType % RPC_SESS_MASK;
return String.format("remote[%d]:%s(%d)", tblId, MASK2STR.get(devType), deviceId);
}
return String.format("%s(%d)", MASK2STR.get(deviceType), deviceId);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMType {
public static final int INT = 0;
public static final int UINT = 1;
public static final int FLOAT = 2;
public static final int HANDLE = 4;
public final int typeCode;
public final int bits;
public final int numOfBytes;
public final int lanes;
/**
* TVMType constructor.
* @param typeStr type name, e.g., "float32", "float64", "uint8", etc.
* @param lanes NDArray lanes.
*/
public TVMType(String typeStr, int lanes) {
this.lanes = lanes;
int bitsTemp = 0;
if (typeStr.startsWith("int")) {
typeCode = 0;
bitsTemp = Integer.parseInt(typeStr.substring(3));
} else if (typeStr.startsWith("uint")) {
typeCode = 1;
bitsTemp = Integer.parseInt(typeStr.substring(4));
} else if (typeStr.startsWith("float")) {
typeCode = 2;
bitsTemp = Integer.parseInt(typeStr.substring(5));
} else if (typeStr.startsWith("handle")) {
typeCode = 4;
bitsTemp = 64;
} else {
throw new IllegalArgumentException("Do not know how to handle type " + typeStr);
}
bits = (bitsTemp == 0) ? 32 : bitsTemp;
if ((bits & (bits - 1)) != 0 || bits < 8) {
throw new IllegalArgumentException("Do not know how to handle type " + typeStr);
}
numOfBytes = bits / 8;
}
public TVMType(String typeStr) {
this(typeStr, 1);
}
@Override public int hashCode() {
return (typeCode << 16) | (bits << 8) | lanes;
}
@Override public boolean equals(Object other) {
if (other != null && other instanceof TVMType) {
TVMType otherInst = (TVMType) other;
return (bits == otherInst.bits)
&& (typeCode == otherInst.typeCode) && (lanes == otherInst.lanes);
}
return false;
}
@Override public String toString() {
String typeCodeStr;
switch (typeCode) {
case 0:
typeCodeStr = "int";
break;
case 1:
typeCodeStr = "uint";
break;
case 2:
typeCodeStr = "float";
break;
case 4:
typeCodeStr = "handle";
break;
default:
typeCodeStr = "Unknown";
break;
}
String str = typeCodeStr + bits;
if (lanes != 1) {
str += lanes;
}
return str;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValue {
public final TypeCode typeCode;
public TVMValue(TypeCode tc) {
typeCode = tc;
}
public long asLong() {
throw new UnsupportedOperationException();
}
public double asDouble() {
throw new UnsupportedOperationException();
}
public Module asModule() {
throw new UnsupportedOperationException();
}
public NDArray asNDArray() {
throw new UnsupportedOperationException();
}
public String asString() {
throw new UnsupportedOperationException();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueDouble extends TVMValue {
public final double value;
public TVMValueDouble(double value) {
super(TypeCode.FLOAT);
this.value = value;
}
@Override public double asDouble() {
return value;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueLong extends TVMValue {
public final long value;
public TVMValueLong(long value) {
super(TypeCode.INT);
this.value = value;
}
@Override public long asLong() {
return value;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueModuleHandle extends TVMValue {
public final long value;
public TVMValueModuleHandle(long value) {
super(TypeCode.MODULE_HANDLE);
this.value = value;
}
@Override public Module asModule() {
return new Module(value);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueNDArrayHandle extends TVMValue {
public final long value;
public TVMValueNDArrayHandle(long value) {
super(TypeCode.ARRAY_HANDLE);
this.value = value;
}
@Override public NDArray asNDArray() {
return new NDArray(value);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueNull extends TVMValue {
public TVMValueNull() {
super(TypeCode.NULL);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueString extends TVMValue {
public final String value;
public TVMValueString(String value) {
super(TypeCode.STR);
this.value = value;
}
@Override public String asString() {
return value;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
// Type code used in API calls
public enum TypeCode {
INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5),
TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9),
FUNC_HANDLE(10), STR(11), BYTES(12);
public final int id;
private TypeCode(int id) {
this.id = id;
}
@Override
public String toString() {
return String.valueOf(id);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;
import java.io.File;
import java.util.Random;
public class ModuleTest {
private static String loadingDir;
@BeforeClass
public static void beforeClass() {
loadingDir = System.getProperty("test.tempdir");
}
@Test
public void test_load_add_func_cpu() {
Module fadd = Module.load(loadingDir + File.separator + "add_cpu.so");
TVMContext ctx = new TVMContext("cpu", 0);
long[] shape = new long[]{2};
NDArray arr = NDArray.empty(shape, ctx);
arr.copyFrom(new float[]{3f, 4f});
NDArray res = NDArray.empty(shape, ctx);
fadd.entryFunc().pushArg(arr).pushArg(arr).pushArg(res).invoke();
assertArrayEquals(new float[]{6f, 8f}, res.asFloatArray(), 1e-3f);
// test call() api
fadd.entryFunc().call(arr, arr, res);
assertArrayEquals(new float[]{6f, 8f}, res.asFloatArray(), 1e-3f);
arr.release();
res.release();
fadd.release();
}
@Test
public void test_load_add_func_gpu() {
final Random RND = new Random(0);
Module fadd = Module.load(loadingDir + File.separator + "add_gpu.so");
Module faddDev = Module.load(loadingDir + File.separator + "add_gpu.ptx");
fadd.importModule(faddDev);
TVMContext ctx = new TVMContext("gpu", 0);
final int dim = 100;
long[] shape = new long[]{dim};
NDArray arr = NDArray.empty(shape, ctx);
float[] data = new float[dim];
float[] dataX2 = new float[dim];
for (int i = 0; i < dim; ++i) {
data[i] = RND.nextFloat();
dataX2[i] = data[i] * 2;
}
arr.copyFrom(data);
NDArray res = NDArray.empty(shape, ctx);
fadd.entryFunc().pushArg(arr).pushArg(arr).pushArg(res).invoke();
assertArrayEquals(dataX2, res.asFloatArray(), 1e-3f);
arr.release();
res.release();
faddDev.release();
fadd.release();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import org.junit.Test;
import static org.junit.Assert.*;
public class NDArrayTest {
@Test
public void test_from_float32() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("float32"));
ndarray.copyFrom(new float[]{1, 2, 3, 4});
assertArrayEquals(new float[]{1f, 2f, 3f, 4f}, ndarray.asFloatArray(), 1e-3f);
ndarray.release();
}
@Test
public void test_from_float64() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("float64"));
ndarray.copyFrom(new double[]{1, 2, 3, 4});
assertArrayEquals(new double[]{1.0, 2.0, 3.0, 4.0}, ndarray.asDoubleArray(), 1e-3);
ndarray.release();
}
@Test
public void test_from_int8() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int8"));
ndarray.copyFrom(new byte[]{1, 2, 3, 4});
assertArrayEquals(new byte[]{1, 2, 3, 4}, ndarray.asByteArray());
ndarray.release();
}
@Test
public void test_from_int16() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int16"));
ndarray.copyFrom(new short[]{1, 2, 3, 4});
assertArrayEquals(new short[]{1, 2, 3, 4}, ndarray.asShortArray());
ndarray.release();
}
@Test
public void test_from_int32() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int32"));
ndarray.copyFrom(new int[]{1, 2, 3, 4});
assertArrayEquals(new int[]{1, 2, 3, 4}, ndarray.asIntArray());
ndarray.release();
}
@Test
public void test_from_int64() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int64"));
ndarray.copyFrom(new long[]{1, 2, 3, 4});
assertArrayEquals(new long[]{1, 2, 3, 4}, ndarray.asLongArray());
ndarray.release();
}
@Test
public void test_from_uint16() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("uint16"));
ndarray.copyFrom(new char[]{65535, 2, 3, 4});
assertArrayEquals(new char[]{65535, 2, 3, 4}, ndarray.asCharArray());
ndarray.release();
}
}
import os
import tvm
from tvm.contrib import cc_compiler as cc
from tvm.contrib import util
def test_add(target_dir):
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
s = tvm.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], "llvm", target_host="llvm", name="myadd")
fadd.save(os.path.join(target_dir, "add_cpu.o"))
cc.create_shared(os.path.join(target_dir, "add_cpu.so"),
[os.path.join(target_dir, "add_cpu.o")])
if __name__ == "__main__":
import sys
if len(sys.argv) != 2:
sys.exit(-1)
test_add(sys.argv[1])
import os
import tvm
from tvm.contrib import cc_compiler as cc
from tvm.contrib import util
def test_add(target_dir):
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
s = tvm.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
fadd_cuda.save(os.path.join(target_dir, "add_gpu.o"))
fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx"))
cc.create_shared(os.path.join(target_dir, "add_gpu.so"),
[os.path.join(target_dir, "add_gpu.o")])
if __name__ == "__main__":
import sys
if len(sys.argv) != 2:
sys.exit(-1)
test_add(sys.argv[1])
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-native-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>libtvm4j-linux-x86_64-cpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Native Linux-x86_64 CPU-only</name>
<url>http://maven.apache.org</url>
<packaging>so</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
<type>jar</type>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>native-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<!-- trigger javah -->
<javahOS>linux</javahOS>
<compilerProvider>generic-classic</compilerProvider>
<compilerExecutable>${cxx}</compilerExecutable>
<linkerExecutable>${cxx}</linkerExecutable>
<sources>
<source>
<directory>../src/main/native</directory>
<fileNames>
<fileName>ml_dmlc_tvm_native_c_api.cc</fileName>
</fileNames>
</source>
</sources>
<compilerStartOptions>
<compilerStartOption>-std=c++0x</compilerStartOption>
</compilerStartOptions>
<compilerEndOptions>
<compilerEndOption>-I../../../include</compilerEndOption>
<compilerEndOption>${cflags}</compilerEndOption>
</compilerEndOptions>
<linkerStartOptions>
<linkerStartOption>-shared</linkerStartOption>
</linkerStartOptions>
<linkerEndOptions>
<linkerEndOption>${ldflags}</linkerEndOption>
</linkerEndOptions>
</configuration>
<executions>
<execution>
<id>javah</id>
<phase>generate-sources</phase>
<configuration>
<javahOS>linux</javahOS>
<javahProvider>default</javahProvider>
<javahOutputDirectory>${project.build.directory}/custom-javah</javahOutputDirectory>
<workingDirectory>${basedir}</workingDirectory>
<javahOutputFileName>ml_dmlc_tvm_native_c_api.h</javahOutputFileName>
<javahClassNames>
<javahClassName>ml.dmlc.tvm.LibInfo</javahClassName>
</javahClassNames>
</configuration>
<goals>
<goal>javah</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-native-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>libtvm4j-linux-x86_64-gpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Native Linux-x86_64 GPU</name>
<url>http://maven.apache.org</url>
<packaging>so</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
<type>jar</type>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>native-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<!-- trigger javah -->
<javahOS>linux</javahOS>
<compilerProvider>generic-classic</compilerProvider>
<compilerExecutable>${cxx}</compilerExecutable>
<linkerExecutable>${cxx}</linkerExecutable>
<sources>
<source>
<directory>../src/main/native</directory>
<fileNames>
<fileName>ml_dmlc_tvm_native_c_api.cc</fileName>
</fileNames>
</source>
</sources>
<compilerStartOptions>
<compilerStartOption>-std=c++0x</compilerStartOption>
</compilerStartOptions>
<compilerEndOptions>
<compilerEndOption>-I../../../include</compilerEndOption>
<compilerEndOption>${cflags}</compilerEndOption>
</compilerEndOptions>
<linkerStartOptions>
<linkerStartOption>-shared</linkerStartOption>
</linkerStartOptions>
<linkerEndOptions>
<linkerEndOption>${ldflags}</linkerEndOption>
</linkerEndOptions>
</configuration>
<executions>
<execution>
<id>javah</id>
<phase>generate-sources</phase>
<configuration>
<javahOS>linux</javahOS>
<javahProvider>default</javahProvider>
<javahOutputDirectory>${project.build.directory}/custom-javah</javahOutputDirectory>
<workingDirectory>${basedir}</workingDirectory>
<javahOutputFileName>ml_dmlc_tvm_native_c_api.h</javahOutputFileName>
<javahClassNames>
<javahClassName>ml.dmlc.tvm.LibInfo</javahClassName>
</javahClassNames>
</configuration>
<goals>
<goal>javah</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-native-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>libtvm4j-osx-x86_64-cpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Native OSX-x86_64 CPU-only</name>
<url>http://maven.apache.org</url>
<packaging>jnilib</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
<type>jar</type>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>native-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<!-- trigger javah -->
<javahOS>darwin</javahOS>
<compilerProvider>generic-classic</compilerProvider>
<compilerExecutable>${cxx}</compilerExecutable>
<linkerExecutable>${cxx}</linkerExecutable>
<sources>
<source>
<directory>../src/main/native</directory>
<fileNames>
<fileName>ml_dmlc_tvm_native_c_api.cc</fileName>
</fileNames>
</source>
</sources>
<compilerStartOptions>
<compilerStartOption>-std=c++0x</compilerStartOption>
</compilerStartOptions>
<compilerEndOptions>
<compilerEndOption>-I../../../include</compilerEndOption>
<compilerEndOption>${cflags}</compilerEndOption>
</compilerEndOptions>
<linkerStartOptions>
<linkerStartOption>-shared</linkerStartOption>
</linkerStartOptions>
<linkerMiddleOptions>
<linkerMiddleOption>-framework JavaVM</linkerMiddleOption>
<linkerMiddleOption>-Wl,-exported_symbol,_Java_*</linkerMiddleOption>
<linkerMiddleOption>-undefined dynamic_lookup</linkerMiddleOption>
<linkerMiddleOption>-Wl,-x</linkerMiddleOption>
</linkerMiddleOptions>
<linkerEndOptions>
<linkerEndOption>${ldflags}</linkerEndOption>
</linkerEndOptions>
</configuration>
<executions>
<execution>
<id>javah</id>
<phase>generate-sources</phase>
<configuration>
<javahOS>darwin</javahOS>
<javahProvider>default</javahProvider>
<javahOutputDirectory>${project.build.directory}/custom-javah</javahOutputDirectory>
<workingDirectory>${basedir}</workingDirectory>
<javahOutputFileName>ml_dmlc_tvm_native_c_api.h</javahOutputFileName>
<javahClassNames>
<javahClassName>ml.dmlc.tvm.LibInfo</javahClassName>
</javahClassNames>
</configuration>
<goals>
<goal>javah</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>tvm4j-native-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Native Parent</name>
<packaging>pom</packaging>
<profiles>
<profile>
<id>osx-x86_64-cpu</id>
<modules>
<module>osx-x86_64-cpu</module>
</modules>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<modules>
<module>linux-x86_64-cpu</module>
</modules>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<modules>
<module>linux-x86_64-gpu</module>
</modules>
</profile>
</profiles>
</project>
/*!
* Copyright (c) 2017 by Contributors
* \file jni_helper_func.h
* \brief Helper functions for operating JVM objects
*/
#include <jni.h>
#ifndef TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
#define TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
// Helper functions for RefXXX getter & setter
jlong getLongField(JNIEnv *env, jobject obj) {
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
jlong ret = env->GetLongField(obj, refFid);
env->DeleteLocalRef(refClass);
return ret;
}
jint getIntField(JNIEnv *env, jobject obj) {
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
jint ret = env->GetIntField(obj, refFid);
env->DeleteLocalRef(refClass);
return ret;
}
void setIntField(JNIEnv *env, jobject obj, jint value) {
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
env->SetIntField(obj, refFid, value);
env->DeleteLocalRef(refClass);
}
void setLongField(JNIEnv *env, jobject obj, jlong value) {
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
env->SetLongField(obj, refFid, value);
env->DeleteLocalRef(refClass);
}
void setStringField(JNIEnv *env, jobject obj, const char *value) {
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefString");
jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;");
env->SetObjectField(obj, refFid, env->NewStringUTF(value));
env->DeleteLocalRef(refClass);
}
// Helper functions for TVMValue
jlong getTVMValueLongField(JNIEnv *env, jobject obj,
const char *clsname = "ml/dmlc/tvm/TVMValueLong") {
jclass cls = env->FindClass(clsname);
jfieldID fid = env->GetFieldID(cls, "value", "J");
jlong ret = env->GetLongField(obj, fid);
env->DeleteLocalRef(cls);
return ret;
}
jdouble getTVMValueDoubleField(JNIEnv *env, jobject obj) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueDouble");
jfieldID fid = env->GetFieldID(cls, "value", "D");
jdouble ret = env->GetDoubleField(obj, fid);
env->DeleteLocalRef(cls);
return ret;
}
jstring getTVMValueStringField(JNIEnv *env, jobject obj) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueString");
jfieldID fid = env->GetFieldID(cls, "value", "Ljava/lang/String;");
jstring ret = static_cast<jstring>(env->GetObjectField(obj, fid));
env->DeleteLocalRef(cls);
return ret;
}
jobject newTVMValueLong(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueLong");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}
jobject newTVMValueDouble(JNIEnv *env, jdouble value) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueDouble");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(D)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}
jobject newTVMValueModuleHandle(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueModuleHandle");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}
jobject newObject(JNIEnv *env, const char *clsname) {
jclass cls = env->FindClass(clsname);
jmethodID constructor = env->GetMethodID(cls, "<init>", "()V");
jobject object = env->NewObject(cls, constructor);
env->DeleteLocalRef(cls);
return object;
}
void fromJavaDType(JNIEnv *env, jobject jdtype, TVMType *dtype) {
jclass tvmTypeClass = env->FindClass("ml/dmlc/tvm/TVMType");
dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I")));
dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I")));
dtype->lanes = (uint16_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "lanes", "I")));
env->DeleteLocalRef(tvmTypeClass);
}
void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) {
jclass tvmContextClass = env->FindClass("ml/dmlc/tvm/TVMContext");
ctx->device_type = static_cast<DLDeviceType>(env->GetIntField(jctx,
env->GetFieldID(tvmContextClass, "deviceType", "I")));
ctx->device_id = static_cast<int>(env->GetIntField(jctx,
env->GetFieldID(tvmContextClass, "deviceId", "I")));
env->DeleteLocalRef(tvmContextClass);
}
#endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Parent</name>
<url>https://github.com/dmlc/tvm/tree/master/jvm</url>
<description>TVM4J Package</description>
<organization>
<name>Distributed (Deep) Machine Learning Community</name>
<url>http://dmlc.ml</url>
</organization>
<licenses>
<license>
<name>The Apache License, Version 2.0</name>
<url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
</license>
</licenses>
<scm>
<connection>scm:git:git@github.com:dmlc/tvm.git</connection>
<developerConnection>scm:git:git@github.com:dmlc/tvm.git</developerConnection>
<url>https://github.com/dmlc/tvm</url>
</scm>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<packaging>pom</packaging>
<modules>
<module>core</module>
<module>native</module>
<module>assembly</module>
</modules>
<profiles>
<profile>
<id>release</id>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>${project.build.directory}/genjavadoc</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<serverId>ossrh</serverId>
<nexusUrl>https://oss.sonatype.org/</nexusUrl>
<autoReleaseAfterClose>true</autoReleaseAfterClose>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<distributionManagement>
<snapshotRepository>
<id>ossrh</id>
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
</snapshotRepository>
</distributionManagement>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>2.2.1</version>
<executions>
<execution>
<phase>package</phase>
<id>attach-sources</id>
<goals>
<goal>jar-no-fork</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>native-maven-plugin</artifactId>
<version>1.0-alpha-7</version>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>2.7</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<version>2.9</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.0.2</version>
<executions>
<execution>
<id>empty-javadoc-jar</id>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<includes>
<include>**/*</include>
</includes>
<classifier>javadoc</classifier>
<classesDirectory>${basedir}/javadoc</classesDirectory>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<version>2.5.5</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-deploy-plugin</artifactId>
<version>2.8.2</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-install-plugin</artifactId>
<version>2.5.2</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.3</version>
<configuration>
<source>1.6</source>
<target>1.6</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.9.1</version>
<executions>
<execution>
<id>attach-javadocs</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
</dependencies>
</project>
#!/bin/bash
export PYTHONPATH=python:apps/extension/python
export PYTHONPATH=${PYTHONPATH}:apps/graph_executor/python:apps/graph_executor/nnvm/python
export LD_LIBRARY_PATH=lib:${LD_LIBRARY_PATH}
CURR_DIR=$(cd `dirname $0`; pwd)
SCRIPT_DIR=$CURR_DIR/../../jvm/core/src/test/scripts
TEMP_DIR=$(mktemp -d)
python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1
python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1
make jvmpkg || exit -1
make jvmpkg JVM_TEST_ARGS="-DskipTests=false -Dtest.tempdir=$TEMP_DIR" || exit -1
rm -rf $TEMP_DIR
......@@ -3,6 +3,8 @@ echo "Check codestyle of c++ code..."
make cpplint || exit -1
echo "Check codestyle of python code..."
make pylint || exit -1
echo "Check codestyle of jni code..."
make jnilint || exit -1
echo "Check documentations of c++ code..."
make doc 2>log.txt
(cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment