Commit 1cb602f1 by Siva Committed by Tianqi Chen

[RUNTIME][GOLANG] TVM runtime for golang v0.1 (#1470)

parent 069aa381
.PHONY: clean all
TVM_BASE = $(CURDIR)/../
TARGET = gotvm
LIBS = -lm -ldl
NATIVE_SRC = tvm_runtime_pack.cc
GOPATH=$(CURDIR)/gopath
GOPATHDIR=${GOPATH}/src/${TARGET}/
CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/"
CGO_CXXFLAGS="-std=c++11"
CGO_CFLAGS="-I${TVM_BASE}"
CGO_LDFLAGS="-ldl -lm"
all:
@mkdir gopath 2>/dev/null || true
@mkdir gopath/src 2>/dev/null || true
@mkdir gopath/src/$(TARGET) 2>/dev/null || true
@cp src/$(TARGET).cc gopath/src/$(TARGET)
@cp src/$(TARGET).h gopath/src/$(TARGET)
@cp src/$(NATIVE_SRC) gopath/src/$(TARGET)
@cp src/*.go gopath/src/$(TARGET)
@export GOPATH=$(GOPATH); \
export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \
export CGO_CXXFLAGS=$(CGO_CXXFLAGS); \
export CGO_CFLAGS=$(CGO_CFLAGS); \
export CGO_LDFLAGS=$(CGO_LDFLAGS); \
(cd $(GOPATHDIR) && go clean -cache \
&& golint && go build -o $(TARGET).a \
&& go install)
@find . -name gotvm.a
@#mkdir gopath/doc 2>/dev/null || true
@#godoc -html -goroot gopath/ gotvm | grep -v "for documentation on the gotvm command" > gopath/doc/gotvm.html
@#echo "Run 'godoc -http=:6060 -goroot=./gopath' for documentation"
samples: all
cp gopath/pkg/linux_amd64/gotvm.a sample/ -rfa
make -C sample
tests: all
@(cd sample; python3 deploy.py)
@export GOPATH=$(GOPATH); \
export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \
export CGO_CXXFLAGS=$(CGO_CXXFLAGS); \
export CGO_CFLAGS=$(CGO_CFLAGS); \
export CGO_LDFLAGS=$(CGO_LDFLAGS); \
(cd $(GOPATHDIR) \
&& cp ../../../sample/deploy.so . \
&& go test -v)
clean:
@if [ -d $(GOPATHDIR) ] ; then \
export GOPATH=$(GOPATH); \
export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \
export CGO_CFLAGS=$(CGO_CFLAGS); \
export CGO_LDFLAGS=$(CGO_LDFLAGS); \
(cd $(GOPATHDIR) && go clean -cache); fi
@rm -rf gopath
@make -C sample clean
lint:
@(cd src; golint)
@python3 ${TVM_BASE}/dmlc-core/scripts/lint.py gotvm cpp src/*.cc
@python3 ${TVM_BASE}/dmlc-core/scripts/lint.py gotvm cpp src/*.h
# gotvm - Golang Frontend for TVM Runtime
This folder contain golang interface for TVM runtime. It brings TVM runtime to Golang.
- It enable c runtime api of tvm exposed to golang.
- It enables module loading (lib, graph and params) and inference operations.
## Installation
### Requirements
- go compiler (https://golang.org/) version 0.10 or above.
### Modules
- src
Module that generates golang package corresponding to the c runtime api exposed from tvm source tree.
This process build golang package _gotvm.a_
- samples
Sample golang reference application to inference through gotvm package.
### Build
Once the Requirements are installed
To build _gotvm_ package
```bash
make
```
To build and run internal tests
```bash
make tests
```
To build sample apps.
```bash
make samples
```
## Run
To Demonstrates sample TVM module compilation using python and deploy via golang.
```bash
./simple
```
To deploy a realtime module with lib, graph and param.
```bash
./complex
```
To demonstrate go function closure conversion to packed function handle.
```bash
./pack_func_convert
```
To demonstrate a packed function handle given as an argument.
```bash
pack_func_handle_arg
```
To register go function with runtime as a global function.
```bash
pack_func_register
```
To demonstrate function closure passed as argument to a function call.
```bash
./pack_func_closure_arg
```
To demonstrate function closure returned from a packed function.
```bash
./pack_func_closure_return
```
## Documentation
gotvm.go is documented with sufficient information about gotvm package.
A html version documentation can be accessed by running below command after building runtime.
```bash
godoc -http=:6060 -goroot=./gopath
```
After above command try http://127.0.0.1:6060 from any browser.
Also please refer to the sample applications under sample folder.
## Docker
Docker setup may need below additions for dependencies and environment preparation.
Please refer ```docker/install/ubuntu_install_golang.sh``` for the packages dependencies.
go compiler 1.10 on ubuntu doesn't install on standard path, hence an explicit export may be needed as shown below.
```bash
export PATH="/usr/lib/go-1.10/bin:$PATH"```
```
.PHONY: clean all
SOURCES=$(wildcard *.go)
EXECUTABLE=$(patsubst %.go, %, $(SOURCES))
all: $(EXECUTABLE)
@golint
@python3 deploy.py
%: %.o
@go tool link -linkmode external -extld "g++" -extldflags "-ldl" -o $@ $<
%.o: %.go
@go tool compile -pack -o $@ $<
clean:
@rm -f $(EXECUTABLE) *.so *.o *.a
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application deployment over tvm.
* \file complex.go
*/
package main
import (
"fmt"
"io/ioutil"
"math/rand"
"./gotvm"
"runtime"
)
// NNVM compiled model paths.
const (
modLib = "./mobilenet.so"
modJSON = "./mobilenet.json"
modParams = "./mobilenet.params"
)
// main
func main() {
defer runtime.GC()
// Welcome
fmt.Printf("TVM Version : v%v\n", gotvm.TVMVersion)
fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion)
// Query global functions available
funcNames, err := gotvm.FuncListGlobalNames()
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Global Functions:%v\n", funcNames)
// Import tvm module (so)
modp, err := gotvm.LoadModuleFromFile(modLib)
if err != nil {
fmt.Print(err)
fmt.Printf("Please copy tvm compiled modules here and update the sample.go accordingly.\n")
fmt.Printf("You may need to update modLib, modJSON, modParams, tshapeIn, tshapeOut\n")
return
}
fmt.Printf("Module Imported:%p\n", modp)
bytes, err := ioutil.ReadFile(modJSON)
if err != nil {
fmt.Print(err)
return
}
jsonStr := string(bytes)
// Load module on tvm runtime - call tvm.graph_runtime.create
funp, err := gotvm.GetGlobalFunction("tvm.graph_runtime.create")
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Calling tvm.graph_runtime.create\n")
// Call function
graphrt, err := funp.Invoke(jsonStr, modp, (int64)(gotvm.KDLCPU), (int64)(0))
if err != nil {
fmt.Print(err)
return
}
graphmod := graphrt.AsModule()
fmt.Printf("Graph runtime Created\n")
// Array allocation attributes
tshapeIn := []int64{1, 224, 224, 3}
tshapeOut := []int64{1, 1001}
// Allocate input Array
inX, err := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0))
if err != nil {
fmt.Print(err)
return
}
// Allocate output Array
out, err := gotvm.Empty(tshapeOut)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Input and Output Arrays allocated\n")
// Get module function from graph runtime : load_params
// Read params
bytes, err = ioutil.ReadFile(modParams)
if err != nil {
fmt.Print(err)
}
// Load Params
funp, err = graphmod.GetFunction("load_params")
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Func load_params:%p\n", funp)
// Call function
_, err = funp.Invoke(bytes)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Module params loaded\n")
// Set some data in input Array
inSlice := make([]float32, (244 * 244 * 3))
rand.Seed(10)
rand.Shuffle(len(inSlice), func(i, j int) {inSlice[i],
inSlice[j] = rand.Float32(),
rand.Float32() })
inX.CopyFrom(inSlice)
// Set Input
funp, err = graphmod.GetFunction("set_input")
if err != nil {
fmt.Print(err)
return
}
// Call function
_, err = funp.Invoke("input", inX)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Module input is set\n")
// Run
funp, err = graphmod.GetFunction("run")
if err != nil {
fmt.Print(err)
return
}
// Call function
_, err = funp.Invoke()
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Module Executed \n")
// Call runtime function get_output
funp, err = graphmod.GetFunction("get_output")
if err != nil {
fmt.Print(err)
return
}
// Call function
_, err = funp.Invoke(int64(0), out)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Got Module Output \n")
// Print results
outIntf, _ := out.AsSlice()
outSlice := outIntf.([]float32)
fmt.Printf("Result:%v\n", outSlice[:10])
}
"""
Get Started with TVM Go
=======================
"""
from __future__ import absolute_import, print_function
import tvm
import numpy as np
# Global declarations of environment.
tgt_host="llvm"
tgt="llvm"
######################################################################
# Describe the Computation
# ------------------------
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
######################################################################
# Schedule the Computation
# ------------------------
s = tvm.create_schedule(C.op)
######################################################################
# Compilation
# -----------
fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
######################################################################
# Save Compiled Module
# --------------------
from tvm.contrib import cc
from tvm.contrib import util
fadd.save("deploy.o")
cc.create_shared("deploy.so", ["deploy.o"])
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application to demonstrate go-closure given to a packed function argument.
* \file pack_func_closure_arg.go
*/
package main
import (
"fmt"
"./gotvm"
)
// sampleFunctionArg receives a Packed Function handle and calls it.
func sampleFunctionArg(args ...*gotvm.Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
// Call Packed Function
retVal, err = pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64())
return
}
// main
func main() {
// Not passing a function name implicitely
// picks the name from reflection as "main.sampleDunctionArg"
gotvm.RegisterFunction(sampleFunctionArg);
fmt.Printf("Registered: sampleFunctionArg\n")
// Get registered global function.
funp, err := gotvm.GetGlobalFunction("main.sampleFunctionArg")
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("GetGlobalFunction: main.sampleFunctionArg - Success\n")
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*gotvm.Value) (retVal interface{}, err error) {
for _, v := range args {
fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64())
}
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
// Call function
result, err := funp.Invoke(funccall, 30, 50)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Invoked sampleFunctionArg with function closure arg : Result:%v\n", result.AsInt64())
}
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application to demonstrate go-closure returned from a callback function.
* \file pack_func_closure_return.go
*/
package main
import (
"fmt"
"./gotvm"
)
// sampleFunctionCb returns a function closure which is embed as packed function in TVMValue.
func sampleFunctionCb(args ...*gotvm.Value) (retVal interface{}, err error) {
funccall := func (cargs ...*gotvm.Value) (fret interface{}, ferr error) {
for _, v := range cargs {
fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64())
}
val1 := cargs[0].AsInt64()
val2 := cargs[1].AsInt64()
fret = int64(val1+val2)
return
}
retVal = funccall
return
}
// main
func main() {
// Not passing a function name implicitely
// picks the name from reflection as "main.sampleDunctionCb"
gotvm.RegisterFunction(sampleFunctionCb);
fmt.Printf("Registered: sampleFunctionCb\n")
// Get registered global function
funp, err := gotvm.GetGlobalFunction("main.sampleFunctionCb")
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("GetGlobalFunction: main.sampleFunctionCb - Success\n")
// Call function
result, err := funp.Invoke()
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Invoked main.sampleFunctionCb via Function handle\n")
pfunc := result.AsFunction()
fmt.Printf("Function Handle received via Packed Function call:%T - %v \n", pfunc, pfunc)
pfuncRet, err := pfunc.Invoke(30, 40)
fmt.Printf("Invoked closure inside sampleFunctionCb result:%v\n", pfuncRet.AsInt64())
}
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application to demonstrate function conversion to packed function.
* \file pack_func_convert.go
*/
package main
import (
"fmt"
"./gotvm"
)
// sampleCb is a simple golang callback function like C = A + B.
func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) {
for _, v := range args {
fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64())
}
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
// main
func main() {
// Welcome
// Simple convert to a packed function
fhandle, err := gotvm.ConvertFunction(sampleCb)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Converted function\n")
retVal, err := fhandle.Invoke(10, 20)
fmt.Printf("Invoke Completed\n")
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Result:%v\n", retVal.AsInt64())
}
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application to demonstrate converted packed
* function handle passed to another packed function.
* \file pack_func_handle_arg.go
*/
package main
import (
"fmt"
"./gotvm"
)
// sampleCb is a simple golang callback function like C = A + B.
func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) {
for _, v := range args {
fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64())
}
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
// sampleFunctionArg receives a Packed Function handle and calls it.
func sampleFunctionArg(args ...*gotvm.Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
// Call Packed Function
retVal, err = pfunc.Invoke(args[1], args[2])
return
}
// main
func main() {
// Simple convert to a packed function
fhandle, err := gotvm.ConvertFunction(sampleCb)
if err != nil {
fmt.Print(err)
return
}
gotvm.RegisterFunction(sampleFunctionArg);
fmt.Printf("Registered: sampleFunctionArg\n")
funp, err := gotvm.GetGlobalFunction("main.sampleFunctionArg")
if err != nil {
fmt.Print(err)
return
}
retVal, err := funp.Invoke(fhandle, 10, 20)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Result:%v\n", retVal.AsInt64())
}
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application to demonstrate function register into TVM global functions.
* \file pack_func_register.go
*/
package main
import (
"fmt"
"./gotvm"
"strings"
)
// sampleCb is a simple golang callback function like C = A + B.
func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) {
for _, v := range args {
fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64())
}
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
// main
func main() {
// Register sampleCb with TVM packed function system and call and check Global Function List.
gotvm.RegisterFunction(sampleCb, "sampleCb");
// Query global functions available
funcNames, err := gotvm.FuncListGlobalNames()
if err != nil {
fmt.Print(err)
return
}
found := 0
for ii := range (funcNames) {
if strings.Compare(funcNames[ii], "sampleCb") == 0 {
found = 1
}
}
if found == 0 {
fmt.Printf("Function registerd but, not listed\n")
return
}
// Get "sampleCb" and verify the call.
funp, err := gotvm.GetGlobalFunction("sampleCb")
if err != nil {
fmt.Print(err)
return
}
// Call function
result, err := funp.Invoke((int64)(10), (int64)(20))
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("sampleCb result: %v\n", result.AsInt64())
}
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application deployment over tvm.
* \file simple.go
*/
package main
import (
"fmt"
"runtime"
"./gotvm"
"math/rand"
)
// NNVM compiled model paths.
const (
modLib = "./deploy.so"
)
// main
func main() {
// Welcome
defer runtime.GC()
fmt.Printf("TVM Version : v%v\n", gotvm.TVMVersion)
fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion)
// Import tvm module (so)
modp, _ := gotvm.LoadModuleFromFile(modLib)
fmt.Printf("Module Imported\n")
// Allocate Array for inputs and outputs.
// Allocation by explicit type and context.
tshapeIn := []int64{4}
inX, _ := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0))
// Default allocation on CPU
inY, _ := gotvm.Empty(tshapeIn, "float32")
// Default allocation to type "float32" and on CPU
out, _ := gotvm.Empty(tshapeIn)
fmt.Printf("Input and Output Arrays allocated\n")
// Fill Input Data : inX , inY
inXSlice := make([]float32, 4)
inYSlice := make([]float32, 4)
for i := range inXSlice {
inXSlice[i] = rand.Float32()
inYSlice[i] = rand.Float32()
}
// Copy the data on target memory through runtime CopyFrom api.
inX.CopyFrom(inXSlice)
inY.CopyFrom(inYSlice)
fmt.Printf("X: %v\n", inXSlice)
fmt.Printf("Y: %v\n", inYSlice)
// Get function "myadd"
funp, _ := modp.GetFunction("myadd")
// Call function
funp.Invoke(inX, inY, out)
fmt.Printf("Module function myadd executed\n")
// Get the output tensor as an interface holding a slice through runtime CopyTo api.
outSlice, _ := out.AsSlice()
// Print results
fmt.Printf("Result:%v\n", outSlice.([]float32))
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for TVMByteArray interface.
* \file bytearray.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"unsafe"
)
// ByteArray type wraps the TVMByteArray of C runtime API.
//
// This can be used to hold raw data like params of a model.
type ByteArray uintptr
// nativeCPtr returns the type freed unitptr for ByteArray.
func (tbytearray ByteArray) nativeCPtr() (retVal uintptr) {
retVal = (uintptr)(tbytearray)
return
}
// SetData is used to intialize ByteArray from a golang string object.
//
// This method initialize both data and data size of the underlaying object.
// This function handles freeing old data object if any before allocating new.
//
// `val` is the golang string object from which the ByteArray is initialized.
func (tbytearray ByteArray) setData(val string) {
bufPtr := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data
if bufPtr == (*_Ctype_char)(C.NULL) {
C.free(unsafe.Pointer(bufPtr))
}
((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data = C.CString(val)
((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).size = C.ulong(len(val))
}
// getData returns the golang byte slice corresponding to the ByteArray.
func (tbytearray ByteArray) getData() (retVal []byte) {
val := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data
blen := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).size
retVal = C.GoBytes(unsafe.Pointer(val), C.int(blen))
return
}
// newByteArray initilizes the native TVMByteArray object with given byte slice
//
//`val` is the golang byte array used to initialize.
//
// returns newly created ByteArray.
func newByteArray(val []byte) (retVal ByteArray) {
handle := ByteArray(C.malloc(C.sizeof_TVMByteArray))
((*C.TVMByteArray)(unsafe.Pointer(handle))).data = (*_Ctype_char)(C.NULL)
((*C.TVMByteArray)(unsafe.Pointer(handle))).size = 0
handle.setData(string(val))
retVal = handle
return
}
// deleteTVMByteArray releases the allocated native object of ByteArray.
//
// This delete handles freeing of underlaying native data object too.
func (tbytearray ByteArray) deleteTVMByteArray() {
bufPtr := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data
C.free(unsafe.Pointer(bufPtr))
C.free(unsafe.Pointer(tbytearray.nativeCPtr()))
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file bytearray_test.go
*/
package gotvm
import (
"testing"
"math/rand"
)
// Check ByteArray creation from byte slice and verify the data.
func TestByteArrayGet(t *testing.T) {
data := make([]byte, 1024)
rand.Read(data)
barr := newByteArray(data)
dataRet := barr.getData()
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v at : %v\n", data[i], dataRet[i], i)
return
}
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for TVMContext interface
* \file context.go
*/
package gotvm
//#include "gotvm.h"
import "C"
// KDLCPU is golang enum correspond to TVM device type kDLCPU.
var KDLCPU = int32(C.kDLCPU)
// KDLGPU is golang enum correspond to TVM device type kDLGPU.
var KDLGPU = int32(C.kDLGPU)
// KDLCPUPinned is golang enum correspond to TVM device type kDLCPUPinned.
var KDLCPUPinned = int32(C.kDLCPUPinned)
// KDLOpenCL is golang enum correspond to TVM device type kDLOpenCL.
var KDLOpenCL = int32(C.kDLOpenCL)
// KDLMetal is golang enum correspond to TVM device type kDLMetal.
var KDLMetal = int32(C.kDLMetal)
// KDLVPI is golang enum correspond to TVM device type kDLVPI.
var KDLVPI = int32(C.kDLVPI)
// KDLROCM is golang enum correspond to TVM device type kDLROCM.
var KDLROCM = int32(C.kDLROCM)
// KDLSDAccel is golang enum correspond to TVM device type kDLSDAccel.
var KDLSDAccel = int32(C.kDLSDAccel)
// KDLVulkan is golang enum correspond to TVM device type kDLVulkan.
var KDLVulkan = int32(C.kDLVulkan)
// KOpenGL is golang enum correspond to TVM device type kOpenGL.
var KOpenGL = int32(C.kOpenGL)
// KExtDev is golang enum correspond to TVM device type kDLExtDev.
var KExtDev = int32(C.kDLExtDev)
// Context dtype corresponding to TVMContext aka DLContext
type Context struct {
DeviceType int32
DeviceID int32
}
// CPU returns the Context object for CPU target on given index
func CPU(index int32) Context {
return Context{KDLCPU, index}
}
// GPU returns the Context object for GPU target on given index
func GPU(index int32) Context {
return Context{KDLGPU, index}
}
// CPUPinned returns the Context object for CPUPinned target on given index
func CPUPinned(index int32) Context {
return Context{KDLCPUPinned, index}
}
// OpenCL returns the Context object for OpenCL target on given index
func OpenCL(index int32) Context {
return Context{KDLOpenCL, index}
}
// Metal returns the Context object for Metal target on given index
func Metal(index int32) Context {
return Context{KDLMetal, index}
}
// VPI returns the Context object for VPI target on given index
func VPI(index int32) Context {
return Context{KDLVPI, index}
}
// ROCM returns the Context object for ROCM target on given index
func ROCM(index int32) Context {
return Context{KDLROCM, index}
}
// SDAccel returns the Context object for SDAccel target on given index
func SDAccel(index int32) Context {
return Context{KDLSDAccel, index}
}
// Vulkan returns the Context object for Vulkan target on given index
func Vulkan(index int32) Context {
return Context{KDLVulkan, index}
}
// OpenGL returns the Context object for OpenGL target on given index
func OpenGL(index int32) Context {
return Context{KOpenGL, index}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for error related API interface.
* \file error.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"unsafe"
)
// getTVMLastError returns the detailed error string for any api called in TVM runtime.
//
// This is useful when any api returns non zero value.
//
// Returns golang string for the corresponding native error message.
func getTVMLastError() (retVal string) {
errStr := C.TVMGetLastError()
retVal = C.GoString(errStr)
return
}
func setTVMLastError(errStr string) {
cstr := C.CString(errStr)
C.TVMAPISetLastError(cstr)
C.free(unsafe.Pointer(cstr))
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file error_test.go
*/
package gotvm
import (
"testing"
"strings"
)
// Check err receiving from TVM global function.
func TestErrorTest(t *testing.T) {
_, err := LoadModuleFromFile("dummy.so")
if err == nil {
t.Error("Expected an error, but not received\n")
return
}
errStr := err.Error()
if !(strings.Contains(errStr, string("cannot open shared object"))) {
t.Error("Ah! TVM didn't report an error\n")
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file function_test.go
*/
package gotvm
import (
"testing"
"reflect"
"math/rand"
"strings"
"fmt"
)
// Check global function list API
func TestFunctionGlobals(t *testing.T) {
funcNames, err := FuncListGlobalNames()
if err != nil {
t.Error(err.Error())
return
}
if len(funcNames) < 1 {
t.Errorf("Global Function names received:%v\n", funcNames)
}
}
// Check GetFunction API
func TestFunctionGlobalGet(t *testing.T) {
funp, err := GetGlobalFunction("tvm.graph_runtime.create")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(funp).Kind() != reflect.Ptr {
t.Error("Function type mis matched\n")
return
}
}
func TestFunctionModuleGet(t *testing.T) {
modp, err := LoadModuleFromFile("./deploy.so")
if err != nil {
t.Error(err.Error())
return
}
funp, err := modp.GetFunction("myadd")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(funp).Kind() != reflect.Ptr {
t.Error("Function type mis matched\n")
return
}
dlen := int64(1024)
shape := []int64{dlen}
inX, _ := Empty(shape)
inY, _ := Empty(shape)
out, _ := Empty(shape)
dataX := make([]float32, (dlen))
dataY := make([]float32, (dlen))
outExpected := make([]float32, (dlen))
for i := range dataX {
dataX[i] = rand.Float32()
dataY[i] = rand.Float32()
outExpected[i] = dataX[i] + dataY[i]
}
inX.CopyFrom(dataX)
inY.CopyFrom(dataY)
funp.Invoke(inX, inY, out)
outi, _ := out.AsSlice()
outSlice := outi.([]float32)
if len(outSlice) != len(outExpected) {
t.Errorf("Data expected Len: %v Got :%v\n", len(outExpected), len(outSlice))
return
}
for i := range outSlice {
if outExpected[i] != outSlice[i] {
t.Errorf("Data expected: %v Got :%v at index %v\n", outExpected[i], outSlice[i], i)
return
}
}
}
// Check FunctionConvert API
func TestFunctionConvert(t *testing.T) {
sampleCb := func (args ...*Value) (retVal interface{}, err error) {
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
fhandle, err := ConvertFunction(sampleCb)
if err != nil {
t.Error(err.Error())
return
}
retVal, err := fhandle.Invoke(10, 20)
if err != nil {
t.Error(err.Error())
return
}
if retVal.AsInt64() != int64(30) {
t.Errorf("Expected result :30 got:%v\n", retVal.AsInt64())
return
}
}
func TestFunctionError(t *testing.T) {
sampleCb := func (args ...*Value) (retVal interface{}, err error) {
err = fmt.Errorf("Sample Error XYZABC");
return
}
fhandle, err := ConvertFunction(sampleCb)
if err != nil {
t.Error(err.Error())
return
}
_, err = fhandle.Invoke()
if err == nil {
t.Error("Expected error but didn't received\n")
return
}
if !strings.Contains(err.Error(), string("Sample Error XYZABC")) {
t.Errorf("Expected Error should contain :\"Sample Error XYZABC\" got :%v\n", err.Error())
}
}
// Check FunctionRegister
func TestFunctionRegister(t *testing.T) {
sampleCb := func (args ...*Value) (retVal interface{}, err error) {
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
RegisterFunction(sampleCb, "TestFunctionRegister.sampleCb");
// Query global functions available
funcNames, err := FuncListGlobalNames()
if err != nil {
t.Error(err.Error())
return
}
found := 0
for ii := range (funcNames) {
if strings.Compare(funcNames[ii], "TestFunctionRegister.sampleCb") == 0 {
found = 1
}
}
if found == 0 {
t.Error("Registered function not found in global function list.")
return
}
// Get "sampleCb" and verify the call.
funp, err := GetGlobalFunction("TestFunctionRegister.sampleCb")
if err != nil {
t.Error(err.Error())
return
}
// Call function
result, err := funp.Invoke((int64)(10), (int64)(20))
if err != nil {
t.Error(err.Error())
return
}
if result.AsInt64() != int64(30) {
t.Errorf("Expected result :30 got:%v\n", result.AsInt64())
return
}
}
// Check packed function receiving go-closure as argument.
func TestFunctionClosureArg(t *testing.T) {
// sampleFunctionArg receives a Packed Function handle and calls it.
sampleFunctionArg := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
// Call Packed Function by Value
ret, err := pfunc.Invoke(args[1], args[2])
if err != nil {
return
}
// Call Packed Function with extracted values
ret1, err := pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64())
if err != nil {
return
}
if ret1.AsInt64() != ret.AsInt64() {
err = fmt.Errorf("Invoke with int64 didn't match with Value\n")
return
}
retVal = ret
return
}
RegisterFunction(sampleFunctionArg, "TestFunctionClosureArg.sampleFunctionArg");
funp, err := GetGlobalFunction("TestFunctionClosureArg.sampleFunctionArg")
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
// Call function
result, err := funp.Invoke(funccall, 30, 50)
if err != nil {
t.Error(err.Error())
return
}
if result.AsInt64() != int64(80) {
t.Errorf("Expected result :80 got:%v\n", result.AsInt64())
return
}
}
// Check packed function returning a go-closure.
func TestFunctionClosureReturn(t *testing.T) {
// sampleFunctionCb returns a function closure which is embed as packed function in TVMValue.
sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) {
funccall := func (cargs ...*Value) (fret interface{}, ferr error) {
val1 := cargs[0].AsInt64()
val2 := cargs[1].AsInt64()
fret = int64(val1+val2)
return
}
retVal = funccall
return
}
RegisterFunction(sampleFunctionCb, "TestFunctionClosureReturn.sampleFunctionCb");
funp, err := GetGlobalFunction("TestFunctionClosureReturn.sampleFunctionCb")
if err != nil {
t.Error(err.Error())
return
}
// Call function
result, err := funp.Invoke()
if err != nil {
t.Error(err.Error())
return
}
pfunc := result.AsFunction()
pfuncRet, err := pfunc.Invoke(30, 40)
if err != nil {
t.Error(err.Error())
return
}
if pfuncRet.AsInt64() != int64(70) {
t.Errorf("Expected result :70 got:%v\n", pfuncRet.AsInt64())
return
}
}
// Check packed function with no arguments and no return values.
func TestFunctionNoArgsReturns(t *testing.T) {
sampleFunction := func (args ...*Value) (retVal interface{}, err error) {
return
}
fhandle, err := ConvertFunction(sampleFunction)
if err != nil {
t.Error(err.Error())
return
}
_, err = fhandle.Invoke()
if err != nil {
t.Error(err.Error())
return
}
}
// Check packed function returning a go-closure with no arg and returns.
func TestFunctionNoArgsReturns2(t *testing.T) {
// sampleFunctionCb returns a function closure which is embed as packed function in TVMValue.
sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) {
funccall := func (cargs ...*Value) (fret interface{}, ferr error) {
return
}
retVal = funccall
return
}
funp, err := ConvertFunction(sampleFunctionCb)
if err != nil {
t.Error(err.Error())
return
}
// Call function
result, err := funp.Invoke()
if err != nil {
t.Error(err.Error())
return
}
pfunc := result.AsFunction()
_, err = pfunc.Invoke()
if err != nil {
t.Error(err.Error())
return
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm native interface definition
* \file gotvm.cxx
*/
// Standard includes
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <stdint.h>
// golang string compatible definition
typedef struct { char *p; int n; } _gostring_;
#include <string>
#ifdef __cplusplus
extern "C" {
#endif
// TVM runtime C interface
#include <tvm/runtime/c_runtime_api.h>
#include <dlpack/dlpack.h>
/*!
* \brief Convert native char array to _gostring_ structure.
* _gostring_ structure represents the same memory footprint as golang string object.
*
* \param p is char pointer to a char array.
* \param l is the size of the char array. this method exclusively need length as
* its possible to have a bytearray in a string.
*
* \return _gostring_ object corresponding to native char array.
* Caller is responsible to free the memory block allocated here.
*/
static _gostring_ _native_to_gostring(const char *p, size_t l) {
_gostring_ ret;
ret.p = reinterpret_cast<char*>(malloc(l));
if (NULL == ret.p) {
ret.n = 0;
return ret;
}
memcpy(ret.p, p, l);
ret.n = l;
return ret;
}
/*!
* \brief embeds a 64bit uint value inside a string to serialize the data.
*
* \param s is string object.
* \param off is the offset in the string object.
* \param v is the uint64_t value which need to embed into given string.
*/
static void putuint64(std::string *s, size_t off, uint64_t v) {
for (int i = 0; i < 8; i++) {
(*s)[off + i] = (v >> (i * 8)) & 0xff;
}
}
// TVM runtime C interface wrappers
/*!
* \brief Native interface to query TVM_VERSION in golang string format.
*
* \return char pointer to TVM-VERSION
*/
const char* _TVM_VERSION(void) {
const char *version = TVM_VERSION;
return version;
}
/*!
* \brief Native interface for getting TVMGlobal function list.
*
* \param names return by argument to return the function names.
* We wrap all strings into single string joined by (len+string)
* which is unpacked and processed in golang.
*
* \return c_runtime_api return status.
*/
int _TVMFuncListGlobalNames(_gostring_* names) {
int names_size;
char **names_array;
int result;
result = TVMFuncListGlobalNames(&names_size, (char const ***)&names_array);
if (result) {
return result;
}
size_t tot = 8;
for (int ii = 0; ii < names_size ; ++ii) {
tot += 8 + strlen(names_array[ii]);
}
std::string str;
str.resize(tot);
putuint64(&str, 0, names_size);
size_t off = 8;
for (int64_t ii = 0; ii < names_size ; ++ii) {
putuint64(&str, off, strlen(names_array[ii]));
off += 8;
str.replace(off, strlen(names_array[ii]), names_array[ii]);
off += strlen(names_array[ii]);
}
*names = _native_to_gostring(str.data(), str.size());
if (str.size() != names->n) {
TVMAPISetLastError("malloc failed during _native_to_gostring");
result = 1;
}
return result;
}
// Helpers for TVMValue
/*!
* \brief Native helper to copy TVMValue from golang slice to native array.
* this helper is need as underlying momory for golang slice is not continueous.
*
* \param to_ptr is the native pointer of TVMValue array.
* \param from_ptr pointer to TVMValue in golang slice.
* \param array index in native array.
*/
void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) {
TVMValue *from_p = reinterpret_cast<TVMValue*>(from_ptr);
TVMValue *to_p = reinterpret_cast<TVMValue*>(to_ptr);
memcpy(to_p+ind, from_p, sizeof(TVMValue));
}
/*!
* \brief Native helper to copy TVMValue from golang slice to native array.
* this helper is need as underlying momory for golang slice is not continueous.
*
* \param to_ptr pointer to TVMValue in golang slice.
* \param from_ptr is the native pointer of TVMValue array.
* \param array index in native array.
*/
void _TVMValueNativeGet(void* to_ptr, void* from_ptr, int ind) {
TVMValue *from_p = reinterpret_cast<TVMValue*>(from_ptr);
TVMValue *to_p = reinterpret_cast<TVMValue*>(to_ptr);
memcpy(to_p, from_p+ind, sizeof(TVMValue));
}
extern int goTVMCallback(void*, void*, int, void*, void*);
/*!
* \brief _TVMCallback is the TVM runtime callback function for PackedFunction system.
*
* \param args is an array of TVMValue
* \param type_codes is an array of int
* \param num_args is int representing number of in arguments
* \param ret is the return value handle to set the packed function return.
* \param resource_handle is the golang private data pointer.
*
* \returns the error status as TVM_DLL
*/
int _TVMCallback(TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* resource_handle) {
return goTVMCallback(args, type_codes, num_args, ret, resource_handle);
}
/*!
* _TVMPackedCFuncFinalizer is finalizer for packed function system.
*
*/
void _TVMPackedCFuncFinalizer(void* resource_handle) {
return;
}
/*!
* /brief _ConvertFunction creates a packed function for with given resource handle.
*
* /param fptr is the pointer to golang resource handle.
* /param *fhandle is the return argument holding packed function.
*
* /return is an int indicating the return status.
*/
int _ConvertFunction(void* fptr, TVMFunctionHandle *fhandle) {
int ret = TVMFuncCreateFromCFunc(_TVMCallback,
fptr,
_TVMPackedCFuncFinalizer,
fhandle);
return ret;
}
#ifdef __cplusplus
}
#endif
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file gotvm.go
*/
// Package gotvm is TVM runtime interface definition for golang.
//
// Application need to import this package to access the c_runtime_api exposed by TVM.
package gotvm
//#include "gotvm.h"
import "C"
// DLPackVersion is the dlpack version of tvm runtime.
var DLPackVersion = int(C.DLPACK_VERSION)
// TVMVersion is the TVM runtime version.
var TVMVersion = getTVMVersion()
func getTVMVersion() (retStr string) {
retStr = C.GoString(C._TVM_VERSION())
return
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm native interface declaration.
* \file gotvm.h
*
* These declarations are in cgo interface definition while calling API
* across golang and native C boundaries.
*/
#ifndef GOTVM_GOTVM_H_
#define GOTVM_GOTVM_H_
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <tvm/runtime/c_runtime_api.h>
#include <dlpack/dlpack.h>
// Some type definitions for golang "C"
typedef void* native_voidp;
// Version
extern char* _TVM_VERSION(void);
// Wrappers : For incompatible cgo API.
// To handle array of strings wrapped into __gostring__
extern int _TVMFuncListGlobalNames(void*);
// To handle TVMValue slice to/from native sequential TVMValue array.
extern void _TVMValueNativeSet(void* to, void* from, int index);
extern void _TVMValueNativeGet(void* to, void* from, int index);
// Callbacks
extern int _ConvertFunction(void* fptr, void* funp);
#ifdef __cplusplus
}
#endif
#endif // GOTVM_GOTVM_H_
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file gotvm_test.go
*/
package gotvm
import (
"testing"
"reflect"
)
// Check TVMVersion API
func TestTVMVersion(t *testing.T) {
if len(TVMVersion) == 0 {
t.Error("TVMVersion not set\n")
}
if reflect.TypeOf(TVMVersion).Kind() != reflect.String {
t.Error("TVMVersion type mismatch\n")
}
}
// Check DLPackVersion API
func TestDLPackVersion(t *testing.T) {
if reflect.TypeOf(DLPackVersion).Kind() != reflect.Int {
t.Error("TVMVersion type mismatch\n")
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for TVMModule interface.
* \file module.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"errors"
"runtime"
"unsafe"
)
// Module type in golang hold pointer for the TVMModule handle.
//
// Module initialization happen through TVMModLoadFromFile api in TVM runtime.
type Module uintptr
// nativeCPtr returns type freed uintptr for the Module.
func (tvmmodule *Module) nativeCPtr() (retVal uintptr) {
retVal = (uintptr)(*tvmmodule)
return
}
// LoadModuleFromFile loads the given module in TVM runtime.
//
// `modpath` is the path to tvm module.
//
// `args` is an optional arguments of ["dll", "dylib", "dso", "so"] with default value "so"
//
// returns pointer to Module and err or if any.
func LoadModuleFromFile(modpath string, args ...interface{}) (retVal *Module, err error) {
modtype := "so"
if len(args) > 0 {
modtype = args[0].(string)
}
var modp uintptr
cmodpath := C.CString(modpath)
cmodtype := C.CString(modtype)
ret := (int32)(C.TVMModLoadFromFile(cmodpath,
cmodtype,
(*_Ctype_TVMModuleHandle)(unsafe.Pointer(&modp))))
C.free(unsafe.Pointer(cmodpath))
C.free(unsafe.Pointer(cmodtype))
if ret != 0 {
err = errors.New(getTVMLastError())
return
}
handle := new(Module)
*handle = Module(modp)
finalizer := func(mhandle *Module) {
nativeTVMModFree(mhandle)
mhandle = nil
}
runtime.SetFinalizer(handle, finalizer)
retVal = handle
return
}
// nativeTVMModFree free the module handle allocated in TVM runtime.
//
// `modp` is the Module handle to be freed.
func nativeTVMModFree(modp *Module) (retVal int32) {
retVal = (int32) (C.TVMModFree(C.TVMModuleHandle(modp.nativeCPtr())))
return
}
// GetFunction returns the function pointer from the module for given function name.
//
// `tvmmodule` is handle for Module
//
// `funcname` function name in module.
//
// `args` variadic args of `queryImport`
//
// returns function closure with signature
// func (args ...interface{}) (interface{}, error) and error if any.
//
// The closure function can be used to call Function with arguments directly.
//
// Variadic arguments can be any type which can be embed into Value.
func (tvmmodule *Module) GetFunction (
funcname string, args ...interface{}) (
retVal *Function, err error){
queryImports := int32(1)
if len(args) > 0 {
queryImports = int32(args[1].(int))
}
var funp uintptr
cfuncname := C.CString(funcname)
ret := (int32)(C.TVMModGetFunction((_Ctype_TVMModuleHandle)(*tvmmodule),
cfuncname,
C.int(queryImports),
(*_Ctype_TVMFunctionHandle)(unsafe.Pointer(&funp))))
C.free(unsafe.Pointer(cfuncname))
if ret != 0 {
err = errors.New(getTVMLastError())
return
}
handle := new(Function)
*handle = Function(funp)
finalizer := func(fhandle *Function) {
nativeTVMFuncFree(fhandle)
fhandle = nil
}
runtime.SetFinalizer(handle, finalizer)
retVal = handle
return
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file module_test.go
*/
package gotvm
import (
"testing"
"reflect"
)
// Check module loading - dll
func TestModuleTestLoad1(t *testing.T) {
// dll
mod, err := LoadModuleFromFile("./deploy.so", "dll")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(mod).Kind() != reflect.Ptr {
t.Error("Module type mis matched\n")
return
}
}
// Check module loading - dylib
func TestModuleTestLoad2(t *testing.T) {
// dylib
mod, err := LoadModuleFromFile("./deploy.so", "dylib")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(mod).Kind() != reflect.Ptr {
t.Error("Module type mis matched\n")
return
}
}
func TestModuleTestLoad3(t *testing.T) {
// dso
mod, err := LoadModuleFromFile("./deploy.so", "dso")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(mod).Kind() != reflect.Ptr {
t.Error("Module type mis matched\n")
return
}
}
// Check module loading - so
func TestModuleTestLoad4(t *testing.T) {
// so
mod, err := LoadModuleFromFile("./deploy.so", "so")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(mod).Kind() != reflect.Ptr {
t.Error("Module type mis matched\n")
return
}
}
// Check module loading - default (so)
func TestModuleTestLoad5(t *testing.T) {
// default type as so
mod, err := LoadModuleFromFile("./deploy.so")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(mod).Kind() != reflect.Ptr {
t.Error("Module type mis matched\n")
return
}
}
// Check module loading err
func TestModuleTestLoadErr(t *testing.T) {
// Unknown file should return error
_, err := LoadModuleFromFile("xyzabc.so")
if err == nil {
t.Error("Expected an error, but not received\n")
return
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief This is an all in one TVM runtime file.
* \file tvm_runtime_pack.cc
*/
#include "src/runtime/c_runtime_api.cc"
#include "src/runtime/cpu_device_api.cc"
#include "src/runtime/workspace_pool.cc"
#include "src/runtime/module_util.cc"
#include "src/runtime/module.cc"
#include "src/runtime/registry.cc"
#include "src/runtime/file_util.cc"
#include "src/runtime/threading_backend.cc"
#include "src/runtime/thread_pool.cc"
#include "src/runtime/ndarray.cc"
// NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use.
// Likely we only need to enable one of the following
// If you use Module::Load, use dso_module
// For system packed library, use system_lib_module
#include "src/runtime/dso_module.cc"
#include "src/runtime/system_lib_module.cc"
// Graph runtime
#include "src/runtime/graph/graph_runtime.cc"
// Uncomment the following lines to enable RPC
// #include "../../src/runtime/rpc/rpc_session.cc"
// #include "../../src/runtime/rpc/rpc_event_impl.cc"
// #include "../../src/runtime/rpc/rpc_server_env.cc"
// These macros enables the device API when uncommented.
#define TVM_CUDA_RUNTIME 1
#define TVM_METAL_RUNTIME 1
#define TVM_OPENCL_RUNTIME 1
// Uncomment the following lines to enable Metal
// #include "../../src/runtime/metal/metal_device_api.mm"
// #include "../../src/runtime/metal/metal_module.mm"
// Uncomment the following lines to enable CUDA
// #include "../../src/runtime/cuda/cuda_device_api.cc"
// #include "../../src/runtime/cuda/cuda_module.cc"
// Uncomment the following lines to enable OpenCL
// #include "../../src/runtime/opencl/opencl_device_api.cc"
// #include "../../src/runtime/opencl/opencl_module.cc"
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package for TVMType interface
* \file type.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"fmt"
)
// pTVMType corresponding to data types.
type pTVMType struct {
code uint8
bits uint8
lanes uint16
}
// data type to pTVMType mapping
var dtypeMap = map[string] pTVMType {
"int8": pTVMType{0, 8, 1},
"int16": pTVMType{0, 16, 1},
"int32": pTVMType{0, 32, 1},
"int64": pTVMType{0, 64, 1},
"uint8": pTVMType{1, 8, 1},
"uint16": pTVMType{1, 16, 1},
"uint32": pTVMType{1, 32, 1},
"uint64": pTVMType{1, 64, 1},
"float32": pTVMType{2, 32, 1},
"float64": pTVMType{2, 64, 1},
}
// dtypeFromTVMType return the pTVMType corresponding to given dtype
//
// `dtype` string for the given data type.
func dtypeFromTVMType(tvmtype pTVMType) (retVal string, err error) {
for k, v := range dtypeMap {
if v.code == tvmtype.code && v.bits == tvmtype.bits && v.lanes == tvmtype.lanes {
retVal = k
return
}
}
err = fmt.Errorf("Cannot map TVMType:%v to dtype", tvmtype)
return
}
// dtypeToTVMType return the pTVMType corresponding to given dtype
//
// `dtype` string for the given data type.
func dtypeToTVMType(args ...interface{}) (tvmtype pTVMType, err error) {
dtype := args[0].(string)
lanes := 1
if len(args) == 2 {
lanes = args[1].(int)
}
for k, v := range dtypeMap {
if k == dtype {
tvmtype = v
tvmtype.lanes = uint16(lanes)
return
}
}
err = fmt.Errorf("Cannot map dtype:%v to TVMType", dtype)
return
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for common utilities
* \file util.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"unsafe"
)
// Native string map for go string
type nativeGoString struct { p uintptr; n int32 }
func goStringFromNative (s string) (retStr string) {
p := *(*nativeGoString)(unsafe.Pointer(&s))
retStr = string((*[0x7fffffff]byte)(unsafe.Pointer(p.p))[:p.n])
C.free(unsafe.Pointer(p.p))
return
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file value_test.go
*/
package gotvm
import (
"testing"
"math/rand"
"strings"
)
// Check Int64 Value looping via packed function calling another packed function.
func TestValueLoopInt64(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
newArgs := args[1:]
// Call Packed Function by Value
return pfunc.Invoke(newArgs)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0]
return
}
result := rand.Int63()
retVal, err := fhandle.Invoke(funccall, result)
if err != nil {
t.Error(err.Error())
return
}
if retVal.AsInt64() != result {
t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64())
return
}
}
// Check Int32 Value looping via packed function calling another packed function.
func TestValueLoopInt32(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
newArgs := args[1:]
// Call Packed Function by Value
return pfunc.Invoke(newArgs)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0]
return
}
result := rand.Int31()
retVal, err := fhandle.Invoke(funccall, result)
if err != nil {
t.Error(err.Error())
return
}
if retVal.AsInt64() != int64(result) {
t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64())
return
}
}
// Check Float32 Value looping via packed function calling another packed function.
func TestValueLoopFloat32(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
newArgs := args[1:]
// Call Packed Function by Value
return pfunc.Invoke(newArgs)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0]
return
}
result := rand.Float32()
retVal, err := fhandle.Invoke(funccall, result)
if err != nil {
t.Error(err.Error())
return
}
if retVal.AsFloat64() != float64(result) {
t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64())
return
}
}
// Check Float64 Value looping via packed function calling another packed function.
func TestValueLoopFloat64(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
newArgs := args[1:]
// Call Packed Function by Value
return pfunc.Invoke(newArgs)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0]
return
}
result := rand.Float64()
retVal, err := fhandle.Invoke(funccall, result)
if err != nil {
t.Error(err.Error())
return
}
if retVal.AsFloat64() != result {
t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64())
return
}
}
func TestValueLoopString(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
argStr := args[1].AsStr()
// Call Packed Function by Value
return pfunc.Invoke(argStr)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0].AsStr()
return
}
retVal, err := fhandle.Invoke(funccall, "TestString")
if err != nil {
t.Error(err.Error())
return
}
vStr := retVal.AsStr()
if strings.Compare(vStr, string("TestString")) != 0 {
t.Errorf("Expected : %v got:%v\n", string("TestString"), vStr)
return
}
}
// Check []byte Value looping via packed function calling another packed function.
func TestValueLoopByteSlice(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
argBytes := args[1].AsBytes()
// Call Packed Function by Value
return pfunc.Invoke(argBytes)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0].AsBytes()
return
}
result := make([]byte, 1024)
rand.Read(result)
retVal, err := fhandle.Invoke(funccall, result)
if err != nil {
t.Error(err.Error())
return
}
received := retVal.AsBytes()
if len(result) != len(received) {
t.Errorf("Data expected Len: %v Got :%v\n", len(result), len(received))
return
}
for i := range result {
if result[i] != received[i] {
t.Errorf("Data expected: %v Got :%v at index %v\n", result[i], received[i], i)
return
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment