Commit d3efd7fc by Yizhi Liu Committed by Tianqi Chen

[WIP][Frontend] Scala/Java package (#176)

* JVM package skeleton

* [JVM] link libtvm.so and list function names

* [JVM] Function & NDArray skeleton

* [JVM] TVMFuncCall in JNI

* [JVM] handle string arg in TVMFuncCall

* [JVM] get module function

* [JVM] entry function for Module

* [JVM] construct Module from function return value

* [JVM] TVMContext, TVMArray attributes

* [JVM] NDArray from / to java array

* [JVM] load so and compute on cpu

* [JVM] move PackedFunc to individual modules

* [JVM] assembly package & native library loader

* [JVM] unit test & codestyle check settings

* [JVM] NDArray from & to different dtypes

* [JVM] NDArray from native double array. Add linux-cpu profile.

* [JVM] modify Makefile

* [JVM] add linux-x86_64-gpu profile

* [tvm4j] delay load libtvm_runtime.so

* [tvm4j] refactor to pure java

* [tvm4j] remove scalastyle-config.xml

* [tvm4j] remove link HalideIR, remove Shape, remove scala binary versions

* [tvm4j] only allow convert from/to same type array

* [tvm4j] make NDArray api more readable

* [tvm4j] refactor for c api

* [tvm4j] add Jenkins tests

* [tvm4j] fix duplicate Dockerfile cmd

* [tvm4j] fix ut script filename

* [tvm4j] add module load tests

* [tvm4j] add javadoc, remove types package

* [tvm4j] fix test script

* [tvm4j] remove ut temp dir

* [tvm4j] fix missing package types

* [tvm4j] java code style check

* [tvm4j] fix java lint

* [tvm4j] downgrade checkstyle plugin for JDK7

* [tvm4j] add stylecheck in jenkins tests

* [tvm4j] specify source file encoding

* [tvm4j] lazy init function; add Function.call() api; allow manully release Module,NDArray,Function

* [tvm4j] fix ModFree

* [tvm4j] cache Function in API
parent 86ff24ab
......@@ -104,6 +104,18 @@ nnvm
## IOS
DerivedData/
## Java
*.class
jvm/*/target/
jvm/*/*/target/
*.worksheet
*.idea
*.iml
*.classpath
*.project
*.settings
*/node_modules/
## Various settings
*.pbxuser
!default.pbxuser
......@@ -119,4 +131,5 @@ xcuserdata/
*.moved-aside
*.xccheckout
*.xcscmblueprint
.DS_Store
\ No newline at end of file
.DS_Store
#!groovy
// -*- mode: groovy -*-
// Jenkins pipeline
// See documents at https://jenkins.io/doc/book/pipeline/jenkinsfile/
......@@ -183,6 +184,17 @@ stage('Unit Test') {
}
}
}
},
'java': {
node('GPU' && 'linux') {
ws('workspace/tvm/ut-java') {
init_git()
unpack_lib('gpu', tvm_lib)
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} gpu ./tests/scripts/task_java_unittest.sh"
}
}
}
}
}
......
......@@ -120,6 +120,29 @@ ifdef ADD_LDFLAGS
LDFLAGS += $(ADD_LDFLAGS)
endif
ifeq ($(OS),Windows_NT)
JVM_PKG_PROFILE := windows
else
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
JVM_PKG_PROFILE := osx-x86_64
else
JVM_PKG_PROFILE := linux-x86_64
endif
endif
JVM_TEST_ARGS := $(if $(JVM_TEST_ARGS),$(JVM_TEST_ARGS),-DskipTests -Dcheckstyle.skip=true)
ifeq ($(USE_CUDA), 1)
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-gpu
else ifeq ($(USE_OPENCL), 1)
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-gpu
else ifeq ($(USE_METAL), 1)
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-gpu
else
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-cpu
endif
include tests/cpp/unittest.mk
test: $(TEST)
......@@ -176,7 +199,10 @@ pylint:
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc
lint: cpplint pylint
jnilint:
python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src
lint: cpplint pylint jnilint
doc:
doxygen docs/Doxyfile
......@@ -194,6 +220,12 @@ cython3:
cyclean:
rm -rf python/tvm/*/*/*.so python/tvm/*/*/*.cpp
jvmpkg:
(cd $(ROOTDIR)/jvm; \
mvn clean package -P$(JVM_PKG_PROFILE) -Dcxx="$(CXX)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dcurrent_libdir="$(ROOTDIR)/lib" $(JVM_TEST_ARGS))
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
cd HalideIR; make clean; cd $(ROOTDIR)
......
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-linux-x86_64-cpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full Linux-x86_64 CPU-only</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>libtvm4j-linux-x86_64-cpu</artifactId>
<version>${project.version}</version>
<type>so</type>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<appendAssemblyId>false</appendAssemblyId>
<descriptors>
<descriptor>src/main/assembly/assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<assembly>
<id>full</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<files>
<file>
<source>../../../lib/libtvm_runtime.so</source>
<outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode>
</file>
</files>
<dependencySets>
<dependencySet>
<includes>
<include>*:*:jar</include>
</includes>
<outputDirectory>/</outputDirectory>
<useProjectArtifact>true</useProjectArtifact>
<unpack>true</unpack>
<scope>runtime</scope>
</dependencySet>
<dependencySet>
<outputDirectory>lib/native</outputDirectory>
<outputFileNameMapping>libtvm4j.so</outputFileNameMapping>
<unpack>false</unpack>
<useProjectArtifact>false</useProjectArtifact>
<useStrictFiltering>false</useStrictFiltering>
<includes>
<include>ml.dmlc.tvm:libtvm4j-linux-x86_64-cpu:so</include>
</includes>
</dependencySet>
</dependencySets>
</assembly>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-linux-x86_64-gpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full Linux-x86_64 GPU</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>libtvm4j-linux-x86_64-gpu</artifactId>
<version>${project.version}</version>
<type>so</type>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<appendAssemblyId>false</appendAssemblyId>
<descriptors>
<descriptor>src/main/assembly/assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<assembly>
<id>full</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<files>
<file>
<source>../../../lib/libtvm_runtime.so</source>
<outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode>
</file>
</files>
<dependencySets>
<dependencySet>
<includes>
<include>*:*:jar</include>
</includes>
<outputDirectory>/</outputDirectory>
<useProjectArtifact>true</useProjectArtifact>
<unpack>true</unpack>
<scope>runtime</scope>
</dependencySet>
<dependencySet>
<outputDirectory>lib/native</outputDirectory>
<outputFileNameMapping>libtvm4j.so</outputFileNameMapping>
<unpack>false</unpack>
<useProjectArtifact>false</useProjectArtifact>
<useStrictFiltering>false</useStrictFiltering>
<includes>
<include>ml.dmlc.tvm:libtvm4j-linux-x86_64-gpu:so</include>
</includes>
</dependencySet>
</dependencySets>
</assembly>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-osx-x86_64-cpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full OSX-x86_64 CPU-only</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>libtvm4j-osx-x86_64-cpu</artifactId>
<version>${project.version}</version>
<type>jnilib</type>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<appendAssemblyId>false</appendAssemblyId>
<descriptors>
<descriptor>src/main/assembly/assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<assembly>
<id>full</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<files>
<file>
<source>../../../lib/libtvm_runtime.so</source>
<outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode>
</file>
</files>
<dependencySets>
<dependencySet>
<includes>
<include>*:*:jar</include>
</includes>
<outputDirectory>/</outputDirectory>
<useProjectArtifact>true</useProjectArtifact>
<unpack>true</unpack>
<scope>runtime</scope>
</dependencySet>
<dependencySet>
<outputDirectory>lib/native</outputDirectory>
<outputFileNameMapping>libtvm4j.jnilib</outputFileNameMapping>
<unpack>false</unpack>
<useProjectArtifact>false</useProjectArtifact>
<useStrictFiltering>false</useStrictFiltering>
<includes>
<include>ml.dmlc.tvm:libtvm4j-osx-x86_64-cpu:jnilib</include>
</includes>
</dependencySet>
</dependencySets>
</assembly>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full Parent</name>
<packaging>pom</packaging>
<profiles>
<profile>
<id>osx-x86_64-cpu</id>
<modules>
<module>osx-x86_64-cpu</module>
</modules>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<modules>
<module>linux-x86_64-cpu</module>
</modules>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<modules>
<module>linux-x86_64-gpu</module>
</modules>
</profile>
<profile>
<id>release</id>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>jar-no-fork</goal>
</goals>
<configuration>
<includePom>true</includePom>>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<includeDependencySources>true</includeDependencySources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<serverId>ossrh</serverId>
<nexusUrl>https://oss.sonatype.org/</nexusUrl>
<autoReleaseAfterClose>true</autoReleaseAfterClose>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>
<?xml version="1.0"?>
<!DOCTYPE module PUBLIC
"-//Puppy Crawl//DTD Check Configuration 1.3//EN"
"http://www.puppycrawl.com/dtds/configuration_1_3.dtd">
<!--
Checkstyle configuration that checks the Google coding conventions from:
- Google Java Style
https://google-styleguide.googlecode.com/svn-history/r130/trunk/javaguide.html
Checkstyle is very configurable. Be sure to read the documentation at
http://checkstyle.sf.net (or in your downloaded distribution).
Most Checks are configurable, be sure to consult the documentation.
To completely disable a check, just comment it out or delete it from the file.
Authors: Max Vetrenko, Ruslan Diachenko, Roman Ivanov.
-->
<module name = "Checker">
<property name="charset" value="UTF-8"/>
<property name="severity" value="error"/>
<property name="fileExtensions" value="java, properties, xml"/>
<!-- Checks for whitespace -->
<!-- See http://checkstyle.sf.net/config_whitespace.html -->
<module name="FileTabCharacter">
<property name="eachLine" value="true"/>
</module>
<module name="TreeWalker">
<module name="OuterTypeFilename"/>
<module name="IllegalTokenText">
<property name="tokens" value="STRING_LITERAL, CHAR_LITERAL"/>
<property name="format" value="\\u00(08|09|0(a|A)|0(c|C)|0(d|D)|22|27|5(C|c))|\\(0(10|11|12|14|15|42|47)|134)"/>
<property name="message" value="Avoid using corresponding octal or Unicode escape."/>
</module>
<module name="AvoidEscapedUnicodeCharacters">
<property name="allowEscapesForControlCharacters" value="true"/>
<property name="allowByTailComment" value="true"/>
<property name="allowNonPrintableEscapes" value="true"/>
</module>
<module name="LineLength">
<property name="max" value="100"/>
<property name="ignorePattern" value="^package.*|^import.*|a href|href|http://|https://|ftp://"/>
</module>
<module name="AvoidStarImport"/>
<module name="OneTopLevelClass"/>
<module name="NoLineWrap"/>
<module name="EmptyBlock">
<property name="option" value="TEXT"/>
<property name="tokens" value="LITERAL_TRY, LITERAL_FINALLY, LITERAL_IF, LITERAL_ELSE, LITERAL_SWITCH"/>
</module>
<module name="NeedBraces"/>
<module name="LeftCurly">
<property name="maxLineLength" value="100"/>
</module>
<module name="RightCurly"/>
<module name="RightCurly">
<property name="option" value="alone"/>
<property name="tokens" value="CLASS_DEF, METHOD_DEF, CTOR_DEF, LITERAL_FOR, LITERAL_WHILE, LITERAL_DO, STATIC_INIT, INSTANCE_INIT"/>
</module>
<module name="WhitespaceAround">
<property name="allowEmptyConstructors" value="true"/>
<property name="allowEmptyMethods" value="true"/>
<property name="allowEmptyTypes" value="true"/>
<property name="allowEmptyLoops" value="true"/>
<message key="ws.notFollowed"
value="WhitespaceAround: ''{0}'' is not followed by whitespace. Empty blocks may only be represented as '{}' when not part of a multi-block statement (4.1.3)"/>
<message key="ws.notPreceded"
value="WhitespaceAround: ''{0}'' is not preceded with whitespace."/>
</module>
<module name="OneStatementPerLine"/>
<module name="MultipleVariableDeclarations"/>
<module name="ArrayTypeStyle"/>
<module name="MissingSwitchDefault"/>
<module name="FallThrough"/>
<module name="UpperEll"/>
<module name="ModifierOrder"/>
<module name="EmptyLineSeparator">
<property name="allowNoEmptyLineBetweenFields" value="true"/>
</module>
<module name="SeparatorWrap">
<property name="tokens" value="DOT"/>
<property name="option" value="nl"/>
</module>
<module name="SeparatorWrap">
<property name="tokens" value="COMMA"/>
<property name="option" value="EOL"/>
</module>
<module name="PackageName">
<property name="format" value="^[a-z]+(\.[a-z][a-z0-9]*)*$"/>
<message key="name.invalidPattern"
value="Package name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="TypeName">
<message key="name.invalidPattern"
value="Type name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="MemberName">
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9]*$"/>
<message key="name.invalidPattern"
value="Member name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="ParameterName">
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9]*$"/>
<message key="name.invalidPattern"
value="Parameter name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="LocalVariableName">
<property name="tokens" value="VARIABLE_DEF"/>
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9]*$"/>
<property name="allowOneCharVarInForLoop" value="true"/>
<message key="name.invalidPattern"
value="Local variable name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="ClassTypeParameterName">
<property name="format" value="(^[A-Z][0-9]?)$|([A-Z][a-zA-Z0-9]*[T]$)"/>
<message key="name.invalidPattern"
value="Class type name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="MethodTypeParameterName">
<property name="format" value="(^[A-Z][0-9]?)$|([A-Z][a-zA-Z0-9]*[T]$)"/>
<message key="name.invalidPattern"
value="Method type name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="GenericWhitespace">
<message key="ws.followed"
value="GenericWhitespace ''{0}'' is followed by whitespace."/>
<message key="ws.preceded"
value="GenericWhitespace ''{0}'' is preceded with whitespace."/>
<message key="ws.illegalFollow"
value="GenericWhitespace ''{0}'' should followed by whitespace."/>
<message key="ws.notPreceded"
value="GenericWhitespace ''{0}'' is not preceded with whitespace."/>
</module>
<module name="Indentation">
<property name="basicOffset" value="2"/>
<property name="braceAdjustment" value="0"/>
<property name="caseIndent" value="2"/>
<property name="throwsIndent" value="4"/>
<property name="lineWrappingIndentation" value="4"/>
<property name="arrayInitIndent" value="2"/>
</module>
<module name="AbbreviationAsWordInName">
<property name="ignoreFinal" value="false"/>
<property name="allowedAbbreviationLength" value="5"/>
</module>
<module name="OverloadMethodsDeclarationOrder"/>
<module name="VariableDeclarationUsageDistance"/>
<module name="CustomImportOrder">
<property name="specialImportsRegExp" value="com.google"/>
<property name="sortImportsInGroupAlphabetically" value="true"/>
<property name="customImportOrderRules" value="STATIC###SPECIAL_IMPORTS###THIRD_PARTY_PACKAGE###STANDARD_JAVA_PACKAGE"/>
</module>
<module name="MethodParamPad"/>
<module name="OperatorWrap">
<property name="option" value="NL"/>
<property name="tokens" value="BAND, BOR, BSR, BXOR, DIV, EQUAL, GE, GT, LAND, LE, LITERAL_INSTANCEOF, LOR, LT, MINUS, MOD, NOT_EQUAL, PLUS, QUESTION, SL, SR, STAR "/>
</module>
<module name="AnnotationLocation">
<property name="tokens" value="CLASS_DEF, INTERFACE_DEF, ENUM_DEF, METHOD_DEF, CTOR_DEF"/>
</module>
<module name="AnnotationLocation">
<property name="tokens" value="VARIABLE_DEF"/>
<property name="allowSamelineMultipleAnnotations" value="true"/>
</module>
<module name="NonEmptyAtclauseDescription"/>
<module name="JavadocTagContinuationIndentation"/>
<module name="SummaryJavadocCheck">
<property name="forbiddenSummaryFragments" value="^@return the *|^This method returns |^A [{]@code [a-zA-Z0-9]+[}]( is a )"/>
</module>
<module name="JavadocParagraph"/>
<module name="AtclauseOrder">
<property name="tagOrder" value="@param, @return, @throws, @deprecated"/>
<property name="target" value="CLASS_DEF, INTERFACE_DEF, ENUM_DEF, METHOD_DEF, CTOR_DEF, VARIABLE_DEF"/>
</module>
<module name="JavadocMethod">
<property name="scope" value="public"/>
<property name="allowMissingParamTags" value="true"/>
<property name="allowMissingThrowsTags" value="true"/>
<property name="allowMissingReturnTag" value="true"/>
<property name="minLineCount" value="2"/>
<property name="allowedAnnotations" value="Override, Test"/>
<property name="allowThrowsTagsForSubclasses" value="true"/>
</module>
<module name="MethodName">
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9_]*$"/>
<message key="name.invalidPattern"
value="Method name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="SingleLineJavadoc">
<property name="ignoreInlineTags" value="false"/>
</module>
<module name="EmptyCatchBlock">
<property name="exceptionVariableName" value="expected"/>
</module>
<module name="CommentsIndentation"/>
</module>
</module>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Core</name>
<profiles>
<profile>
<id>osx-x86_64-cpu</id>
<properties>
<platform>osx-x86_64-cpu</platform>
</properties>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<properties>
<platform>linux-x86_64-cpu</platform>
</properties>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<properties>
<platform>linux-x86_64-gpu</platform>
</properties>
</profile>
</profiles>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<version>2.17</version>
<dependencies>
<dependency>
<groupId>com.puppycrawl.tools</groupId>
<artifactId>checkstyle</artifactId>
<version>6.12</version>
</dependency>
</dependencies>
<executions>
<execution>
<phase>process-sources</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
<configuration>
<failsOnError>true</failsOnError>
<configLocation>${project.parent.basedir}/conf/google_checks.xml</configLocation>
<consoleOutput>true</consoleOutput>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.7</version>
<configuration>
<forkCount>1</forkCount>
<reuseForks>true</reuseForks>
<threadCount>1</threadCount>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/${platform}/target
-Dlibtvm.so.path=${project.parent.basedir}/../lib/libtvm_runtime.so
</argLine>
</configuration>
<executions>
<execution>
<id>test</id>
<!--
We put it in the integration-test phase,
because the test suites require the jni library,
which means, in order to run the unit tests,
we have to compile all the modules first.
-->
<phase>integration-test</phase>
<goals>
<goal>test</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.HashMap;
import java.util.Map;
/**
* TVM API functions.
*/
public final class API {
private static ThreadLocal<Map<String, Function>> apiFuncs
= new ThreadLocal<Map<String, Function>>() {
@Override
protected Map<String, Function> initialValue() {
return new HashMap<String, Function>();
}
};
/**
* Get a tvm api function according by name.
* @param name function name.
* @return a TVM Function.
*/
public static Function get(final String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction(name);
apiFuncs.get().put(name, func);
}
return func;
}
/**
* Cannot be instantiated.
*/
private API() {
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
/**
* Internal api functions.
*/
public final class APIInternal {
/**
* Get a tvm api function according by name.
* @param name function name.
* @return a TVM Function.
*/
public static Function get(final String name) {
return API.get(name);
}
/**
* Cannot be instantiated.
*/
private APIInternal() {
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import ml.dmlc.tvm.NativeLibraryLoader.Action;
import java.io.File;
import java.io.IOException;
/**
* Initializing methods and types.
*/
final class Base {
/**
* Hold Long reference for JNI.
*/
public static class RefLong {
public final long value;
public RefLong(final long value) {
this.value = value;
}
public RefLong() {
this(0L);
}
}
/**
* Hold TVMValue reference for JNI.
*/
public static class RefTVMValue {
public final TVMValue value;
public RefTVMValue(TVMValue value) {
this.value = value;
}
public RefTVMValue() {
this(null);
}
}
public static final LibInfo _LIB = new LibInfo();
static {
try {
try {
tryLoadLibraryOS("tvm4j");
} catch (UnsatisfiedLinkError e) {
System.err.println("[WARN] TVM native library not found in path. "
+ "Copying native library from the archive. "
+ "Consider installing the library somewhere in the path "
+ "(for Windows: PATH, for Linux: LD_LIBRARY_PATH), "
+ "or specifying by Java cmd option -Djava.library.path=[lib path].");
NativeLibraryLoader.loadLibrary("tvm4j");
}
} catch (Throwable e) {
System.err.println("[ERROR] Couldn't find native library tvm4j");
throw new RuntimeException(e);
}
String tvmLibFilename = System.getProperty("libtvm.so.path");
if (tvmLibFilename == null || !new File(tvmLibFilename).isFile()
|| _LIB.nativeLibInit(tvmLibFilename) != 0) {
try {
NativeLibraryLoader.extractResourceFileToTempDir("libtvm_runtime.so", new Action() {
@Override public void invoke(File target) {
System.err.println("Loading tvm runtime from " + target.getPath());
checkCall(_LIB.nativeLibInit(target.getPath()));
}
});
} catch (IOException e) {
throw new RuntimeException(e);
}
}
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override public void run() {
_LIB.shutdown();
}
});
}
/**
* Load JNI for different OS.
* @param libname library name.
* @throws UnsatisfiedLinkError if loading fails.
*/
private static void tryLoadLibraryOS(String libname) throws UnsatisfiedLinkError {
try {
System.err.println(String.format("Try loading %s from native path.", libname));
System.loadLibrary(libname);
} catch (UnsatisfiedLinkError e) {
String os = System.getProperty("os.name");
// ref: http://lopica.sourceforge.net/os.html
if (os.startsWith("Linux")) {
tryLoadLibraryXPU(libname, "linux-x86_64");
} else if (os.startsWith("Mac")) {
tryLoadLibraryXPU(libname, "osx-x86_64");
} else {
// TODO(yizhi) support windows later
throw new UnsatisfiedLinkError("Windows not supported currently");
}
}
}
/**
* Load native library for different architectures.
* @param libname library name.
* @param arch architecture.
* @throws UnsatisfiedLinkError if loading fails
*/
private static void tryLoadLibraryXPU(String libname, String arch) throws UnsatisfiedLinkError {
try {
// try gpu first
System.err.println(String.format("Try loading %s-%s-gpu from native path.", libname, arch));
System.loadLibrary(String.format("%s-%s-gpu", libname, arch));
} catch (UnsatisfiedLinkError e) {
System.err.println(String.format("Try loading %s-%s-cpu from native path.", libname, arch));
System.loadLibrary(String.format("%s-%s-cpu", libname, arch));
}
}
// helper function definitions
/**
* Check the return value of C API call
* <p>
* This function will raise exception when error occurs.
* Wrap every API call with this function
* </p>
* @param ret return value from API calls
*/
public static void checkCall(int ret) throws TVMError {
if (ret != 0) {
throw new TVMError(_LIB.tvmGetLastError());
}
}
/**
* TVM Runtime error.
*/
static class TVMError extends RuntimeException {
public TVMError(String err) {
super(err);
}
}
/**
* Cannot be instantiated.
*/
private Base() {
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class Function {
final long handle;
public final boolean isResident;
private boolean isReleased = false;
/**
* Get registered function.
* @param name full function name.
* @return TVM function.
*/
static Function getFunction(final String name) {
for (String fullName : listGlobalFuncNames()) {
if (fullName.equals(name)) {
return getGlobalFunc(fullName, true, false);
}
}
return null;
}
/**
* Get list of global functions registered.
* @return List of global functions names.
*/
private static List<String> listGlobalFuncNames() {
List<String> names = new ArrayList<String>();
Base.checkCall(Base._LIB.tvmFuncListGlobalNames(names));
return Collections.unmodifiableList(names);
}
/**
* Get a global function by name.
* @param name The name of the function.
* @param isResident Whether it is a global 'resident' function.
* @param allowMissing Whether allow missing function or raise an error.
* @return The function to be returned, None if function is missing.
*/
private static Function getGlobalFunc(String name, boolean isResident, boolean allowMissing) {
Base.RefLong handle = new Base.RefLong();
Base.checkCall(Base._LIB.tvmFuncGetGlobal(name, handle));
if (handle.value != 0) {
return new Function(handle.value, isResident);
} else {
if (allowMissing) {
return null;
} else {
throw new IllegalArgumentException("Cannot find global function " + name);
}
}
}
/**
* Initialize the function with handle
* @param handle the handle to the underlying function.
* @param isResident Whether this is a resident function in jvm
*/
public Function(long handle, boolean isResident) {
this.handle = handle;
this.isResident = isResident;
}
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Release the Function.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
if (!isReleased) {
if (!isResident) {
Base.checkCall(Base._LIB.tvmFuncFree(handle));
isReleased = true;
}
}
}
/**
* Invoke the function.
* @return the result.
*/
public TVMValue invoke() {
Base.RefTVMValue ret = new Base.RefTVMValue();
Base.checkCall(Base._LIB.tvmFuncCall(handle, ret));
return ret.value;
}
/**
* Push argument to the function.
* @param arg int argument.
* @return this
*/
public Function pushArg(int arg) {
Base._LIB.tvmFuncPushArgLong(arg);
return this;
}
/**
* Push argument to the function.
* @param arg long argument.
* @return this
*/
public Function pushArg(long arg) {
Base._LIB.tvmFuncPushArgLong(arg);
return this;
}
/**
* Push argument to the function.
* @param arg float argument.
* @return this
*/
public Function pushArg(float arg) {
Base._LIB.tvmFuncPushArgDouble(arg);
return this;
}
/**
* Push argument to the function.
* @param arg double argument.
* @return this
*/
public Function pushArg(double arg) {
Base._LIB.tvmFuncPushArgDouble(arg);
return this;
}
/**
* Push argument to the function.
* @param arg String argument.
* @return this
*/
public Function pushArg(String arg) {
Base._LIB.tvmFuncPushArgString(arg);
return this;
}
/**
* Push argument to the function.
* @param arg NDArray.
* @return this
*/
public Function pushArg(NDArray arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.ARRAY_HANDLE.id);
return this;
}
/**
* Invoke function with arguments.
* @param args Can be Integer, Long, Float, Double, String, NDArray.
* @return the result.
*/
public TVMValue call(Object... args) {
for (Object arg : args) {
if (arg instanceof Integer) {
pushArg((Integer) arg);
} else if (arg instanceof Long) {
pushArg((Long) arg);
} else if (arg instanceof Float) {
pushArg((Float) arg);
} else if (arg instanceof Double) {
pushArg((Double) arg);
} else if (arg instanceof String) {
pushArg((String) arg);
} else if (arg instanceof NDArray) {
pushArg((NDArray) arg);
} else {
throw new IllegalArgumentException("Invalid argument: " + arg);
}
}
return invoke();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.List;
class LibInfo {
public native int nativeLibInit(String tvmLibFile);
public native int shutdown();
public native String tvmGetLastError();
// Function
public native void tvmFuncPushArgLong(long arg);
public native void tvmFuncPushArgDouble(double arg);
public native void tvmFuncPushArgString(String arg);
public native void tvmFuncPushArgHandle(long arg, int argType);
public native int tvmFuncListGlobalNames(List<String> funcNames);
public native int tvmFuncFree(long handle);
public native int tvmFuncGetGlobal(String name, Base.RefLong handle);
public native int tvmFuncCall(long handle, Base.RefTVMValue retVal);
// Module
public native int tvmModFree(long handle);
public native int tvmModGetFunction(long handle, String name,
int queryImports, Base.RefLong retHandle);
public native int tvmModImport(long mod, long dep);
// NDArray
public native int tvmArrayFree(long handle);
public native int tvmArrayAlloc(long[] shape,
int dtypeCode,
int dtypeBits,
int dtypeLanes,
int deviceType,
int deviceId,
Base.RefLong refHandle);
public native int tvmArrayGetShape(long handle, List<Long> shape);
public native int tvmArrayCopyFromTo(long from, long to);
public native int tvmArrayCopyFromJArray(byte[] fromRaw, long from, long to);
public native int tvmArrayCopyToJArray(long from, byte[] to);
// TVMContext
public native int tvmSynchronize(int deviceType, int deviceId);
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.HashMap;
import java.util.Map;
/**
* Container of compiled functions of TVM.
*/
public class Module {
public final long handle;
private boolean isReleased = false;
private static ThreadLocal<Map<String, Function>> apiFuncs
= new ThreadLocal<Map<String, Function>>() {
@Override
protected Map<String, Function> initialValue() {
return new HashMap<String, Function>();
}
};
private static Function getApi(String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction("module." + name);
apiFuncs.get().put(name, func);
}
return func;
}
public Module(long handle) {
this.handle = handle;
}
private Function entry = null;
private final String entryName = "__tvm_main__";
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Release the Module.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
if (!isReleased) {
Base.checkCall(Base._LIB.tvmModFree(handle));
isReleased = true;
}
}
/**
* Get the entry function.
* @return The entry function if exist
*/
public Function entryFunc() {
if (entry == null) {
entry = getFunction(entryName);
}
return entry;
}
/**
* Get function from the module.
* @param name The name of the function.
* @param queryImports Whether also query modules imported by this module.
* @return The result function.
*/
public Function getFunction(String name, boolean queryImports) {
Base.RefLong retHandle = new Base.RefLong();
Base.checkCall(Base._LIB.tvmModGetFunction(
handle, name, queryImports ? 1 : 0, retHandle));
if (retHandle.value == 0) {
throw new IllegalArgumentException("Module has no function " + name);
}
return new Function(retHandle.value, false);
}
public Function getFunction(String name) {
return getFunction(name, false);
}
/**
* Add module to the import list of current one.
* @param module The other module.
*/
public void importModule(Module module) {
Base.checkCall(Base._LIB.tvmModImport(handle, module.handle));
}
/**
* Load module from file.
* @param path The path to the module file.
* @param fmt The format of the file,
* if not specified it will be inferred from suffix of the file.
* @return The loaded module
*/
public static Module load(String path, String fmt) {
TVMValue ret = getApi("_LoadFromFile").pushArg(path).pushArg(fmt).invoke();
assert ret.typeCode == TypeCode.MODULE_HANDLE;
return ret.asModule();
}
public static Module load(String path) {
return load(path, "");
}
/**
* Whether module runtime is enabled for target,
* e.g., The following code checks if gpu is enabled.
* Module.enabled("gpu")
* @param target The target device type.
* @return Whether runtime is enabled.
*/
public static boolean enabled(String target) {
TVMValue ret = getApi("_Enabled").pushArg(target).invoke();
return ret.asLong() != 0;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
/**
* Lightweight NDArray class of TVM runtime.
*/
public class NDArray {
public final long handle;
private final boolean isView;
private final TVMType dtype;
private boolean isReleased = false;
NDArray(long handle, boolean isView, TVMType dtype) {
this.handle = handle;
this.isView = isView;
this.dtype = dtype;
}
NDArray(long handle) {
this(handle, false, new TVMType("float32", 1));
}
NDArray(long handle, boolean isView) {
this(handle, isView, new TVMType("float32", 1));
}
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Release the NDArray memory.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
if (!isReleased) {
if (!isView) {
Base.checkCall(Base._LIB.tvmArrayFree(handle));
isReleased = true;
}
}
}
/**
* Copy from a native array.
* The NDArray type must by float64
* @param sourceArray the source data
*/
public void copyFrom(double[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.FLOAT || dtype.bits != 64) {
throw new IllegalArgumentException("Cannot set double[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putDouble(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Copy from a native array.
* The NDArray type must by float32
* @param sourceArray the source data
*/
public void copyFrom(float[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.FLOAT || dtype.bits != 32) {
throw new IllegalArgumentException("Cannot set float[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putFloat(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Copy from a native array.
* The NDArray type must by int64
* @param sourceArray the source data
*/
public void copyFrom(long[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.INT || dtype.bits != 64) {
throw new IllegalArgumentException("Cannot set long[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putLong(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Copy from a native array.
* The NDArray type must by float32
* @param sourceArray the source data
*/
public void copyFrom(int[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.INT || dtype.bits != 32) {
throw new IllegalArgumentException("Cannot set int[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putInt(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Copy from a native array.
* The NDArray type must by int16
* @param sourceArray the source data
*/
public void copyFrom(short[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.INT || dtype.bits != 16) {
throw new IllegalArgumentException("Cannot set short[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putShort(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Copy from a native array.
* The NDArray type must by int8
* @param sourceArray the source data
*/
public void copyFrom(byte[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.INT || dtype.bits != 8) {
throw new IllegalArgumentException("Cannot set byte[] for " + dtype.toString() + " array");
}
copyFromRaw(sourceArray);
}
/**
* Copy from a native array.
* The NDArray type must by uint16
* @param sourceArray the source data
*/
public void copyFrom(char[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.UINT || dtype.bits != 16) {
throw new IllegalArgumentException("Cannot set char[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putChar(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
private void checkCopySize(int sourceLength) {
long arrSize = size();
if (arrSize != sourceLength) {
throw new IllegalArgumentException(String.format("Array shape size not match: %d v.s. %d",
sourceLength, size()));
}
}
/**
* Copy from a raw byte array.
* @param sourceArray the source data
*/
public void copyFromRaw(byte[] sourceArray) {
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(sourceArray, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Get shape of current NDArray.
* @return an array representing shape of current ndarray
*/
public long[] shape() {
List<Long> data = new ArrayList<Long>();
Base.checkCall(Base._LIB.tvmArrayGetShape(handle, data));
long[] shapeArr = new long[data.size()];
for (int i = 0; i < shapeArr.length; ++i) {
shapeArr[i] = data.get(i);
}
return shapeArr;
}
/**
* Get total size of current NDArray.
* @return size of current NDArray.
*/
public long size() {
long product = 1L;
long[] shapeArr = shape();
for (int i = 0; i < shapeArr.length; ++i) {
product *= shapeArr[i];
}
return product;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be float64
* @return A copy of array content.
*/
public double[] asDoubleArray() {
if (dtype.typeCode != TVMType.FLOAT || dtype.bits != 64) {
throw new IllegalArgumentException(
"Cannot set convert to double[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
double[] array = new double[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getDouble();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be float32
* @return A copy of array content.
*/
public float[] asFloatArray() {
if (dtype.typeCode != TVMType.FLOAT || dtype.bits != 32) {
throw new IllegalArgumentException(
"Cannot set convert to float[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
float[] array = new float[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getFloat();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be int64
* @return A copy of array content.
*/
public long[] asLongArray() {
if (dtype.typeCode != TVMType.INT || dtype.bits != 64) {
throw new IllegalArgumentException(
"Cannot set convert to long[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
long[] array = new long[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getLong();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be int32
* @return A copy of array content.
*/
public int[] asIntArray() {
if (dtype.typeCode != TVMType.INT || dtype.bits != 32) {
throw new IllegalArgumentException(
"Cannot set convert to int[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
int[] array = new int[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getInt();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be int16
* @return A copy of array content.
*/
public short[] asShortArray() {
if (dtype.typeCode != TVMType.INT || dtype.bits != 16) {
throw new IllegalArgumentException(
"Cannot set convert to short[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
short[] array = new short[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getShort();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be uint16
* @return A copy of array content.
*/
public char[] asCharArray() {
if (dtype.typeCode != TVMType.UINT || dtype.bits != 16) {
throw new IllegalArgumentException(
"Cannot set convert to char[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
char[] array = new char[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getChar();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be int8
* @return A copy of array content.
*/
public byte[] asByteArray() {
if (dtype.typeCode != TVMType.INT || dtype.bits != 8) {
throw new IllegalArgumentException(
"Cannot set convert to byte[] for " + dtype.toString() + " array");
}
return internal();
}
/**
* Return a copied internal byte array of current array (row-major).
* @return A copy of array content.
*/
public byte[] internal() {
NDArray tmp = NDArray.empty(shape(), dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromTo(handle, tmp.handle));
int arrLength = dtype.numOfBytes * (int) size();
byte[] arr = new byte[arrLength];
Base.checkCall(Base._LIB.tvmArrayCopyToJArray(tmp.handle, arr));
return arr;
}
private byte[][] groupInternalBytes() {
byte[] raw = internal();
int unitSize = dtype.numOfBytes;
if (raw.length <= 0 || raw.length % unitSize != 0) {
throw new IllegalArgumentException(String.format(
"%s size %d cannot divide byte array size %d",
dtype.toString(), unitSize, raw.length));
}
int numOfUnits = raw.length / unitSize;
byte[][] units = new byte[numOfUnits][unitSize];
for (int i = 0; i < numOfUnits; ++i) {
System.arraycopy(raw, i * unitSize, units[i], 0, unitSize);
}
return units;
}
/**
* Create an empty array given shape, type and device.
* @param shape The shape of the array.
* @param dtype The data type of the array.
* @param ctx The context of the array.
* @return The array tvm supported.
*/
public static NDArray empty(long[] shape, TVMType dtype, TVMContext ctx) {
Base.RefLong refHandle = new Base.RefLong();
Base.checkCall(Base._LIB.tvmArrayAlloc(
shape, dtype.typeCode, dtype.bits, dtype.lanes,
ctx.deviceType, ctx.deviceId, refHandle));
return new NDArray(refHandle.value, false, dtype);
}
/**
* Create an empty array on cpu given shape and type.
* @param shape The shape of the array.
* @param dtype The data type of the array.
* @return The array tvm supported.
*/
public static NDArray empty(long[] shape, TVMType dtype) {
return empty(shape, dtype, new TVMContext(1, 0));
}
/**
* Create an empty float32 array on cpu given shape.
* @param shape The shape of the array.
* @return The array tvm supported.
*/
public static NDArray empty(long[] shape) {
return empty(shape, new TVMType("float32", 1), new TVMContext(1, 0));
}
/**
* Create an empty float32 array given shape and device.
* @param shape The shape of the array.
* @param ctx The context of the array.
* @return The array tvm supported.
*/
public static NDArray empty(long[] shape, TVMContext ctx) {
return empty(shape, new TVMType("float32", 1), ctx);
}
private static ByteBuffer wrapBytes(byte[] bytes) {
ByteBuffer bb = ByteBuffer.wrap(bytes);
bb.order(ByteOrder.LITTLE_ENDIAN);
return bb;
}
private static ByteBuffer wrapBytes(byte[] bytes, int offset, int length) {
ByteBuffer bb = ByteBuffer.wrap(bytes, offset, length);
bb.order(ByteOrder.LITTLE_ENDIAN);
return bb;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
class NativeLibraryLoader {
private static final String libPathInJar = "/lib/native/";
private static File tempDir;
static {
try {
tempDir = File.createTempFile("tvm", "");
if (!tempDir.delete() || !tempDir.mkdir()) {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}
/*
* Different cleanup strategies for Windows and Linux.
* TODO: shutdown hook won't work on Windows
*/
if (!"Windows".equals(getUnifiedOSName())) {
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override public void run() {
for (File f : tempDir.listFiles()) {
System.err.println("Deleting " + f.getAbsolutePath());
if (!f.delete()) {
System.err.println("[WARN] Couldn't delete temporary file " + f.getAbsolutePath());
}
}
System.err.println("Deleting " + tempDir.getAbsolutePath());
if (!tempDir.delete()) {
System.err.println(
"[WARN] Couldn't delete temporary directory " + tempDir.getAbsolutePath());
}
}
});
} else {
throw new RuntimeException("Windows not supported yet.");
}
} catch (IOException ex) {
System.err.println("Couldn't create temporary directory: " + ex.getMessage());
throw new RuntimeException(ex);
}
}
/**
* Find the library as a resource in jar, copy it to a tempfile
* and load it using System.load(). The name of the library has to be the
* base name, it is mapped to the corresponding system name using
* System.mapLibraryName(). e.g., the library "foo" is called "libfoo.so"
* under Linux and "foo.dll" under Windows, but you just have to pass "foo" to
* the loadLibrary().
*
* @param libname basename of the library
* @throws UnsatisfiedLinkError if library not found.
* @throws IOException if file not found.
*/
public static void loadLibrary(String libname) throws UnsatisfiedLinkError, IOException {
String mappedLibname = System.mapLibraryName(libname);
String loadLibname = mappedLibname;
if (mappedLibname.endsWith("dylib")) {
System.err.println("Replaced .dylib with .jnilib");
loadLibname = mappedLibname.replace(".dylib", ".jnilib");
}
System.err.println("Attempting to load " + loadLibname);
extractResourceFileToTempDir(loadLibname, new Action() {
@Override public void invoke(File target) {
System.err.println("Loading library from " + target.getPath());
System.load(target.getPath());
}
});
}
/**
* Translate all those Windows to "Windows". ("Windows XP", "Windows Vista", "Windows 7", etc.)
*/
private static String unifyOSName(String osname) {
if (osname.startsWith("Windows")) {
return "Windows";
}
return osname;
}
private static String getUnifiedOSName() {
return unifyOSName(System.getProperty("os.name"));
}
private static File createTempFile(String name) throws IOException {
return new File(tempDir + File.separator + name);
}
static interface Action {
public void invoke(File file);
}
/**
* Copies the resource file to a temp file and do an action.
* @param filename source file name (in lib/native).
* @param action callback function to deal with the copied file.
*/
public static void extractResourceFileToTempDir(String filename, Action action)
throws IOException {
final String libFileInJar = libPathInJar + filename;
InputStream is = NativeLibraryLoader.class.getResourceAsStream(libFileInJar);
if (is == null) {
throw new UnsatisfiedLinkError("Couldn't find the resource " + filename);
}
System.err.println(String.format("Loading %s from %s", filename, libPathInJar));
try {
File tempfile = createTempFile(filename);
OutputStream os = new FileOutputStream(tempfile);
final long savedTime = System.currentTimeMillis();
byte[] buf = new byte[8192];
int len = is.read(buf);
while (len > 0) {
os.write(buf, 0, len);
len = is.read(buf);
}
os.flush();
final FileInputStream lock = new FileInputStream(tempfile);
os.close();
double seconds = (double) (System.currentTimeMillis() - savedTime) / 1e3;
System.err.println(String.format("Copying took %.2f seconds.", seconds));
action.invoke(tempfile);
lock.close();
} catch (IOException io) {
System.err.println("[ERROR] Could not create the temp file: " + io.toString());
throw io;
} catch (UnsatisfiedLinkError ule) {
System.err.println("Couldn't load copied link file: " + ule.toString());
throw ule;
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.HashMap;
import java.util.Map;
public class TVMContext {
private static final int RPC_SESS_MASK = 128;
private static final Map<Integer, String> MASK2STR = new HashMap<Integer, String>();
private static final Map<String, Integer> STR2MASK = new HashMap<String, Integer>();
static {
MASK2STR.put(1, "cpu");
MASK2STR.put(2, "gpu");
MASK2STR.put(4, "opencl");
MASK2STR.put(8, "metal");
MASK2STR.put(9, "vpi");
STR2MASK.put("cpu", 1);
STR2MASK.put("gpu", 2);
STR2MASK.put("cuda", 2);
STR2MASK.put("cl", 4);
STR2MASK.put("opencl", 4);
STR2MASK.put("metal", 8);
STR2MASK.put("vpi", 9);
}
/**
* Construct a CPU device.
* @param devId The device id
* @return The created context
*/
public static TVMContext cpu(int devId) {
return new TVMContext(1, devId);
}
public static TVMContext cpu() {
return cpu(0);
}
/**
* Construct a GPU device.
* @param devId The device id
* @return The created context
*/
public static TVMContext gpu(int devId) {
return new TVMContext(2, devId);
}
public static TVMContext gpu() {
return gpu(0);
}
/**
* Construct a OpenCL device.
* @param devId The device id
* @return The created context
*/
public static TVMContext opencl(int devId) {
return new TVMContext(4, devId);
}
public static TVMContext opencl() {
return opencl(0);
}
/**
* Construct a metal device.
* @param devId The device id
* @return The created context
*/
public static TVMContext metal(int devId) {
return new TVMContext(8, devId);
}
public static TVMContext metal() {
return metal(0);
}
/**
* Construct a VPI simulated device.
* @param devId The device id
* @return The created context
*/
public static TVMContext vpi(int devId) {
return new TVMContext(9, devId);
}
public static TVMContext vpi() {
return vpi(0);
}
public final int deviceType;
public final int deviceId;
public TVMContext(int deviceType, int deviceId) {
this.deviceType = deviceType;
this.deviceId = deviceId;
}
public TVMContext(String deviceType, int deviceId) {
this(STR2MASK.get(deviceType), deviceId);
}
/**
* Whether this device exists.
* @return true if exists.
*/
public boolean exist() {
TVMValue ret = APIInternal.get("_GetDeviceAttr")
.pushArg(deviceType).pushArg(deviceId).pushArg(0).invoke();
return ((TVMValueLong) ret).value != 0;
}
/**
* Maximum number of threads on each block.
* @return the maximum thread number.
*/
public long maxThreadsPerBlock() {
TVMValue ret = APIInternal.get("_GetDeviceAttr")
.pushArg(deviceType).pushArg(deviceId).pushArg(1).invoke();
return ((TVMValueLong) ret).value;
}
/**
* Number of threads that executes in concurrent.
* @return the thread number.
*/
public long warpSize() {
TVMValue ret = APIInternal.get("_GetDeviceAttr")
.pushArg(deviceType).pushArg(deviceId).pushArg(2).invoke();
return ((TVMValueLong) ret).value;
}
/**
* Synchronize until jobs finished at the context.
*/
public void sync() {
Base.checkCall(Base._LIB.tvmSynchronize(deviceType, deviceId));
}
@Override public int hashCode() {
return (deviceType << 16) | deviceId;
}
@Override public boolean equals(Object other) {
if (other != null && other instanceof TVMContext) {
TVMContext obj = (TVMContext) other;
return deviceId == obj.deviceId && deviceType == obj.deviceType;
}
return false;
}
@Override public String toString() {
if (deviceType >= RPC_SESS_MASK) {
int tblId = deviceType / RPC_SESS_MASK - 1;
int devType = deviceType % RPC_SESS_MASK;
return String.format("remote[%d]:%s(%d)", tblId, MASK2STR.get(devType), deviceId);
}
return String.format("%s(%d)", MASK2STR.get(deviceType), deviceId);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMType {
public static final int INT = 0;
public static final int UINT = 1;
public static final int FLOAT = 2;
public static final int HANDLE = 4;
public final int typeCode;
public final int bits;
public final int numOfBytes;
public final int lanes;
/**
* TVMType constructor.
* @param typeStr type name, e.g., "float32", "float64", "uint8", etc.
* @param lanes NDArray lanes.
*/
public TVMType(String typeStr, int lanes) {
this.lanes = lanes;
int bitsTemp = 0;
if (typeStr.startsWith("int")) {
typeCode = 0;
bitsTemp = Integer.parseInt(typeStr.substring(3));
} else if (typeStr.startsWith("uint")) {
typeCode = 1;
bitsTemp = Integer.parseInt(typeStr.substring(4));
} else if (typeStr.startsWith("float")) {
typeCode = 2;
bitsTemp = Integer.parseInt(typeStr.substring(5));
} else if (typeStr.startsWith("handle")) {
typeCode = 4;
bitsTemp = 64;
} else {
throw new IllegalArgumentException("Do not know how to handle type " + typeStr);
}
bits = (bitsTemp == 0) ? 32 : bitsTemp;
if ((bits & (bits - 1)) != 0 || bits < 8) {
throw new IllegalArgumentException("Do not know how to handle type " + typeStr);
}
numOfBytes = bits / 8;
}
public TVMType(String typeStr) {
this(typeStr, 1);
}
@Override public int hashCode() {
return (typeCode << 16) | (bits << 8) | lanes;
}
@Override public boolean equals(Object other) {
if (other != null && other instanceof TVMType) {
TVMType otherInst = (TVMType) other;
return (bits == otherInst.bits)
&& (typeCode == otherInst.typeCode) && (lanes == otherInst.lanes);
}
return false;
}
@Override public String toString() {
String typeCodeStr;
switch (typeCode) {
case 0:
typeCodeStr = "int";
break;
case 1:
typeCodeStr = "uint";
break;
case 2:
typeCodeStr = "float";
break;
case 4:
typeCodeStr = "handle";
break;
default:
typeCodeStr = "Unknown";
break;
}
String str = typeCodeStr + bits;
if (lanes != 1) {
str += lanes;
}
return str;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValue {
public final TypeCode typeCode;
public TVMValue(TypeCode tc) {
typeCode = tc;
}
public long asLong() {
throw new UnsupportedOperationException();
}
public double asDouble() {
throw new UnsupportedOperationException();
}
public Module asModule() {
throw new UnsupportedOperationException();
}
public NDArray asNDArray() {
throw new UnsupportedOperationException();
}
public String asString() {
throw new UnsupportedOperationException();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueDouble extends TVMValue {
public final double value;
public TVMValueDouble(double value) {
super(TypeCode.FLOAT);
this.value = value;
}
@Override public double asDouble() {
return value;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueLong extends TVMValue {
public final long value;
public TVMValueLong(long value) {
super(TypeCode.INT);
this.value = value;
}
@Override public long asLong() {
return value;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueModuleHandle extends TVMValue {
public final long value;
public TVMValueModuleHandle(long value) {
super(TypeCode.MODULE_HANDLE);
this.value = value;
}
@Override public Module asModule() {
return new Module(value);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueNDArrayHandle extends TVMValue {
public final long value;
public TVMValueNDArrayHandle(long value) {
super(TypeCode.ARRAY_HANDLE);
this.value = value;
}
@Override public NDArray asNDArray() {
return new NDArray(value);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueNull extends TVMValue {
public TVMValueNull() {
super(TypeCode.NULL);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
public class TVMValueString extends TVMValue {
public final String value;
public TVMValueString(String value) {
super(TypeCode.STR);
this.value = value;
}
@Override public String asString() {
return value;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
// Type code used in API calls
public enum TypeCode {
INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5),
TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9),
FUNC_HANDLE(10), STR(11), BYTES(12);
public final int id;
private TypeCode(int id) {
this.id = id;
}
@Override
public String toString() {
return String.valueOf(id);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;
import java.io.File;
import java.util.Random;
public class ModuleTest {
private static String loadingDir;
@BeforeClass
public static void beforeClass() {
loadingDir = System.getProperty("test.tempdir");
}
@Test
public void test_load_add_func_cpu() {
Module fadd = Module.load(loadingDir + File.separator + "add_cpu.so");
TVMContext ctx = new TVMContext("cpu", 0);
long[] shape = new long[]{2};
NDArray arr = NDArray.empty(shape, ctx);
arr.copyFrom(new float[]{3f, 4f});
NDArray res = NDArray.empty(shape, ctx);
fadd.entryFunc().pushArg(arr).pushArg(arr).pushArg(res).invoke();
assertArrayEquals(new float[]{6f, 8f}, res.asFloatArray(), 1e-3f);
// test call() api
fadd.entryFunc().call(arr, arr, res);
assertArrayEquals(new float[]{6f, 8f}, res.asFloatArray(), 1e-3f);
arr.release();
res.release();
fadd.release();
}
@Test
public void test_load_add_func_gpu() {
final Random RND = new Random(0);
Module fadd = Module.load(loadingDir + File.separator + "add_gpu.so");
Module faddDev = Module.load(loadingDir + File.separator + "add_gpu.ptx");
fadd.importModule(faddDev);
TVMContext ctx = new TVMContext("gpu", 0);
final int dim = 100;
long[] shape = new long[]{dim};
NDArray arr = NDArray.empty(shape, ctx);
float[] data = new float[dim];
float[] dataX2 = new float[dim];
for (int i = 0; i < dim; ++i) {
data[i] = RND.nextFloat();
dataX2[i] = data[i] * 2;
}
arr.copyFrom(data);
NDArray res = NDArray.empty(shape, ctx);
fadd.entryFunc().pushArg(arr).pushArg(arr).pushArg(res).invoke();
assertArrayEquals(dataX2, res.asFloatArray(), 1e-3f);
arr.release();
res.release();
faddDev.release();
fadd.release();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import org.junit.Test;
import static org.junit.Assert.*;
public class NDArrayTest {
@Test
public void test_from_float32() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("float32"));
ndarray.copyFrom(new float[]{1, 2, 3, 4});
assertArrayEquals(new float[]{1f, 2f, 3f, 4f}, ndarray.asFloatArray(), 1e-3f);
ndarray.release();
}
@Test
public void test_from_float64() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("float64"));
ndarray.copyFrom(new double[]{1, 2, 3, 4});
assertArrayEquals(new double[]{1.0, 2.0, 3.0, 4.0}, ndarray.asDoubleArray(), 1e-3);
ndarray.release();
}
@Test
public void test_from_int8() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int8"));
ndarray.copyFrom(new byte[]{1, 2, 3, 4});
assertArrayEquals(new byte[]{1, 2, 3, 4}, ndarray.asByteArray());
ndarray.release();
}
@Test
public void test_from_int16() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int16"));
ndarray.copyFrom(new short[]{1, 2, 3, 4});
assertArrayEquals(new short[]{1, 2, 3, 4}, ndarray.asShortArray());
ndarray.release();
}
@Test
public void test_from_int32() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int32"));
ndarray.copyFrom(new int[]{1, 2, 3, 4});
assertArrayEquals(new int[]{1, 2, 3, 4}, ndarray.asIntArray());
ndarray.release();
}
@Test
public void test_from_int64() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("int64"));
ndarray.copyFrom(new long[]{1, 2, 3, 4});
assertArrayEquals(new long[]{1, 2, 3, 4}, ndarray.asLongArray());
ndarray.release();
}
@Test
public void test_from_uint16() {
NDArray ndarray = NDArray.empty(new long[]{2, 2}, new TVMType("uint16"));
ndarray.copyFrom(new char[]{65535, 2, 3, 4});
assertArrayEquals(new char[]{65535, 2, 3, 4}, ndarray.asCharArray());
ndarray.release();
}
}
import os
import tvm
from tvm.contrib import cc_compiler as cc
from tvm.contrib import util
def test_add(target_dir):
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
s = tvm.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], "llvm", target_host="llvm", name="myadd")
fadd.save(os.path.join(target_dir, "add_cpu.o"))
cc.create_shared(os.path.join(target_dir, "add_cpu.so"),
[os.path.join(target_dir, "add_cpu.o")])
if __name__ == "__main__":
import sys
if len(sys.argv) != 2:
sys.exit(-1)
test_add(sys.argv[1])
import os
import tvm
from tvm.contrib import cc_compiler as cc
from tvm.contrib import util
def test_add(target_dir):
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
s = tvm.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
fadd_cuda.save(os.path.join(target_dir, "add_gpu.o"))
fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx"))
cc.create_shared(os.path.join(target_dir, "add_gpu.so"),
[os.path.join(target_dir, "add_gpu.o")])
if __name__ == "__main__":
import sys
if len(sys.argv) != 2:
sys.exit(-1)
test_add(sys.argv[1])
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-native-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>libtvm4j-linux-x86_64-cpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Native Linux-x86_64 CPU-only</name>
<url>http://maven.apache.org</url>
<packaging>so</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
<type>jar</type>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>native-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<!-- trigger javah -->
<javahOS>linux</javahOS>
<compilerProvider>generic-classic</compilerProvider>
<compilerExecutable>${cxx}</compilerExecutable>
<linkerExecutable>${cxx}</linkerExecutable>
<sources>
<source>
<directory>../src/main/native</directory>
<fileNames>
<fileName>ml_dmlc_tvm_native_c_api.cc</fileName>
</fileNames>
</source>
</sources>
<compilerStartOptions>
<compilerStartOption>-std=c++0x</compilerStartOption>
</compilerStartOptions>
<compilerEndOptions>
<compilerEndOption>-I../../../include</compilerEndOption>
<compilerEndOption>${cflags}</compilerEndOption>
</compilerEndOptions>
<linkerStartOptions>
<linkerStartOption>-shared</linkerStartOption>
</linkerStartOptions>
<linkerEndOptions>
<linkerEndOption>${ldflags}</linkerEndOption>
</linkerEndOptions>
</configuration>
<executions>
<execution>
<id>javah</id>
<phase>generate-sources</phase>
<configuration>
<javahOS>linux</javahOS>
<javahProvider>default</javahProvider>
<javahOutputDirectory>${project.build.directory}/custom-javah</javahOutputDirectory>
<workingDirectory>${basedir}</workingDirectory>
<javahOutputFileName>ml_dmlc_tvm_native_c_api.h</javahOutputFileName>
<javahClassNames>
<javahClassName>ml.dmlc.tvm.LibInfo</javahClassName>
</javahClassNames>
</configuration>
<goals>
<goal>javah</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-native-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>libtvm4j-linux-x86_64-gpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Native Linux-x86_64 GPU</name>
<url>http://maven.apache.org</url>
<packaging>so</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
<type>jar</type>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>native-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<!-- trigger javah -->
<javahOS>linux</javahOS>
<compilerProvider>generic-classic</compilerProvider>
<compilerExecutable>${cxx}</compilerExecutable>
<linkerExecutable>${cxx}</linkerExecutable>
<sources>
<source>
<directory>../src/main/native</directory>
<fileNames>
<fileName>ml_dmlc_tvm_native_c_api.cc</fileName>
</fileNames>
</source>
</sources>
<compilerStartOptions>
<compilerStartOption>-std=c++0x</compilerStartOption>
</compilerStartOptions>
<compilerEndOptions>
<compilerEndOption>-I../../../include</compilerEndOption>
<compilerEndOption>${cflags}</compilerEndOption>
</compilerEndOptions>
<linkerStartOptions>
<linkerStartOption>-shared</linkerStartOption>
</linkerStartOptions>
<linkerEndOptions>
<linkerEndOption>${ldflags}</linkerEndOption>
</linkerEndOptions>
</configuration>
<executions>
<execution>
<id>javah</id>
<phase>generate-sources</phase>
<configuration>
<javahOS>linux</javahOS>
<javahProvider>default</javahProvider>
<javahOutputDirectory>${project.build.directory}/custom-javah</javahOutputDirectory>
<workingDirectory>${basedir}</workingDirectory>
<javahOutputFileName>ml_dmlc_tvm_native_c_api.h</javahOutputFileName>
<javahClassNames>
<javahClassName>ml.dmlc.tvm.LibInfo</javahClassName>
</javahClassNames>
</configuration>
<goals>
<goal>javah</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-native-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>libtvm4j-osx-x86_64-cpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Native OSX-x86_64 CPU-only</name>
<url>http://maven.apache.org</url>
<packaging>jnilib</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
<type>jar</type>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>native-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<!-- trigger javah -->
<javahOS>darwin</javahOS>
<compilerProvider>generic-classic</compilerProvider>
<compilerExecutable>${cxx}</compilerExecutable>
<linkerExecutable>${cxx}</linkerExecutable>
<sources>
<source>
<directory>../src/main/native</directory>
<fileNames>
<fileName>ml_dmlc_tvm_native_c_api.cc</fileName>
</fileNames>
</source>
</sources>
<compilerStartOptions>
<compilerStartOption>-std=c++0x</compilerStartOption>
</compilerStartOptions>
<compilerEndOptions>
<compilerEndOption>-I../../../include</compilerEndOption>
<compilerEndOption>${cflags}</compilerEndOption>
</compilerEndOptions>
<linkerStartOptions>
<linkerStartOption>-shared</linkerStartOption>
</linkerStartOptions>
<linkerMiddleOptions>
<linkerMiddleOption>-framework JavaVM</linkerMiddleOption>
<linkerMiddleOption>-Wl,-exported_symbol,_Java_*</linkerMiddleOption>
<linkerMiddleOption>-undefined dynamic_lookup</linkerMiddleOption>
<linkerMiddleOption>-Wl,-x</linkerMiddleOption>
</linkerMiddleOptions>
<linkerEndOptions>
<linkerEndOption>${ldflags}</linkerEndOption>
</linkerEndOptions>
</configuration>
<executions>
<execution>
<id>javah</id>
<phase>generate-sources</phase>
<configuration>
<javahOS>darwin</javahOS>
<javahProvider>default</javahProvider>
<javahOutputDirectory>${project.build.directory}/custom-javah</javahOutputDirectory>
<workingDirectory>${basedir}</workingDirectory>
<javahOutputFileName>ml_dmlc_tvm_native_c_api.h</javahOutputFileName>
<javahClassNames>
<javahClassName>ml.dmlc.tvm.LibInfo</javahClassName>
</javahClassNames>
</configuration>
<goals>
<goal>javah</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>tvm4j-native-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Native Parent</name>
<packaging>pom</packaging>
<profiles>
<profile>
<id>osx-x86_64-cpu</id>
<modules>
<module>osx-x86_64-cpu</module>
</modules>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<modules>
<module>linux-x86_64-cpu</module>
</modules>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<modules>
<module>linux-x86_64-gpu</module>
</modules>
</profile>
</profiles>
</project>
/*!
* Copyright (c) 2017 by Contributors
* \file jni_helper_func.h
* \brief Helper functions for operating JVM objects
*/
#include <jni.h>
#ifndef TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
#define TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
// Helper functions for RefXXX getter & setter
jlong getLongField(JNIEnv *env, jobject obj) {
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
jlong ret = env->GetLongField(obj, refFid);
env->DeleteLocalRef(refClass);
return ret;
}
jint getIntField(JNIEnv *env, jobject obj) {
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
jint ret = env->GetIntField(obj, refFid);
env->DeleteLocalRef(refClass);
return ret;
}
void setIntField(JNIEnv *env, jobject obj, jint value) {
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
env->SetIntField(obj, refFid, value);
env->DeleteLocalRef(refClass);
}
void setLongField(JNIEnv *env, jobject obj, jlong value) {
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
env->SetLongField(obj, refFid, value);
env->DeleteLocalRef(refClass);
}
void setStringField(JNIEnv *env, jobject obj, const char *value) {
jclass refClass = env->FindClass("ml/dmlc/tvm/Base$RefString");
jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;");
env->SetObjectField(obj, refFid, env->NewStringUTF(value));
env->DeleteLocalRef(refClass);
}
// Helper functions for TVMValue
jlong getTVMValueLongField(JNIEnv *env, jobject obj,
const char *clsname = "ml/dmlc/tvm/TVMValueLong") {
jclass cls = env->FindClass(clsname);
jfieldID fid = env->GetFieldID(cls, "value", "J");
jlong ret = env->GetLongField(obj, fid);
env->DeleteLocalRef(cls);
return ret;
}
jdouble getTVMValueDoubleField(JNIEnv *env, jobject obj) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueDouble");
jfieldID fid = env->GetFieldID(cls, "value", "D");
jdouble ret = env->GetDoubleField(obj, fid);
env->DeleteLocalRef(cls);
return ret;
}
jstring getTVMValueStringField(JNIEnv *env, jobject obj) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueString");
jfieldID fid = env->GetFieldID(cls, "value", "Ljava/lang/String;");
jstring ret = static_cast<jstring>(env->GetObjectField(obj, fid));
env->DeleteLocalRef(cls);
return ret;
}
jobject newTVMValueLong(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueLong");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}
jobject newTVMValueDouble(JNIEnv *env, jdouble value) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueDouble");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(D)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}
jobject newTVMValueModuleHandle(JNIEnv *env, jlong value) {
jclass cls = env->FindClass("ml/dmlc/tvm/TVMValueModuleHandle");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(J)V");
jobject object = env->NewObject(cls, constructor, value);
env->DeleteLocalRef(cls);
return object;
}
jobject newObject(JNIEnv *env, const char *clsname) {
jclass cls = env->FindClass(clsname);
jmethodID constructor = env->GetMethodID(cls, "<init>", "()V");
jobject object = env->NewObject(cls, constructor);
env->DeleteLocalRef(cls);
return object;
}
void fromJavaDType(JNIEnv *env, jobject jdtype, TVMType *dtype) {
jclass tvmTypeClass = env->FindClass("ml/dmlc/tvm/TVMType");
dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I")));
dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I")));
dtype->lanes = (uint16_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "lanes", "I")));
env->DeleteLocalRef(tvmTypeClass);
}
void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) {
jclass tvmContextClass = env->FindClass("ml/dmlc/tvm/TVMContext");
ctx->device_type = static_cast<DLDeviceType>(env->GetIntField(jctx,
env->GetFieldID(tvmContextClass, "deviceType", "I")));
ctx->device_id = static_cast<int>(env->GetIntField(jctx,
env->GetFieldID(tvmContextClass, "deviceId", "I")));
env->DeleteLocalRef(tvmContextClass);
}
#endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
/*!
* Copyright (c) 2017 by Contributors
* \file ml_dmlc_tvm_native_c_api.cc
* \brief tvm4j jni source file
*/
#include "ml_dmlc_tvm_native_c_api.h" // generated by javah
#include <dlfcn.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h>
#include <iostream>
#include <cstring>
#include <vector>
#include <thread>
#include "jni_helper_func.h"
JavaVM *_jvm;
void *_tvmHandle;
struct TVMFuncArgsThreadLocalEntry {
std::vector<TVMValue> tvmFuncArgValues;
std::vector<int> tvmFuncArgTypes;
// for later release
std::vector<std::pair<jstring, const char *> > tvmFuncArgPushedStrs;
};
typedef dmlc::ThreadLocalStore<TVMFuncArgsThreadLocalEntry> TVMFuncArgsThreadLocalStore;
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_nativeLibInit
(JNIEnv *env, jobject obj, jstring jtvmLibFile) {
if (_tvmHandle == NULL) {
const char *tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0);
_tvmHandle = dlopen(tvmLibFile, RTLD_LAZY | RTLD_GLOBAL);
env->ReleaseStringUTFChars(jtvmLibFile, tvmLibFile);
if (!_tvmHandle) {
fprintf(stderr, "%s\n", dlerror());
return 1;
}
}
return env->GetJavaVM(&_jvm);
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_shutdown(JNIEnv *env, jobject obj) {
if (_tvmHandle) {
dlclose(_tvmHandle);
}
return 0;
}
JNIEXPORT jstring JNICALL Java_ml_dmlc_tvm_LibInfo_tvmGetLastError(JNIEnv * env, jobject obj) {
return env->NewStringUTF(TVMGetLastError());
}
// Function
JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgLong(
JNIEnv *env, jobject obj, jlong arg) {
TVMValue value;
value.v_int64 = static_cast<int64_t>(arg);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kInt);
}
JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble(
JNIEnv *env, jobject obj, jdouble arg) {
TVMValue value;
value.v_float64 = static_cast<double>(arg);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kFloat);
}
JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgString(
JNIEnv *env, jobject obj, jstring arg) {
TVMValue value;
jstring garg = reinterpret_cast<jstring>(env->NewGlobalRef(arg));
value.v_str = env->GetStringUTFChars(garg, 0);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kStr);
// release string args later
e->tvmFuncArgPushedStrs.push_back(std::make_pair(garg, value.v_str));
}
JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgHandle(
JNIEnv *env, jobject obj, jlong arg, jint argType) {
TVMValue value;
value.v_handle = reinterpret_cast<void *>(arg);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(static_cast<int>(argType));
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncListGlobalNames(
JNIEnv *env, jobject obj, jobject jfuncNames) {
int outSize;
const char **outArray;
int ret = TVMFuncListGlobalNames(&outSize, &outArray);
if (ret) {
return ret;
}
jclass arrayClass = env->FindClass("java/util/List");
jmethodID arrayAppend = env->GetMethodID(arrayClass, "add", "(Ljava/lang/Object;)Z");
// fill names
for (int i = 0; i < outSize; ++i) {
jstring jname = env->NewStringUTF(outArray[i]);
env->CallObjectMethod(jfuncNames, arrayAppend, jname);
env->DeleteLocalRef(jname);
}
env->DeleteLocalRef(arrayClass);
return ret;
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncFree(
JNIEnv *env, jobject obj, jlong jhandle) {
return TVMFuncFree(reinterpret_cast<TVMFunctionHandle>(jhandle));
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncGetGlobal(
JNIEnv *env, jobject obj, jstring jname, jobject jhandle) {
TVMFunctionHandle handle;
const char *name = env->GetStringUTFChars(jname, 0);
int ret = TVMFuncGetGlobal(name, &handle);
env->ReleaseStringUTFChars(jname, name);
setLongField(env, jhandle, reinterpret_cast<jlong>(handle));
return ret;
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncCall(
JNIEnv *env, jobject obj, jlong jhandle, jobject jretVal) {
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
int numArgs = e->tvmFuncArgValues.size();
TVMValue retVal;
int retTypeCode;
int ret = TVMFuncCall(reinterpret_cast<TVMFunctionHandle>(jhandle),
&e->tvmFuncArgValues[0], &e->tvmFuncArgTypes[0], numArgs, &retVal, &retTypeCode);
for (auto iter = e->tvmFuncArgPushedStrs.cbegin();
iter != e->tvmFuncArgPushedStrs.cend(); iter++) {
env->ReleaseStringUTFChars(iter->first, iter->second);
env->DeleteGlobalRef(iter->first);
}
e->tvmFuncArgPushedStrs.clear();
e->tvmFuncArgTypes.clear();
e->tvmFuncArgValues.clear();
// return TVMValue object to Java
jclass refTVMValueCls = env->FindClass("ml/dmlc/tvm/Base$RefTVMValue");
jfieldID refTVMValueFid
= env->GetFieldID(refTVMValueCls, "value", "Lml/dmlc/tvm/TVMValue;");
switch (retTypeCode) {
case kInt:
env->SetObjectField(jretVal, refTVMValueFid,
newTVMValueLong(env, static_cast<jlong>(retVal.v_int64)));
break;
case kFloat:
env->SetObjectField(jretVal, refTVMValueFid,
newTVMValueDouble(env, static_cast<jdouble>(retVal.v_float64)));
break;
case kModuleHandle:
env->SetObjectField(jretVal, refTVMValueFid,
newTVMValueModuleHandle(env, reinterpret_cast<jlong>(retVal.v_handle)));
break;
case kNull:
env->SetObjectField(jretVal, refTVMValueFid,
newObject(env, "ml/dmlc/tvm/TVMValueNull"));
break;
default:
LOG(FATAL) << "Do NOT know how to handle return type code " << retTypeCode;
}
env->DeleteLocalRef(refTVMValueCls);
return ret;
}
// Module
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmModFree(
JNIEnv *env, jobject obj, jlong jhandle) {
return TVMModFree(reinterpret_cast<TVMModuleHandle>(jhandle));
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmModImport(
JNIEnv *env, jobject obj, jlong jmod, jlong jdep) {
return TVMModImport(reinterpret_cast<TVMModuleHandle>(jmod),
reinterpret_cast<TVMModuleHandle>(jdep));
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmModGetFunction(
JNIEnv *env, jobject obj, jlong jhandle, jstring jname, jint jimport, jobject jret) {
TVMFunctionHandle retFunc;
const char *name = env->GetStringUTFChars(jname, 0);
int ret = TVMModGetFunction(reinterpret_cast<TVMFunctionHandle>(jhandle),
name,
reinterpret_cast<int>(jimport),
&retFunc);
env->ReleaseStringUTFChars(jname, name);
setLongField(env, jret, reinterpret_cast<jlong>(retFunc));
return ret;
}
// NDArray
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmArrayFree(
JNIEnv *env, jobject obj, jlong jhandle) {
return TVMArrayFree(reinterpret_cast<TVMArrayHandle>(jhandle));
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmArrayAlloc(
JNIEnv *env, jobject obj, jlongArray jshape, jint jdtypeCode,
jint jdtypeBits, jint jdtypeLanes, jint jdeviceType, jint jdeviceId, jobject jret) {
int ndim = static_cast<int>(env->GetArrayLength(jshape));
TVMArrayHandle out;
jlong *shapeArray = env->GetLongArrayElements(jshape, NULL);
int ret = TVMArrayAlloc(
reinterpret_cast<const tvm_index_t*>(shapeArray),
ndim,
static_cast<int>(jdtypeCode),
static_cast<int>(jdtypeBits),
static_cast<int>(jdtypeLanes),
static_cast<int>(jdeviceType),
static_cast<int>(jdeviceId),
&out);
env->ReleaseLongArrayElements(jshape, shapeArray, 0);
setLongField(env, jret, reinterpret_cast<jlong>(out));
return ret;
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmArrayGetShape(
JNIEnv *env, jobject obj, jlong jhandle, jobject jshape) {
TVMArray *array = reinterpret_cast<TVMArray *>(jhandle);
int64_t *shape = array->shape;
int ndim = array->ndim;
// fill shape buffer
jclass longClass = env->FindClass("java/lang/Long");
jmethodID newLong = env->GetMethodID(longClass, "<init>", "(J)V");
jclass arrayClass = env->FindClass("java/util/List");
jmethodID arrayAppend = env->GetMethodID(arrayClass, "add", "(Ljava/lang/Object;)Z");
for (int i = 0; i < ndim; ++i) {
jobject data = env->NewObject(longClass, newLong, static_cast<jlong>(shape[i]));
env->CallObjectMethod(jshape, arrayAppend, data);
env->DeleteLocalRef(data);
}
env->DeleteLocalRef(longClass);
env->DeleteLocalRef(arrayClass);
return 0;
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmArrayCopyFromTo(
JNIEnv *env, jobject obj, jlong jfrom, jlong jto) {
return TVMArrayCopyFromTo(reinterpret_cast<TVMArrayHandle>(jfrom),
reinterpret_cast<TVMArrayHandle>(jto), NULL);
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmArrayCopyFromJArray(
JNIEnv *env, jobject obj, jbyteArray jarr, jlong jfrom, jlong jto) {
jbyte *data = env->GetByteArrayElements(jarr, NULL);
TVMArray *from = reinterpret_cast<TVMArray *>(jfrom);
from->data = static_cast<void *>(data);
int ret = TVMArrayCopyFromTo(static_cast<TVMArrayHandle>(from),
reinterpret_cast<TVMArrayHandle>(jto), NULL);
from->data = NULL;
env->ReleaseByteArrayElements(jarr, data, 0);
return ret;
}
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmArrayCopyToJArray(
JNIEnv *env, jobject obj, jlong jfrom, jbyteArray jarr) {
TVMArray *from = reinterpret_cast<TVMArray *>(jfrom);
int size = static_cast<int>(env->GetArrayLength(jarr));
jbyte *pdata = env->GetByteArrayElements(jarr, NULL);
int ret = 0;
if (memcpy(static_cast<void *>(pdata), from->data, size) == NULL) {
ret = 1;
}
env->ReleaseByteArrayElements(jarr, pdata, 0); // copy back to java array automatically
return ret;
}
// Context
JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmSynchronize(
JNIEnv *env, jint deviceType, jint deviceId) {
return TVMSynchronize(static_cast<int>(deviceType), static_cast<int>(deviceId), NULL);
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Parent</name>
<url>https://github.com/dmlc/tvm/tree/master/jvm</url>
<description>TVM4J Package</description>
<organization>
<name>Distributed (Deep) Machine Learning Community</name>
<url>http://dmlc.ml</url>
</organization>
<licenses>
<license>
<name>The Apache License, Version 2.0</name>
<url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
</license>
</licenses>
<scm>
<connection>scm:git:git@github.com:dmlc/tvm.git</connection>
<developerConnection>scm:git:git@github.com:dmlc/tvm.git</developerConnection>
<url>https://github.com/dmlc/tvm</url>
</scm>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<packaging>pom</packaging>
<modules>
<module>core</module>
<module>native</module>
<module>assembly</module>
</modules>
<profiles>
<profile>
<id>release</id>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>${project.build.directory}/genjavadoc</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<serverId>ossrh</serverId>
<nexusUrl>https://oss.sonatype.org/</nexusUrl>
<autoReleaseAfterClose>true</autoReleaseAfterClose>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<distributionManagement>
<snapshotRepository>
<id>ossrh</id>
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
</snapshotRepository>
</distributionManagement>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>2.2.1</version>
<executions>
<execution>
<phase>package</phase>
<id>attach-sources</id>
<goals>
<goal>jar-no-fork</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>native-maven-plugin</artifactId>
<version>1.0-alpha-7</version>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>2.7</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<version>2.9</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.0.2</version>
<executions>
<execution>
<id>empty-javadoc-jar</id>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<includes>
<include>**/*</include>
</includes>
<classifier>javadoc</classifier>
<classesDirectory>${basedir}/javadoc</classesDirectory>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<version>2.5.5</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-deploy-plugin</artifactId>
<version>2.8.2</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-install-plugin</artifactId>
<version>2.5.2</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.3</version>
<configuration>
<source>1.6</source>
<target>1.6</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.9.1</version>
<executions>
<execution>
<id>attach-javadocs</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
</dependencies>
</project>
#!/bin/bash
export PYTHONPATH=python:apps/extension/python
export PYTHONPATH=${PYTHONPATH}:apps/graph_executor/python:apps/graph_executor/nnvm/python
export LD_LIBRARY_PATH=lib:${LD_LIBRARY_PATH}
CURR_DIR=$(cd `dirname $0`; pwd)
SCRIPT_DIR=$CURR_DIR/../../jvm/core/src/test/scripts
TEMP_DIR=$(mktemp -d)
python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1
python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1
make jvmpkg || exit -1
make jvmpkg JVM_TEST_ARGS="-DskipTests=false -Dtest.tempdir=$TEMP_DIR" || exit -1
rm -rf $TEMP_DIR
......@@ -3,6 +3,8 @@ echo "Check codestyle of c++ code..."
make cpplint || exit -1
echo "Check codestyle of python code..."
make pylint || exit -1
echo "Check codestyle of jni code..."
make jnilint || exit -1
echo "Check documentations of c++ code..."
make doc 2>log.txt
(cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment