Skip to content
Snippets Groups Projects
Commit d3efd7fc authored by Yizhi Liu's avatar Yizhi Liu Committed by Tianqi Chen
Browse files

[WIP][Frontend] Scala/Java package (#176)

* JVM package skeleton

* [JVM] link libtvm.so and list function names

* [JVM] Function & NDArray skeleton

* [JVM] TVMFuncCall in JNI

* [JVM] handle string arg in TVMFuncCall

* [JVM] get module function

* [JVM] entry function for Module

* [JVM] construct Module from function return value

* [JVM] TVMContext, TVMArray attributes

* [JVM] NDArray from / to java array

* [JVM] load so and compute on cpu

* [JVM] move PackedFunc to individual modules

* [JVM] assembly package & native library loader

* [JVM] unit test & codestyle check settings

* [JVM] NDArray from & to different dtypes

* [JVM] NDArray from native double array. Add linux-cpu profile.

* [JVM] modify Makefile

* [JVM] add linux-x86_64-gpu profile

* [tvm4j] delay load libtvm_runtime.so

* [tvm4j] refactor to pure java

* [tvm4j] remove scalastyle-config.xml

* [tvm4j] remove link HalideIR, remove Shape, remove scala binary versions

* [tvm4j] only allow convert from/to same type array

* [tvm4j] make NDArray api more readable

* [tvm4j] refactor for c api

* [tvm4j] add Jenkins tests

* [tvm4j] fix duplicate Dockerfile cmd

* [tvm4j] fix ut script filename

* [tvm4j] add module load tests

* [tvm4j] add javadoc, remove types package

* [tvm4j] fix test script

* [tvm4j] remove ut temp dir

* [tvm4j] fix missing package types

* [tvm4j] java code style check

* [tvm4j] fix java lint

* [tvm4j] downgrade checkstyle plugin for JDK7

* [tvm4j] add stylecheck in jenkins tests

* [tvm4j] specify source file encoding

* [tvm4j] lazy init function; add Function.call() api; allow manully release Module,NDArray,Function

* [tvm4j] fix ModFree

* [tvm4j] cache Function in API
parent 86ff24ab
No related branches found
No related tags found
No related merge requests found
Showing
with 2035 additions and 2 deletions
......@@ -104,6 +104,18 @@ nnvm
## IOS
DerivedData/
## Java
*.class
jvm/*/target/
jvm/*/*/target/
*.worksheet
*.idea
*.iml
*.classpath
*.project
*.settings
*/node_modules/
## Various settings
*.pbxuser
!default.pbxuser
......@@ -119,4 +131,5 @@ xcuserdata/
*.moved-aside
*.xccheckout
*.xcscmblueprint
.DS_Store
\ No newline at end of file
.DS_Store
#!groovy
// -*- mode: groovy -*-
// Jenkins pipeline
// See documents at https://jenkins.io/doc/book/pipeline/jenkinsfile/
......@@ -183,6 +184,17 @@ stage('Unit Test') {
}
}
}
},
'java': {
node('GPU' && 'linux') {
ws('workspace/tvm/ut-java') {
init_git()
unpack_lib('gpu', tvm_lib)
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} gpu ./tests/scripts/task_java_unittest.sh"
}
}
}
}
}
......
......@@ -120,6 +120,29 @@ ifdef ADD_LDFLAGS
LDFLAGS += $(ADD_LDFLAGS)
endif
ifeq ($(OS),Windows_NT)
JVM_PKG_PROFILE := windows
else
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
JVM_PKG_PROFILE := osx-x86_64
else
JVM_PKG_PROFILE := linux-x86_64
endif
endif
JVM_TEST_ARGS := $(if $(JVM_TEST_ARGS),$(JVM_TEST_ARGS),-DskipTests -Dcheckstyle.skip=true)
ifeq ($(USE_CUDA), 1)
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-gpu
else ifeq ($(USE_OPENCL), 1)
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-gpu
else ifeq ($(USE_METAL), 1)
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-gpu
else
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-cpu
endif
include tests/cpp/unittest.mk
test: $(TEST)
......@@ -176,7 +199,10 @@ pylint:
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc
lint: cpplint pylint
jnilint:
python dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src
lint: cpplint pylint jnilint
doc:
doxygen docs/Doxyfile
......@@ -194,6 +220,12 @@ cython3:
cyclean:
rm -rf python/tvm/*/*/*.so python/tvm/*/*/*.cpp
jvmpkg:
(cd $(ROOTDIR)/jvm; \
mvn clean package -P$(JVM_PKG_PROFILE) -Dcxx="$(CXX)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dcurrent_libdir="$(ROOTDIR)/lib" $(JVM_TEST_ARGS))
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o */*.d */*/*.d */*/*/*.d
cd HalideIR; make clean; cd $(ROOTDIR)
......
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-linux-x86_64-cpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full Linux-x86_64 CPU-only</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>libtvm4j-linux-x86_64-cpu</artifactId>
<version>${project.version}</version>
<type>so</type>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<appendAssemblyId>false</appendAssemblyId>
<descriptors>
<descriptor>src/main/assembly/assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<assembly>
<id>full</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<files>
<file>
<source>../../../lib/libtvm_runtime.so</source>
<outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode>
</file>
</files>
<dependencySets>
<dependencySet>
<includes>
<include>*:*:jar</include>
</includes>
<outputDirectory>/</outputDirectory>
<useProjectArtifact>true</useProjectArtifact>
<unpack>true</unpack>
<scope>runtime</scope>
</dependencySet>
<dependencySet>
<outputDirectory>lib/native</outputDirectory>
<outputFileNameMapping>libtvm4j.so</outputFileNameMapping>
<unpack>false</unpack>
<useProjectArtifact>false</useProjectArtifact>
<useStrictFiltering>false</useStrictFiltering>
<includes>
<include>ml.dmlc.tvm:libtvm4j-linux-x86_64-cpu:so</include>
</includes>
</dependencySet>
</dependencySets>
</assembly>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-linux-x86_64-gpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full Linux-x86_64 GPU</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>libtvm4j-linux-x86_64-gpu</artifactId>
<version>${project.version}</version>
<type>so</type>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<appendAssemblyId>false</appendAssemblyId>
<descriptors>
<descriptor>src/main/assembly/assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<assembly>
<id>full</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<files>
<file>
<source>../../../lib/libtvm_runtime.so</source>
<outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode>
</file>
</files>
<dependencySets>
<dependencySet>
<includes>
<include>*:*:jar</include>
</includes>
<outputDirectory>/</outputDirectory>
<useProjectArtifact>true</useProjectArtifact>
<unpack>true</unpack>
<scope>runtime</scope>
</dependencySet>
<dependencySet>
<outputDirectory>lib/native</outputDirectory>
<outputFileNameMapping>libtvm4j.so</outputFileNameMapping>
<unpack>false</unpack>
<useProjectArtifact>false</useProjectArtifact>
<useStrictFiltering>false</useStrictFiltering>
<includes>
<include>ml.dmlc.tvm:libtvm4j-linux-x86_64-gpu:so</include>
</includes>
</dependencySet>
</dependencySets>
</assembly>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-osx-x86_64-cpu</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full OSX-x86_64 CPU-only</name>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>libtvm4j-osx-x86_64-cpu</artifactId>
<version>${project.version}</version>
<type>jnilib</type>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<appendAssemblyId>false</appendAssemblyId>
<descriptors>
<descriptor>src/main/assembly/assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
<assembly>
<id>full</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<files>
<file>
<source>../../../lib/libtvm_runtime.so</source>
<outputDirectory>lib/native</outputDirectory>
<fileMode>0644</fileMode>
</file>
</files>
<dependencySets>
<dependencySet>
<includes>
<include>*:*:jar</include>
</includes>
<outputDirectory>/</outputDirectory>
<useProjectArtifact>true</useProjectArtifact>
<unpack>true</unpack>
<scope>runtime</scope>
</dependencySet>
<dependencySet>
<outputDirectory>lib/native</outputDirectory>
<outputFileNameMapping>libtvm4j.jnilib</outputFileNameMapping>
<unpack>false</unpack>
<useProjectArtifact>false</useProjectArtifact>
<useStrictFiltering>false</useStrictFiltering>
<includes>
<include>ml.dmlc.tvm:libtvm4j-osx-x86_64-cpu:jnilib</include>
</includes>
</dependencySet>
</dependencySets>
</assembly>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-full-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Full Parent</name>
<packaging>pom</packaging>
<profiles>
<profile>
<id>osx-x86_64-cpu</id>
<modules>
<module>osx-x86_64-cpu</module>
</modules>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<modules>
<module>linux-x86_64-cpu</module>
</modules>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<modules>
<module>linux-x86_64-gpu</module>
</modules>
</profile>
<profile>
<id>release</id>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>jar-no-fork</goal>
</goals>
<configuration>
<includePom>true</includePom>>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<includeDependencySources>true</includeDependencySources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.sonatype.plugins</groupId>
<artifactId>nexus-staging-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<serverId>ossrh</serverId>
<nexusUrl>https://oss.sonatype.org/</nexusUrl>
<autoReleaseAfterClose>true</autoReleaseAfterClose>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>
<?xml version="1.0"?>
<!DOCTYPE module PUBLIC
"-//Puppy Crawl//DTD Check Configuration 1.3//EN"
"http://www.puppycrawl.com/dtds/configuration_1_3.dtd">
<!--
Checkstyle configuration that checks the Google coding conventions from:
- Google Java Style
https://google-styleguide.googlecode.com/svn-history/r130/trunk/javaguide.html
Checkstyle is very configurable. Be sure to read the documentation at
http://checkstyle.sf.net (or in your downloaded distribution).
Most Checks are configurable, be sure to consult the documentation.
To completely disable a check, just comment it out or delete it from the file.
Authors: Max Vetrenko, Ruslan Diachenko, Roman Ivanov.
-->
<module name = "Checker">
<property name="charset" value="UTF-8"/>
<property name="severity" value="error"/>
<property name="fileExtensions" value="java, properties, xml"/>
<!-- Checks for whitespace -->
<!-- See http://checkstyle.sf.net/config_whitespace.html -->
<module name="FileTabCharacter">
<property name="eachLine" value="true"/>
</module>
<module name="TreeWalker">
<module name="OuterTypeFilename"/>
<module name="IllegalTokenText">
<property name="tokens" value="STRING_LITERAL, CHAR_LITERAL"/>
<property name="format" value="\\u00(08|09|0(a|A)|0(c|C)|0(d|D)|22|27|5(C|c))|\\(0(10|11|12|14|15|42|47)|134)"/>
<property name="message" value="Avoid using corresponding octal or Unicode escape."/>
</module>
<module name="AvoidEscapedUnicodeCharacters">
<property name="allowEscapesForControlCharacters" value="true"/>
<property name="allowByTailComment" value="true"/>
<property name="allowNonPrintableEscapes" value="true"/>
</module>
<module name="LineLength">
<property name="max" value="100"/>
<property name="ignorePattern" value="^package.*|^import.*|a href|href|http://|https://|ftp://"/>
</module>
<module name="AvoidStarImport"/>
<module name="OneTopLevelClass"/>
<module name="NoLineWrap"/>
<module name="EmptyBlock">
<property name="option" value="TEXT"/>
<property name="tokens" value="LITERAL_TRY, LITERAL_FINALLY, LITERAL_IF, LITERAL_ELSE, LITERAL_SWITCH"/>
</module>
<module name="NeedBraces"/>
<module name="LeftCurly">
<property name="maxLineLength" value="100"/>
</module>
<module name="RightCurly"/>
<module name="RightCurly">
<property name="option" value="alone"/>
<property name="tokens" value="CLASS_DEF, METHOD_DEF, CTOR_DEF, LITERAL_FOR, LITERAL_WHILE, LITERAL_DO, STATIC_INIT, INSTANCE_INIT"/>
</module>
<module name="WhitespaceAround">
<property name="allowEmptyConstructors" value="true"/>
<property name="allowEmptyMethods" value="true"/>
<property name="allowEmptyTypes" value="true"/>
<property name="allowEmptyLoops" value="true"/>
<message key="ws.notFollowed"
value="WhitespaceAround: ''{0}'' is not followed by whitespace. Empty blocks may only be represented as '{}' when not part of a multi-block statement (4.1.3)"/>
<message key="ws.notPreceded"
value="WhitespaceAround: ''{0}'' is not preceded with whitespace."/>
</module>
<module name="OneStatementPerLine"/>
<module name="MultipleVariableDeclarations"/>
<module name="ArrayTypeStyle"/>
<module name="MissingSwitchDefault"/>
<module name="FallThrough"/>
<module name="UpperEll"/>
<module name="ModifierOrder"/>
<module name="EmptyLineSeparator">
<property name="allowNoEmptyLineBetweenFields" value="true"/>
</module>
<module name="SeparatorWrap">
<property name="tokens" value="DOT"/>
<property name="option" value="nl"/>
</module>
<module name="SeparatorWrap">
<property name="tokens" value="COMMA"/>
<property name="option" value="EOL"/>
</module>
<module name="PackageName">
<property name="format" value="^[a-z]+(\.[a-z][a-z0-9]*)*$"/>
<message key="name.invalidPattern"
value="Package name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="TypeName">
<message key="name.invalidPattern"
value="Type name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="MemberName">
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9]*$"/>
<message key="name.invalidPattern"
value="Member name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="ParameterName">
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9]*$"/>
<message key="name.invalidPattern"
value="Parameter name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="LocalVariableName">
<property name="tokens" value="VARIABLE_DEF"/>
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9]*$"/>
<property name="allowOneCharVarInForLoop" value="true"/>
<message key="name.invalidPattern"
value="Local variable name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="ClassTypeParameterName">
<property name="format" value="(^[A-Z][0-9]?)$|([A-Z][a-zA-Z0-9]*[T]$)"/>
<message key="name.invalidPattern"
value="Class type name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="MethodTypeParameterName">
<property name="format" value="(^[A-Z][0-9]?)$|([A-Z][a-zA-Z0-9]*[T]$)"/>
<message key="name.invalidPattern"
value="Method type name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="GenericWhitespace">
<message key="ws.followed"
value="GenericWhitespace ''{0}'' is followed by whitespace."/>
<message key="ws.preceded"
value="GenericWhitespace ''{0}'' is preceded with whitespace."/>
<message key="ws.illegalFollow"
value="GenericWhitespace ''{0}'' should followed by whitespace."/>
<message key="ws.notPreceded"
value="GenericWhitespace ''{0}'' is not preceded with whitespace."/>
</module>
<module name="Indentation">
<property name="basicOffset" value="2"/>
<property name="braceAdjustment" value="0"/>
<property name="caseIndent" value="2"/>
<property name="throwsIndent" value="4"/>
<property name="lineWrappingIndentation" value="4"/>
<property name="arrayInitIndent" value="2"/>
</module>
<module name="AbbreviationAsWordInName">
<property name="ignoreFinal" value="false"/>
<property name="allowedAbbreviationLength" value="5"/>
</module>
<module name="OverloadMethodsDeclarationOrder"/>
<module name="VariableDeclarationUsageDistance"/>
<module name="CustomImportOrder">
<property name="specialImportsRegExp" value="com.google"/>
<property name="sortImportsInGroupAlphabetically" value="true"/>
<property name="customImportOrderRules" value="STATIC###SPECIAL_IMPORTS###THIRD_PARTY_PACKAGE###STANDARD_JAVA_PACKAGE"/>
</module>
<module name="MethodParamPad"/>
<module name="OperatorWrap">
<property name="option" value="NL"/>
<property name="tokens" value="BAND, BOR, BSR, BXOR, DIV, EQUAL, GE, GT, LAND, LE, LITERAL_INSTANCEOF, LOR, LT, MINUS, MOD, NOT_EQUAL, PLUS, QUESTION, SL, SR, STAR "/>
</module>
<module name="AnnotationLocation">
<property name="tokens" value="CLASS_DEF, INTERFACE_DEF, ENUM_DEF, METHOD_DEF, CTOR_DEF"/>
</module>
<module name="AnnotationLocation">
<property name="tokens" value="VARIABLE_DEF"/>
<property name="allowSamelineMultipleAnnotations" value="true"/>
</module>
<module name="NonEmptyAtclauseDescription"/>
<module name="JavadocTagContinuationIndentation"/>
<module name="SummaryJavadocCheck">
<property name="forbiddenSummaryFragments" value="^@return the *|^This method returns |^A [{]@code [a-zA-Z0-9]+[}]( is a )"/>
</module>
<module name="JavadocParagraph"/>
<module name="AtclauseOrder">
<property name="tagOrder" value="@param, @return, @throws, @deprecated"/>
<property name="target" value="CLASS_DEF, INTERFACE_DEF, ENUM_DEF, METHOD_DEF, CTOR_DEF, VARIABLE_DEF"/>
</module>
<module name="JavadocMethod">
<property name="scope" value="public"/>
<property name="allowMissingParamTags" value="true"/>
<property name="allowMissingThrowsTags" value="true"/>
<property name="allowMissingReturnTag" value="true"/>
<property name="minLineCount" value="2"/>
<property name="allowedAnnotations" value="Override, Test"/>
<property name="allowThrowsTagsForSubclasses" value="true"/>
</module>
<module name="MethodName">
<property name="format" value="^[a-z][a-z0-9][a-zA-Z0-9_]*$"/>
<message key="name.invalidPattern"
value="Method name ''{0}'' must match pattern ''{1}''."/>
</module>
<module name="SingleLineJavadoc">
<property name="ignoreInlineTags" value="false"/>
</module>
<module name="EmptyCatchBlock">
<property name="exceptionVariableName" value="expected"/>
</module>
<module name="CommentsIndentation"/>
</module>
</module>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-parent</artifactId>
<version>0.0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>ml.dmlc.tvm</groupId>
<artifactId>tvm4j-core</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>TVM4J Package - Core</name>
<profiles>
<profile>
<id>osx-x86_64-cpu</id>
<properties>
<platform>osx-x86_64-cpu</platform>
</properties>
</profile>
<profile>
<id>linux-x86_64-cpu</id>
<properties>
<platform>linux-x86_64-cpu</platform>
</properties>
</profile>
<profile>
<id>linux-x86_64-gpu</id>
<properties>
<platform>linux-x86_64-gpu</platform>
</properties>
</profile>
</profiles>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<version>2.17</version>
<dependencies>
<dependency>
<groupId>com.puppycrawl.tools</groupId>
<artifactId>checkstyle</artifactId>
<version>6.12</version>
</dependency>
</dependencies>
<executions>
<execution>
<phase>process-sources</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
<configuration>
<failsOnError>true</failsOnError>
<configLocation>${project.parent.basedir}/conf/google_checks.xml</configLocation>
<consoleOutput>true</consoleOutput>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.7</version>
<configuration>
<forkCount>1</forkCount>
<reuseForks>true</reuseForks>
<threadCount>1</threadCount>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/${platform}/target
-Dlibtvm.so.path=${project.parent.basedir}/../lib/libtvm_runtime.so
</argLine>
</configuration>
<executions>
<execution>
<id>test</id>
<!--
We put it in the integration-test phase,
because the test suites require the jni library,
which means, in order to run the unit tests,
we have to compile all the modules first.
-->
<phase>integration-test</phase>
<goals>
<goal>test</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.HashMap;
import java.util.Map;
/**
* TVM API functions.
*/
public final class API {
private static ThreadLocal<Map<String, Function>> apiFuncs
= new ThreadLocal<Map<String, Function>>() {
@Override
protected Map<String, Function> initialValue() {
return new HashMap<String, Function>();
}
};
/**
* Get a tvm api function according by name.
* @param name function name.
* @return a TVM Function.
*/
public static Function get(final String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction(name);
apiFuncs.get().put(name, func);
}
return func;
}
/**
* Cannot be instantiated.
*/
private API() {
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
/**
* Internal api functions.
*/
public final class APIInternal {
/**
* Get a tvm api function according by name.
* @param name function name.
* @return a TVM Function.
*/
public static Function get(final String name) {
return API.get(name);
}
/**
* Cannot be instantiated.
*/
private APIInternal() {
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import ml.dmlc.tvm.NativeLibraryLoader.Action;
import java.io.File;
import java.io.IOException;
/**
* Initializing methods and types.
*/
final class Base {
/**
* Hold Long reference for JNI.
*/
public static class RefLong {
public final long value;
public RefLong(final long value) {
this.value = value;
}
public RefLong() {
this(0L);
}
}
/**
* Hold TVMValue reference for JNI.
*/
public static class RefTVMValue {
public final TVMValue value;
public RefTVMValue(TVMValue value) {
this.value = value;
}
public RefTVMValue() {
this(null);
}
}
public static final LibInfo _LIB = new LibInfo();
static {
try {
try {
tryLoadLibraryOS("tvm4j");
} catch (UnsatisfiedLinkError e) {
System.err.println("[WARN] TVM native library not found in path. "
+ "Copying native library from the archive. "
+ "Consider installing the library somewhere in the path "
+ "(for Windows: PATH, for Linux: LD_LIBRARY_PATH), "
+ "or specifying by Java cmd option -Djava.library.path=[lib path].");
NativeLibraryLoader.loadLibrary("tvm4j");
}
} catch (Throwable e) {
System.err.println("[ERROR] Couldn't find native library tvm4j");
throw new RuntimeException(e);
}
String tvmLibFilename = System.getProperty("libtvm.so.path");
if (tvmLibFilename == null || !new File(tvmLibFilename).isFile()
|| _LIB.nativeLibInit(tvmLibFilename) != 0) {
try {
NativeLibraryLoader.extractResourceFileToTempDir("libtvm_runtime.so", new Action() {
@Override public void invoke(File target) {
System.err.println("Loading tvm runtime from " + target.getPath());
checkCall(_LIB.nativeLibInit(target.getPath()));
}
});
} catch (IOException e) {
throw new RuntimeException(e);
}
}
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override public void run() {
_LIB.shutdown();
}
});
}
/**
* Load JNI for different OS.
* @param libname library name.
* @throws UnsatisfiedLinkError if loading fails.
*/
private static void tryLoadLibraryOS(String libname) throws UnsatisfiedLinkError {
try {
System.err.println(String.format("Try loading %s from native path.", libname));
System.loadLibrary(libname);
} catch (UnsatisfiedLinkError e) {
String os = System.getProperty("os.name");
// ref: http://lopica.sourceforge.net/os.html
if (os.startsWith("Linux")) {
tryLoadLibraryXPU(libname, "linux-x86_64");
} else if (os.startsWith("Mac")) {
tryLoadLibraryXPU(libname, "osx-x86_64");
} else {
// TODO(yizhi) support windows later
throw new UnsatisfiedLinkError("Windows not supported currently");
}
}
}
/**
* Load native library for different architectures.
* @param libname library name.
* @param arch architecture.
* @throws UnsatisfiedLinkError if loading fails
*/
private static void tryLoadLibraryXPU(String libname, String arch) throws UnsatisfiedLinkError {
try {
// try gpu first
System.err.println(String.format("Try loading %s-%s-gpu from native path.", libname, arch));
System.loadLibrary(String.format("%s-%s-gpu", libname, arch));
} catch (UnsatisfiedLinkError e) {
System.err.println(String.format("Try loading %s-%s-cpu from native path.", libname, arch));
System.loadLibrary(String.format("%s-%s-cpu", libname, arch));
}
}
// helper function definitions
/**
* Check the return value of C API call
* <p>
* This function will raise exception when error occurs.
* Wrap every API call with this function
* </p>
* @param ret return value from API calls
*/
public static void checkCall(int ret) throws TVMError {
if (ret != 0) {
throw new TVMError(_LIB.tvmGetLastError());
}
}
/**
* TVM Runtime error.
*/
static class TVMError extends RuntimeException {
public TVMError(String err) {
super(err);
}
}
/**
* Cannot be instantiated.
*/
private Base() {
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class Function {
final long handle;
public final boolean isResident;
private boolean isReleased = false;
/**
* Get registered function.
* @param name full function name.
* @return TVM function.
*/
static Function getFunction(final String name) {
for (String fullName : listGlobalFuncNames()) {
if (fullName.equals(name)) {
return getGlobalFunc(fullName, true, false);
}
}
return null;
}
/**
* Get list of global functions registered.
* @return List of global functions names.
*/
private static List<String> listGlobalFuncNames() {
List<String> names = new ArrayList<String>();
Base.checkCall(Base._LIB.tvmFuncListGlobalNames(names));
return Collections.unmodifiableList(names);
}
/**
* Get a global function by name.
* @param name The name of the function.
* @param isResident Whether it is a global 'resident' function.
* @param allowMissing Whether allow missing function or raise an error.
* @return The function to be returned, None if function is missing.
*/
private static Function getGlobalFunc(String name, boolean isResident, boolean allowMissing) {
Base.RefLong handle = new Base.RefLong();
Base.checkCall(Base._LIB.tvmFuncGetGlobal(name, handle));
if (handle.value != 0) {
return new Function(handle.value, isResident);
} else {
if (allowMissing) {
return null;
} else {
throw new IllegalArgumentException("Cannot find global function " + name);
}
}
}
/**
* Initialize the function with handle
* @param handle the handle to the underlying function.
* @param isResident Whether this is a resident function in jvm
*/
public Function(long handle, boolean isResident) {
this.handle = handle;
this.isResident = isResident;
}
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Release the Function.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
if (!isReleased) {
if (!isResident) {
Base.checkCall(Base._LIB.tvmFuncFree(handle));
isReleased = true;
}
}
}
/**
* Invoke the function.
* @return the result.
*/
public TVMValue invoke() {
Base.RefTVMValue ret = new Base.RefTVMValue();
Base.checkCall(Base._LIB.tvmFuncCall(handle, ret));
return ret.value;
}
/**
* Push argument to the function.
* @param arg int argument.
* @return this
*/
public Function pushArg(int arg) {
Base._LIB.tvmFuncPushArgLong(arg);
return this;
}
/**
* Push argument to the function.
* @param arg long argument.
* @return this
*/
public Function pushArg(long arg) {
Base._LIB.tvmFuncPushArgLong(arg);
return this;
}
/**
* Push argument to the function.
* @param arg float argument.
* @return this
*/
public Function pushArg(float arg) {
Base._LIB.tvmFuncPushArgDouble(arg);
return this;
}
/**
* Push argument to the function.
* @param arg double argument.
* @return this
*/
public Function pushArg(double arg) {
Base._LIB.tvmFuncPushArgDouble(arg);
return this;
}
/**
* Push argument to the function.
* @param arg String argument.
* @return this
*/
public Function pushArg(String arg) {
Base._LIB.tvmFuncPushArgString(arg);
return this;
}
/**
* Push argument to the function.
* @param arg NDArray.
* @return this
*/
public Function pushArg(NDArray arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.ARRAY_HANDLE.id);
return this;
}
/**
* Invoke function with arguments.
* @param args Can be Integer, Long, Float, Double, String, NDArray.
* @return the result.
*/
public TVMValue call(Object... args) {
for (Object arg : args) {
if (arg instanceof Integer) {
pushArg((Integer) arg);
} else if (arg instanceof Long) {
pushArg((Long) arg);
} else if (arg instanceof Float) {
pushArg((Float) arg);
} else if (arg instanceof Double) {
pushArg((Double) arg);
} else if (arg instanceof String) {
pushArg((String) arg);
} else if (arg instanceof NDArray) {
pushArg((NDArray) arg);
} else {
throw new IllegalArgumentException("Invalid argument: " + arg);
}
}
return invoke();
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.List;
class LibInfo {
public native int nativeLibInit(String tvmLibFile);
public native int shutdown();
public native String tvmGetLastError();
// Function
public native void tvmFuncPushArgLong(long arg);
public native void tvmFuncPushArgDouble(double arg);
public native void tvmFuncPushArgString(String arg);
public native void tvmFuncPushArgHandle(long arg, int argType);
public native int tvmFuncListGlobalNames(List<String> funcNames);
public native int tvmFuncFree(long handle);
public native int tvmFuncGetGlobal(String name, Base.RefLong handle);
public native int tvmFuncCall(long handle, Base.RefTVMValue retVal);
// Module
public native int tvmModFree(long handle);
public native int tvmModGetFunction(long handle, String name,
int queryImports, Base.RefLong retHandle);
public native int tvmModImport(long mod, long dep);
// NDArray
public native int tvmArrayFree(long handle);
public native int tvmArrayAlloc(long[] shape,
int dtypeCode,
int dtypeBits,
int dtypeLanes,
int deviceType,
int deviceId,
Base.RefLong refHandle);
public native int tvmArrayGetShape(long handle, List<Long> shape);
public native int tvmArrayCopyFromTo(long from, long to);
public native int tvmArrayCopyFromJArray(byte[] fromRaw, long from, long to);
public native int tvmArrayCopyToJArray(long from, byte[] to);
// TVMContext
public native int tvmSynchronize(int deviceType, int deviceId);
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.util.HashMap;
import java.util.Map;
/**
* Container of compiled functions of TVM.
*/
public class Module {
public final long handle;
private boolean isReleased = false;
private static ThreadLocal<Map<String, Function>> apiFuncs
= new ThreadLocal<Map<String, Function>>() {
@Override
protected Map<String, Function> initialValue() {
return new HashMap<String, Function>();
}
};
private static Function getApi(String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction("module." + name);
apiFuncs.get().put(name, func);
}
return func;
}
public Module(long handle) {
this.handle = handle;
}
private Function entry = null;
private final String entryName = "__tvm_main__";
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Release the Module.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
if (!isReleased) {
Base.checkCall(Base._LIB.tvmModFree(handle));
isReleased = true;
}
}
/**
* Get the entry function.
* @return The entry function if exist
*/
public Function entryFunc() {
if (entry == null) {
entry = getFunction(entryName);
}
return entry;
}
/**
* Get function from the module.
* @param name The name of the function.
* @param queryImports Whether also query modules imported by this module.
* @return The result function.
*/
public Function getFunction(String name, boolean queryImports) {
Base.RefLong retHandle = new Base.RefLong();
Base.checkCall(Base._LIB.tvmModGetFunction(
handle, name, queryImports ? 1 : 0, retHandle));
if (retHandle.value == 0) {
throw new IllegalArgumentException("Module has no function " + name);
}
return new Function(retHandle.value, false);
}
public Function getFunction(String name) {
return getFunction(name, false);
}
/**
* Add module to the import list of current one.
* @param module The other module.
*/
public void importModule(Module module) {
Base.checkCall(Base._LIB.tvmModImport(handle, module.handle));
}
/**
* Load module from file.
* @param path The path to the module file.
* @param fmt The format of the file,
* if not specified it will be inferred from suffix of the file.
* @return The loaded module
*/
public static Module load(String path, String fmt) {
TVMValue ret = getApi("_LoadFromFile").pushArg(path).pushArg(fmt).invoke();
assert ret.typeCode == TypeCode.MODULE_HANDLE;
return ret.asModule();
}
public static Module load(String path) {
return load(path, "");
}
/**
* Whether module runtime is enabled for target,
* e.g., The following code checks if gpu is enabled.
* Module.enabled("gpu")
* @param target The target device type.
* @return Whether runtime is enabled.
*/
public static boolean enabled(String target) {
TVMValue ret = getApi("_Enabled").pushArg(target).invoke();
return ret.asLong() != 0;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
/**
* Lightweight NDArray class of TVM runtime.
*/
public class NDArray {
public final long handle;
private final boolean isView;
private final TVMType dtype;
private boolean isReleased = false;
NDArray(long handle, boolean isView, TVMType dtype) {
this.handle = handle;
this.isView = isView;
this.dtype = dtype;
}
NDArray(long handle) {
this(handle, false, new TVMType("float32", 1));
}
NDArray(long handle, boolean isView) {
this(handle, isView, new TVMType("float32", 1));
}
@Override protected void finalize() throws Throwable {
release();
super.finalize();
}
/**
* Release the NDArray memory.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
if (!isReleased) {
if (!isView) {
Base.checkCall(Base._LIB.tvmArrayFree(handle));
isReleased = true;
}
}
}
/**
* Copy from a native array.
* The NDArray type must by float64
* @param sourceArray the source data
*/
public void copyFrom(double[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.FLOAT || dtype.bits != 64) {
throw new IllegalArgumentException("Cannot set double[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putDouble(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Copy from a native array.
* The NDArray type must by float32
* @param sourceArray the source data
*/
public void copyFrom(float[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.FLOAT || dtype.bits != 32) {
throw new IllegalArgumentException("Cannot set float[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putFloat(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Copy from a native array.
* The NDArray type must by int64
* @param sourceArray the source data
*/
public void copyFrom(long[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.INT || dtype.bits != 64) {
throw new IllegalArgumentException("Cannot set long[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putLong(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Copy from a native array.
* The NDArray type must by float32
* @param sourceArray the source data
*/
public void copyFrom(int[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.INT || dtype.bits != 32) {
throw new IllegalArgumentException("Cannot set int[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putInt(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Copy from a native array.
* The NDArray type must by int16
* @param sourceArray the source data
*/
public void copyFrom(short[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.INT || dtype.bits != 16) {
throw new IllegalArgumentException("Cannot set short[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putShort(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Copy from a native array.
* The NDArray type must by int8
* @param sourceArray the source data
*/
public void copyFrom(byte[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.INT || dtype.bits != 8) {
throw new IllegalArgumentException("Cannot set byte[] for " + dtype.toString() + " array");
}
copyFromRaw(sourceArray);
}
/**
* Copy from a native array.
* The NDArray type must by uint16
* @param sourceArray the source data
*/
public void copyFrom(char[] sourceArray) {
checkCopySize(sourceArray.length);
if (dtype.typeCode != TVMType.UINT || dtype.bits != 16) {
throw new IllegalArgumentException("Cannot set char[] for " + dtype.toString() + " array");
}
byte[] nativeArr = new byte[sourceArray.length * dtype.numOfBytes];
for (int i = 0; i < sourceArray.length; ++i) {
wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putChar(sourceArray[i]);
}
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
private void checkCopySize(int sourceLength) {
long arrSize = size();
if (arrSize != sourceLength) {
throw new IllegalArgumentException(String.format("Array shape size not match: %d v.s. %d",
sourceLength, size()));
}
}
/**
* Copy from a raw byte array.
* @param sourceArray the source data
*/
public void copyFromRaw(byte[] sourceArray) {
NDArray tmpArr = empty(shape(), this.dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(sourceArray, tmpArr.handle, handle));
Base.checkCall(Base._LIB.tvmArrayFree(tmpArr.handle));
}
/**
* Get shape of current NDArray.
* @return an array representing shape of current ndarray
*/
public long[] shape() {
List<Long> data = new ArrayList<Long>();
Base.checkCall(Base._LIB.tvmArrayGetShape(handle, data));
long[] shapeArr = new long[data.size()];
for (int i = 0; i < shapeArr.length; ++i) {
shapeArr[i] = data.get(i);
}
return shapeArr;
}
/**
* Get total size of current NDArray.
* @return size of current NDArray.
*/
public long size() {
long product = 1L;
long[] shapeArr = shape();
for (int i = 0; i < shapeArr.length; ++i) {
product *= shapeArr[i];
}
return product;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be float64
* @return A copy of array content.
*/
public double[] asDoubleArray() {
if (dtype.typeCode != TVMType.FLOAT || dtype.bits != 64) {
throw new IllegalArgumentException(
"Cannot set convert to double[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
double[] array = new double[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getDouble();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be float32
* @return A copy of array content.
*/
public float[] asFloatArray() {
if (dtype.typeCode != TVMType.FLOAT || dtype.bits != 32) {
throw new IllegalArgumentException(
"Cannot set convert to float[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
float[] array = new float[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getFloat();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be int64
* @return A copy of array content.
*/
public long[] asLongArray() {
if (dtype.typeCode != TVMType.INT || dtype.bits != 64) {
throw new IllegalArgumentException(
"Cannot set convert to long[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
long[] array = new long[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getLong();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be int32
* @return A copy of array content.
*/
public int[] asIntArray() {
if (dtype.typeCode != TVMType.INT || dtype.bits != 32) {
throw new IllegalArgumentException(
"Cannot set convert to int[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
int[] array = new int[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getInt();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be int16
* @return A copy of array content.
*/
public short[] asShortArray() {
if (dtype.typeCode != TVMType.INT || dtype.bits != 16) {
throw new IllegalArgumentException(
"Cannot set convert to short[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
short[] array = new short[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getShort();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be uint16
* @return A copy of array content.
*/
public char[] asCharArray() {
if (dtype.typeCode != TVMType.UINT || dtype.bits != 16) {
throw new IllegalArgumentException(
"Cannot set convert to char[] for " + dtype.toString() + " array");
}
byte[][] units = groupInternalBytes();
char[] array = new char[units.length];
for (int i = 0; i < units.length; ++i) {
array[i] = wrapBytes(units[i]).getChar();
}
return array;
}
/**
* Return a copied flat java array of current array (row-major).
* The NDArray dtype must be int8
* @return A copy of array content.
*/
public byte[] asByteArray() {
if (dtype.typeCode != TVMType.INT || dtype.bits != 8) {
throw new IllegalArgumentException(
"Cannot set convert to byte[] for " + dtype.toString() + " array");
}
return internal();
}
/**
* Return a copied internal byte array of current array (row-major).
* @return A copy of array content.
*/
public byte[] internal() {
NDArray tmp = NDArray.empty(shape(), dtype);
Base.checkCall(Base._LIB.tvmArrayCopyFromTo(handle, tmp.handle));
int arrLength = dtype.numOfBytes * (int) size();
byte[] arr = new byte[arrLength];
Base.checkCall(Base._LIB.tvmArrayCopyToJArray(tmp.handle, arr));
return arr;
}
private byte[][] groupInternalBytes() {
byte[] raw = internal();
int unitSize = dtype.numOfBytes;
if (raw.length <= 0 || raw.length % unitSize != 0) {
throw new IllegalArgumentException(String.format(
"%s size %d cannot divide byte array size %d",
dtype.toString(), unitSize, raw.length));
}
int numOfUnits = raw.length / unitSize;
byte[][] units = new byte[numOfUnits][unitSize];
for (int i = 0; i < numOfUnits; ++i) {
System.arraycopy(raw, i * unitSize, units[i], 0, unitSize);
}
return units;
}
/**
* Create an empty array given shape, type and device.
* @param shape The shape of the array.
* @param dtype The data type of the array.
* @param ctx The context of the array.
* @return The array tvm supported.
*/
public static NDArray empty(long[] shape, TVMType dtype, TVMContext ctx) {
Base.RefLong refHandle = new Base.RefLong();
Base.checkCall(Base._LIB.tvmArrayAlloc(
shape, dtype.typeCode, dtype.bits, dtype.lanes,
ctx.deviceType, ctx.deviceId, refHandle));
return new NDArray(refHandle.value, false, dtype);
}
/**
* Create an empty array on cpu given shape and type.
* @param shape The shape of the array.
* @param dtype The data type of the array.
* @return The array tvm supported.
*/
public static NDArray empty(long[] shape, TVMType dtype) {
return empty(shape, dtype, new TVMContext(1, 0));
}
/**
* Create an empty float32 array on cpu given shape.
* @param shape The shape of the array.
* @return The array tvm supported.
*/
public static NDArray empty(long[] shape) {
return empty(shape, new TVMType("float32", 1), new TVMContext(1, 0));
}
/**
* Create an empty float32 array given shape and device.
* @param shape The shape of the array.
* @param ctx The context of the array.
* @return The array tvm supported.
*/
public static NDArray empty(long[] shape, TVMContext ctx) {
return empty(shape, new TVMType("float32", 1), ctx);
}
private static ByteBuffer wrapBytes(byte[] bytes) {
ByteBuffer bb = ByteBuffer.wrap(bytes);
bb.order(ByteOrder.LITTLE_ENDIAN);
return bb;
}
private static ByteBuffer wrapBytes(byte[] bytes, int offset, int length) {
ByteBuffer bb = ByteBuffer.wrap(bytes, offset, length);
bb.order(ByteOrder.LITTLE_ENDIAN);
return bb;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.tvm;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
class NativeLibraryLoader {
private static final String libPathInJar = "/lib/native/";
private static File tempDir;
static {
try {
tempDir = File.createTempFile("tvm", "");
if (!tempDir.delete() || !tempDir.mkdir()) {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}
/*
* Different cleanup strategies for Windows and Linux.
* TODO: shutdown hook won't work on Windows
*/
if (!"Windows".equals(getUnifiedOSName())) {
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override public void run() {
for (File f : tempDir.listFiles()) {
System.err.println("Deleting " + f.getAbsolutePath());
if (!f.delete()) {
System.err.println("[WARN] Couldn't delete temporary file " + f.getAbsolutePath());
}
}
System.err.println("Deleting " + tempDir.getAbsolutePath());
if (!tempDir.delete()) {
System.err.println(
"[WARN] Couldn't delete temporary directory " + tempDir.getAbsolutePath());
}
}
});
} else {
throw new RuntimeException("Windows not supported yet.");
}
} catch (IOException ex) {
System.err.println("Couldn't create temporary directory: " + ex.getMessage());
throw new RuntimeException(ex);
}
}
/**
* Find the library as a resource in jar, copy it to a tempfile
* and load it using System.load(). The name of the library has to be the
* base name, it is mapped to the corresponding system name using
* System.mapLibraryName(). e.g., the library "foo" is called "libfoo.so"
* under Linux and "foo.dll" under Windows, but you just have to pass "foo" to
* the loadLibrary().
*
* @param libname basename of the library
* @throws UnsatisfiedLinkError if library not found.
* @throws IOException if file not found.
*/
public static void loadLibrary(String libname) throws UnsatisfiedLinkError, IOException {
String mappedLibname = System.mapLibraryName(libname);
String loadLibname = mappedLibname;
if (mappedLibname.endsWith("dylib")) {
System.err.println("Replaced .dylib with .jnilib");
loadLibname = mappedLibname.replace(".dylib", ".jnilib");
}
System.err.println("Attempting to load " + loadLibname);
extractResourceFileToTempDir(loadLibname, new Action() {
@Override public void invoke(File target) {
System.err.println("Loading library from " + target.getPath());
System.load(target.getPath());
}
});
}
/**
* Translate all those Windows to "Windows". ("Windows XP", "Windows Vista", "Windows 7", etc.)
*/
private static String unifyOSName(String osname) {
if (osname.startsWith("Windows")) {
return "Windows";
}
return osname;
}
private static String getUnifiedOSName() {
return unifyOSName(System.getProperty("os.name"));
}
private static File createTempFile(String name) throws IOException {
return new File(tempDir + File.separator + name);
}
static interface Action {
public void invoke(File file);
}
/**
* Copies the resource file to a temp file and do an action.
* @param filename source file name (in lib/native).
* @param action callback function to deal with the copied file.
*/
public static void extractResourceFileToTempDir(String filename, Action action)
throws IOException {
final String libFileInJar = libPathInJar + filename;
InputStream is = NativeLibraryLoader.class.getResourceAsStream(libFileInJar);
if (is == null) {
throw new UnsatisfiedLinkError("Couldn't find the resource " + filename);
}
System.err.println(String.format("Loading %s from %s", filename, libPathInJar));
try {
File tempfile = createTempFile(filename);
OutputStream os = new FileOutputStream(tempfile);
final long savedTime = System.currentTimeMillis();
byte[] buf = new byte[8192];
int len = is.read(buf);
while (len > 0) {
os.write(buf, 0, len);
len = is.read(buf);
}
os.flush();
final FileInputStream lock = new FileInputStream(tempfile);
os.close();
double seconds = (double) (System.currentTimeMillis() - savedTime) / 1e3;
System.err.println(String.format("Copying took %.2f seconds.", seconds));
action.invoke(tempfile);
lock.close();
} catch (IOException io) {
System.err.println("[ERROR] Could not create the temp file: " + io.toString());
throw io;
} catch (UnsatisfiedLinkError ule) {
System.err.println("Couldn't load copied link file: " + ule.toString());
throw ule;
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment