Commit 1cb602f1 by Siva Committed by Tianqi Chen

[RUNTIME][GOLANG] TVM runtime for golang v0.1 (#1470)

parent 069aa381
.PHONY: clean all
TVM_BASE = $(CURDIR)/../
TARGET = gotvm
LIBS = -lm -ldl
NATIVE_SRC = tvm_runtime_pack.cc
GOPATH=$(CURDIR)/gopath
GOPATHDIR=${GOPATH}/src/${TARGET}/
CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/"
CGO_CXXFLAGS="-std=c++11"
CGO_CFLAGS="-I${TVM_BASE}"
CGO_LDFLAGS="-ldl -lm"
all:
@mkdir gopath 2>/dev/null || true
@mkdir gopath/src 2>/dev/null || true
@mkdir gopath/src/$(TARGET) 2>/dev/null || true
@cp src/$(TARGET).cc gopath/src/$(TARGET)
@cp src/$(TARGET).h gopath/src/$(TARGET)
@cp src/$(NATIVE_SRC) gopath/src/$(TARGET)
@cp src/*.go gopath/src/$(TARGET)
@export GOPATH=$(GOPATH); \
export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \
export CGO_CXXFLAGS=$(CGO_CXXFLAGS); \
export CGO_CFLAGS=$(CGO_CFLAGS); \
export CGO_LDFLAGS=$(CGO_LDFLAGS); \
(cd $(GOPATHDIR) && go clean -cache \
&& golint && go build -o $(TARGET).a \
&& go install)
@find . -name gotvm.a
@#mkdir gopath/doc 2>/dev/null || true
@#godoc -html -goroot gopath/ gotvm | grep -v "for documentation on the gotvm command" > gopath/doc/gotvm.html
@#echo "Run 'godoc -http=:6060 -goroot=./gopath' for documentation"
samples: all
cp gopath/pkg/linux_amd64/gotvm.a sample/ -rfa
make -C sample
tests: all
@(cd sample; python3 deploy.py)
@export GOPATH=$(GOPATH); \
export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \
export CGO_CXXFLAGS=$(CGO_CXXFLAGS); \
export CGO_CFLAGS=$(CGO_CFLAGS); \
export CGO_LDFLAGS=$(CGO_LDFLAGS); \
(cd $(GOPATHDIR) \
&& cp ../../../sample/deploy.so . \
&& go test -v)
clean:
@if [ -d $(GOPATHDIR) ] ; then \
export GOPATH=$(GOPATH); \
export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \
export CGO_CFLAGS=$(CGO_CFLAGS); \
export CGO_LDFLAGS=$(CGO_LDFLAGS); \
(cd $(GOPATHDIR) && go clean -cache); fi
@rm -rf gopath
@make -C sample clean
lint:
@(cd src; golint)
@python3 ${TVM_BASE}/dmlc-core/scripts/lint.py gotvm cpp src/*.cc
@python3 ${TVM_BASE}/dmlc-core/scripts/lint.py gotvm cpp src/*.h
# gotvm - Golang Frontend for TVM Runtime
This folder contain golang interface for TVM runtime. It brings TVM runtime to Golang.
- It enable c runtime api of tvm exposed to golang.
- It enables module loading (lib, graph and params) and inference operations.
## Installation
### Requirements
- go compiler (https://golang.org/) version 0.10 or above.
### Modules
- src
Module that generates golang package corresponding to the c runtime api exposed from tvm source tree.
This process build golang package _gotvm.a_
- samples
Sample golang reference application to inference through gotvm package.
### Build
Once the Requirements are installed
To build _gotvm_ package
```bash
make
```
To build and run internal tests
```bash
make tests
```
To build sample apps.
```bash
make samples
```
## Run
To Demonstrates sample TVM module compilation using python and deploy via golang.
```bash
./simple
```
To deploy a realtime module with lib, graph and param.
```bash
./complex
```
To demonstrate go function closure conversion to packed function handle.
```bash
./pack_func_convert
```
To demonstrate a packed function handle given as an argument.
```bash
pack_func_handle_arg
```
To register go function with runtime as a global function.
```bash
pack_func_register
```
To demonstrate function closure passed as argument to a function call.
```bash
./pack_func_closure_arg
```
To demonstrate function closure returned from a packed function.
```bash
./pack_func_closure_return
```
## Documentation
gotvm.go is documented with sufficient information about gotvm package.
A html version documentation can be accessed by running below command after building runtime.
```bash
godoc -http=:6060 -goroot=./gopath
```
After above command try http://127.0.0.1:6060 from any browser.
Also please refer to the sample applications under sample folder.
## Docker
Docker setup may need below additions for dependencies and environment preparation.
Please refer ```docker/install/ubuntu_install_golang.sh``` for the packages dependencies.
go compiler 1.10 on ubuntu doesn't install on standard path, hence an explicit export may be needed as shown below.
```bash
export PATH="/usr/lib/go-1.10/bin:$PATH"```
```
.PHONY: clean all
SOURCES=$(wildcard *.go)
EXECUTABLE=$(patsubst %.go, %, $(SOURCES))
all: $(EXECUTABLE)
@golint
@python3 deploy.py
%: %.o
@go tool link -linkmode external -extld "g++" -extldflags "-ldl" -o $@ $<
%.o: %.go
@go tool compile -pack -o $@ $<
clean:
@rm -f $(EXECUTABLE) *.so *.o *.a
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application deployment over tvm.
* \file complex.go
*/
package main
import (
"fmt"
"io/ioutil"
"math/rand"
"./gotvm"
"runtime"
)
// NNVM compiled model paths.
const (
modLib = "./mobilenet.so"
modJSON = "./mobilenet.json"
modParams = "./mobilenet.params"
)
// main
func main() {
defer runtime.GC()
// Welcome
fmt.Printf("TVM Version : v%v\n", gotvm.TVMVersion)
fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion)
// Query global functions available
funcNames, err := gotvm.FuncListGlobalNames()
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Global Functions:%v\n", funcNames)
// Import tvm module (so)
modp, err := gotvm.LoadModuleFromFile(modLib)
if err != nil {
fmt.Print(err)
fmt.Printf("Please copy tvm compiled modules here and update the sample.go accordingly.\n")
fmt.Printf("You may need to update modLib, modJSON, modParams, tshapeIn, tshapeOut\n")
return
}
fmt.Printf("Module Imported:%p\n", modp)
bytes, err := ioutil.ReadFile(modJSON)
if err != nil {
fmt.Print(err)
return
}
jsonStr := string(bytes)
// Load module on tvm runtime - call tvm.graph_runtime.create
funp, err := gotvm.GetGlobalFunction("tvm.graph_runtime.create")
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Calling tvm.graph_runtime.create\n")
// Call function
graphrt, err := funp.Invoke(jsonStr, modp, (int64)(gotvm.KDLCPU), (int64)(0))
if err != nil {
fmt.Print(err)
return
}
graphmod := graphrt.AsModule()
fmt.Printf("Graph runtime Created\n")
// Array allocation attributes
tshapeIn := []int64{1, 224, 224, 3}
tshapeOut := []int64{1, 1001}
// Allocate input Array
inX, err := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0))
if err != nil {
fmt.Print(err)
return
}
// Allocate output Array
out, err := gotvm.Empty(tshapeOut)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Input and Output Arrays allocated\n")
// Get module function from graph runtime : load_params
// Read params
bytes, err = ioutil.ReadFile(modParams)
if err != nil {
fmt.Print(err)
}
// Load Params
funp, err = graphmod.GetFunction("load_params")
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Func load_params:%p\n", funp)
// Call function
_, err = funp.Invoke(bytes)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Module params loaded\n")
// Set some data in input Array
inSlice := make([]float32, (244 * 244 * 3))
rand.Seed(10)
rand.Shuffle(len(inSlice), func(i, j int) {inSlice[i],
inSlice[j] = rand.Float32(),
rand.Float32() })
inX.CopyFrom(inSlice)
// Set Input
funp, err = graphmod.GetFunction("set_input")
if err != nil {
fmt.Print(err)
return
}
// Call function
_, err = funp.Invoke("input", inX)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Module input is set\n")
// Run
funp, err = graphmod.GetFunction("run")
if err != nil {
fmt.Print(err)
return
}
// Call function
_, err = funp.Invoke()
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Module Executed \n")
// Call runtime function get_output
funp, err = graphmod.GetFunction("get_output")
if err != nil {
fmt.Print(err)
return
}
// Call function
_, err = funp.Invoke(int64(0), out)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Got Module Output \n")
// Print results
outIntf, _ := out.AsSlice()
outSlice := outIntf.([]float32)
fmt.Printf("Result:%v\n", outSlice[:10])
}
"""
Get Started with TVM Go
=======================
"""
from __future__ import absolute_import, print_function
import tvm
import numpy as np
# Global declarations of environment.
tgt_host="llvm"
tgt="llvm"
######################################################################
# Describe the Computation
# ------------------------
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
######################################################################
# Schedule the Computation
# ------------------------
s = tvm.create_schedule(C.op)
######################################################################
# Compilation
# -----------
fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
######################################################################
# Save Compiled Module
# --------------------
from tvm.contrib import cc
from tvm.contrib import util
fadd.save("deploy.o")
cc.create_shared("deploy.so", ["deploy.o"])
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application to demonstrate go-closure given to a packed function argument.
* \file pack_func_closure_arg.go
*/
package main
import (
"fmt"
"./gotvm"
)
// sampleFunctionArg receives a Packed Function handle and calls it.
func sampleFunctionArg(args ...*gotvm.Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
// Call Packed Function
retVal, err = pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64())
return
}
// main
func main() {
// Not passing a function name implicitely
// picks the name from reflection as "main.sampleDunctionArg"
gotvm.RegisterFunction(sampleFunctionArg);
fmt.Printf("Registered: sampleFunctionArg\n")
// Get registered global function.
funp, err := gotvm.GetGlobalFunction("main.sampleFunctionArg")
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("GetGlobalFunction: main.sampleFunctionArg - Success\n")
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*gotvm.Value) (retVal interface{}, err error) {
for _, v := range args {
fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64())
}
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
// Call function
result, err := funp.Invoke(funccall, 30, 50)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Invoked sampleFunctionArg with function closure arg : Result:%v\n", result.AsInt64())
}
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application to demonstrate go-closure returned from a callback function.
* \file pack_func_closure_return.go
*/
package main
import (
"fmt"
"./gotvm"
)
// sampleFunctionCb returns a function closure which is embed as packed function in TVMValue.
func sampleFunctionCb(args ...*gotvm.Value) (retVal interface{}, err error) {
funccall := func (cargs ...*gotvm.Value) (fret interface{}, ferr error) {
for _, v := range cargs {
fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64())
}
val1 := cargs[0].AsInt64()
val2 := cargs[1].AsInt64()
fret = int64(val1+val2)
return
}
retVal = funccall
return
}
// main
func main() {
// Not passing a function name implicitely
// picks the name from reflection as "main.sampleDunctionCb"
gotvm.RegisterFunction(sampleFunctionCb);
fmt.Printf("Registered: sampleFunctionCb\n")
// Get registered global function
funp, err := gotvm.GetGlobalFunction("main.sampleFunctionCb")
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("GetGlobalFunction: main.sampleFunctionCb - Success\n")
// Call function
result, err := funp.Invoke()
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Invoked main.sampleFunctionCb via Function handle\n")
pfunc := result.AsFunction()
fmt.Printf("Function Handle received via Packed Function call:%T - %v \n", pfunc, pfunc)
pfuncRet, err := pfunc.Invoke(30, 40)
fmt.Printf("Invoked closure inside sampleFunctionCb result:%v\n", pfuncRet.AsInt64())
}
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application to demonstrate function conversion to packed function.
* \file pack_func_convert.go
*/
package main
import (
"fmt"
"./gotvm"
)
// sampleCb is a simple golang callback function like C = A + B.
func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) {
for _, v := range args {
fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64())
}
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
// main
func main() {
// Welcome
// Simple convert to a packed function
fhandle, err := gotvm.ConvertFunction(sampleCb)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Converted function\n")
retVal, err := fhandle.Invoke(10, 20)
fmt.Printf("Invoke Completed\n")
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Result:%v\n", retVal.AsInt64())
}
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application to demonstrate converted packed
* function handle passed to another packed function.
* \file pack_func_handle_arg.go
*/
package main
import (
"fmt"
"./gotvm"
)
// sampleCb is a simple golang callback function like C = A + B.
func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) {
for _, v := range args {
fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64())
}
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
// sampleFunctionArg receives a Packed Function handle and calls it.
func sampleFunctionArg(args ...*gotvm.Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
// Call Packed Function
retVal, err = pfunc.Invoke(args[1], args[2])
return
}
// main
func main() {
// Simple convert to a packed function
fhandle, err := gotvm.ConvertFunction(sampleCb)
if err != nil {
fmt.Print(err)
return
}
gotvm.RegisterFunction(sampleFunctionArg);
fmt.Printf("Registered: sampleFunctionArg\n")
funp, err := gotvm.GetGlobalFunction("main.sampleFunctionArg")
if err != nil {
fmt.Print(err)
return
}
retVal, err := funp.Invoke(fhandle, 10, 20)
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("Result:%v\n", retVal.AsInt64())
}
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application to demonstrate function register into TVM global functions.
* \file pack_func_register.go
*/
package main
import (
"fmt"
"./gotvm"
"strings"
)
// sampleCb is a simple golang callback function like C = A + B.
func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) {
for _, v := range args {
fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64())
}
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
// main
func main() {
// Register sampleCb with TVM packed function system and call and check Global Function List.
gotvm.RegisterFunction(sampleCb, "sampleCb");
// Query global functions available
funcNames, err := gotvm.FuncListGlobalNames()
if err != nil {
fmt.Print(err)
return
}
found := 0
for ii := range (funcNames) {
if strings.Compare(funcNames[ii], "sampleCb") == 0 {
found = 1
}
}
if found == 0 {
fmt.Printf("Function registerd but, not listed\n")
return
}
// Get "sampleCb" and verify the call.
funp, err := gotvm.GetGlobalFunction("sampleCb")
if err != nil {
fmt.Print(err)
return
}
// Call function
result, err := funp.Invoke((int64)(10), (int64)(20))
if err != nil {
fmt.Print(err)
return
}
fmt.Printf("sampleCb result: %v\n", result.AsInt64())
}
/*!
* Copyright (c) 2018 by Contributors
* \brief Sample golang application deployment over tvm.
* \file simple.go
*/
package main
import (
"fmt"
"runtime"
"./gotvm"
"math/rand"
)
// NNVM compiled model paths.
const (
modLib = "./deploy.so"
)
// main
func main() {
// Welcome
defer runtime.GC()
fmt.Printf("TVM Version : v%v\n", gotvm.TVMVersion)
fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion)
// Import tvm module (so)
modp, _ := gotvm.LoadModuleFromFile(modLib)
fmt.Printf("Module Imported\n")
// Allocate Array for inputs and outputs.
// Allocation by explicit type and context.
tshapeIn := []int64{4}
inX, _ := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0))
// Default allocation on CPU
inY, _ := gotvm.Empty(tshapeIn, "float32")
// Default allocation to type "float32" and on CPU
out, _ := gotvm.Empty(tshapeIn)
fmt.Printf("Input and Output Arrays allocated\n")
// Fill Input Data : inX , inY
inXSlice := make([]float32, 4)
inYSlice := make([]float32, 4)
for i := range inXSlice {
inXSlice[i] = rand.Float32()
inYSlice[i] = rand.Float32()
}
// Copy the data on target memory through runtime CopyFrom api.
inX.CopyFrom(inXSlice)
inY.CopyFrom(inYSlice)
fmt.Printf("X: %v\n", inXSlice)
fmt.Printf("Y: %v\n", inYSlice)
// Get function "myadd"
funp, _ := modp.GetFunction("myadd")
// Call function
funp.Invoke(inX, inY, out)
fmt.Printf("Module function myadd executed\n")
// Get the output tensor as an interface holding a slice through runtime CopyTo api.
outSlice, _ := out.AsSlice()
// Print results
fmt.Printf("Result:%v\n", outSlice.([]float32))
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file array_test.go
*/
package gotvm
import (
"testing"
"unsafe"
"math/rand"
)
// Create an array and check size.
func TestArrayCreateSize(t *testing.T) {
_, err := Empty([]int64{4})
if err != nil {
t.Error(err.Error())
return
}
_, err = Empty([]int64{4, 5, 6})
if err != nil {
t.Error(err.Error())
return
}
_, err = Empty([]int64{})
if err == nil {
t.Error("Expected err for empty Array created, but didn't got !!")
return
}
}
// Check array creation via various different arguments.
func TestArrayCreateArgs(t *testing.T) {
_, err := Empty([]int64{4, 2}, "float32", CPU(0))
if err != nil {
t.Error(err.Error())
return
}
_, err = Empty([]int64{4, 2}, "float32")
if err != nil {
t.Error(err.Error())
return
}
_, err = Empty([]int64{4, 2}, CPU(0))
if err != nil {
t.Error(err.Error())
return
}
_, err = Empty([]int64{4, 2}, CPU(0), "float32")
if err != nil {
t.Error(err.Error())
return
}
}
// Create an array and check the NDim.
func TestArrayNDim(t *testing.T) {
arr, err := Empty([]int64{4, 5, 6})
if err != nil {
t.Error(err.Error())
return
}
if 3 != arr.GetNdim() {
t.Errorf("GetNdim failed Expected: 3 Got :%v\n", arr.GetNdim())
return
}
}
// Create an array and check Shape.
func TestArrayShape(t *testing.T) {
arr, err := Empty([]int64{4, 5, 6})
if err != nil {
t.Error(err.Error())
return
}
shape := arr.GetShape()
if len(shape) != 3 {
t.Errorf("Shape slice expected: 3 Got :%v\n", len(shape))
return
}
if shape[0] != 4 || shape[1] != 5 || shape[2] != 6 {
t.Errorf("Shape values expected {4, 5, 6} Got : %v\n", shape);
return
}
}
// Create an array and check created Context.
func TestArrayCtx(t *testing.T) {
// TODO: Could some test cases for other targets
arr, err := Empty([]int64{4}, CPU(0))
if err != nil {
t.Error(err.Error())
return
}
ctx := arr.GetCtx()
if ctx.DeviceType != KDLCPU {
t.Errorf("Ctx DeviceType expected: %v Got :%v\n", KDLCPU, ctx.DeviceType)
return
}
if ctx.DeviceID != 0 {
t.Errorf("Ctx DeviceID expected: %v Got :%v\n", KDLCPU, ctx.DeviceID)
return
}
arr, err = Empty([]int64{4}, CPU(2))
if err != nil {
t.Error(err.Error())
return
}
ctx = arr.GetCtx()
if ctx.DeviceType != KDLCPU {
t.Errorf("Ctx DeviceType expected: %v Got :%v\n", KDLCPU, ctx.DeviceType)
return
}
if ctx.DeviceID != 2 {
t.Errorf("Ctx DeviceID expected: %v Got :%v\n", KDLCPU, ctx.DeviceID)
return
}
}
// Create array of different dtypes and check dtypes.
func TestArrayDType(t *testing.T) {
for _, dtype := range []string{"int8", "int16", "int32", "int64",
"uint8", "uint16", "uint32", "uint64",
"float32", "float64"} {
arr, err := Empty([]int64{4}, dtype)
if err != nil {
t.Error(err.Error())
return
}
if dtype != arr.GetDType() {
t.Errorf("Dtype expected: %v Got :%v\n", dtype, arr.GetDType())
return
}
}
}
// Copy Int8 data to created Array and verify.
func TestArrayCopySliceInt8(t *testing.T) {
dlen := int64(32)
arr, err := Empty([]int64{4, dlen/4}, "int8")
if err != nil {
t.Error(err.Error())
return
}
bdata := make([]byte, dlen)
rand.Read(bdata)
data := (*[1<<31]int8)(unsafe.Pointer(&bdata[0]))[:dlen:dlen]
err = arr.CopyFrom(data)
if err != nil {
t.Error(err.Error())
return
}
ret, err := arr.AsSlice()
if err != nil {
t.Error(err.Error())
return
}
switch ret.(type) {
case []int8:
default:
t.Errorf("Expected : %T but got :%T\n", data, ret)
return
}
dataRet := ret.([]int8)
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v\n", data, dataRet)
return
}
}
}
// Copy Int16 data to created Array and verify.
func TestArrayCopySliceInt16(t *testing.T) {
dlen := int64(32)
arr, err := Empty([]int64{4, dlen/4}, "int16")
if err != nil {
t.Error(err.Error())
return
}
bdata := make([]byte, dlen*2)
rand.Read(bdata)
data := (*[1<<31]int16)(unsafe.Pointer(&bdata[0]))[:dlen:dlen]
err = arr.CopyFrom(data)
if err != nil {
t.Error(err.Error())
return
}
ret, err := arr.AsSlice()
if err != nil {
t.Error(err.Error())
return
}
switch ret.(type) {
case []int16:
default:
t.Errorf("Expected : %T but got :%T\n", data, ret)
return
}
dataRet := ret.([]int16)
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v\n", data, dataRet)
return
}
}
}
// Copy Int32 data to created Array and verify.
func TestArrayCopySliceInt32(t *testing.T) {
dlen := int64(32)
arr, err := Empty([]int64{4, dlen/4}, "int32")
if err != nil {
t.Error(err.Error())
return
}
bdata := make([]byte, dlen*4)
rand.Read(bdata)
data := (*[1<<31]int32)(unsafe.Pointer(&bdata[0]))[:dlen:dlen]
err = arr.CopyFrom(data)
if err != nil {
t.Error(err.Error())
return
}
ret, err := arr.AsSlice()
if err != nil {
t.Error(err.Error())
return
}
switch ret.(type) {
case []int32:
default:
t.Errorf("Expected : %T but got :%T\n", data, ret)
return
}
dataRet := ret.([]int32)
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v\n", data, dataRet)
return
}
}
}
// Copy Int64 data to created Array and verify.
func TestArrayCopySliceInt64(t *testing.T) {
dlen := int64(32)
arr, err := Empty([]int64{4, dlen/4}, "int64")
if err != nil {
t.Error(err.Error())
return
}
bdata := make([]byte, dlen*8)
rand.Read(bdata)
data := (*[1<<31]int64)(unsafe.Pointer(&bdata[0]))[:dlen:dlen]
err = arr.CopyFrom(data)
if err != nil {
t.Error(err.Error())
return
}
ret, err := arr.AsSlice()
if err != nil {
t.Error(err.Error())
return
}
switch ret.(type) {
case []int64:
default:
t.Errorf("Expected : %T but got :%T\n", data, ret)
return
}
dataRet := ret.([]int64)
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v\n", data, dataRet)
return
}
}
}
// Copy UInt8 data to created Array and verify.
func TestArrayCopySliceUInt8(t *testing.T) {
dlen := int64(32)
arr, err := Empty([]int64{4, dlen/4}, "uint8")
if err != nil {
t.Error(err.Error())
return
}
bdata := make([]byte, dlen)
rand.Read(bdata)
data := (*[1<<31]uint8)(unsafe.Pointer(&bdata[0]))[:dlen:dlen]
err = arr.CopyFrom(data)
if err != nil {
t.Error(err.Error())
return
}
ret, err := arr.AsSlice()
if err != nil {
t.Error(err.Error())
return
}
switch ret.(type) {
case []uint8:
default:
t.Errorf("Expected : %T but got :%T\n", data, ret)
return
}
dataRet := ret.([]uint8)
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v\n", data, dataRet)
return
}
}
}
// Copy UInt16 data to created Array and verify.
func TestArrayCopySliceUInt16(t *testing.T) {
dlen := int64(32)
arr, err := Empty([]int64{4, dlen/4}, "uint16")
if err != nil {
t.Error(err.Error())
return
}
bdata := make([]byte, dlen*2)
rand.Read(bdata)
data := (*[1<<31]uint16)(unsafe.Pointer(&bdata[0]))[:dlen:dlen]
err = arr.CopyFrom(data)
if err != nil {
t.Error(err.Error())
return
}
ret, err := arr.AsSlice()
if err != nil {
t.Error(err.Error())
return
}
switch ret.(type) {
case []uint16:
default:
t.Errorf("Expected : %T but got :%T\n", data, ret)
return
}
dataRet := ret.([]uint16)
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v\n", data, dataRet)
return
}
}
}
// Copy UInt32 data to created Array and verify.
func TestArrayCopySliceUInt32(t *testing.T) {
dlen := int64(32)
arr, err := Empty([]int64{4, dlen/4}, "uint32")
if err != nil {
t.Error(err.Error())
return
}
bdata := make([]byte, dlen*4)
rand.Read(bdata)
data := (*[1<<31]uint32)(unsafe.Pointer(&bdata[0]))[:dlen:dlen]
err = arr.CopyFrom(data)
if err != nil {
t.Error(err.Error())
return
}
ret, err := arr.AsSlice()
if err != nil {
t.Error(err.Error())
return
}
switch ret.(type) {
case []uint32:
default:
t.Errorf("Expected : %T but got :%T\n", data, ret)
return
}
dataRet := ret.([]uint32)
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v\n", data, dataRet)
return
}
}
}
// Copy UInt64 data to created Array and verify.
func TestArrayCopySliceUInt64(t *testing.T) {
dlen := int64(32)
arr, err := Empty([]int64{4, dlen/4}, "uint64")
if err != nil {
t.Error(err.Error())
return
}
bdata := make([]byte, dlen*8)
rand.Read(bdata)
data := (*[1<<31]uint64)(unsafe.Pointer(&bdata[0]))[:dlen:dlen]
err = arr.CopyFrom(data)
if err != nil {
t.Error(err.Error())
return
}
ret, err := arr.AsSlice()
if err != nil {
t.Error(err.Error())
return
}
switch ret.(type) {
case []uint64:
default:
t.Errorf("Expected : %T but got :%T\n", data, ret)
return
}
dataRet := ret.([]uint64)
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v\n", data, dataRet)
return
}
}
}
// Copy Float32 data to created Array and verify.
func TestArrayCopySliceFloat32(t *testing.T) {
dlen := int64(32)
arr, err := Empty([]int64{4, dlen/4}, "float32")
if err != nil {
t.Error(err.Error())
return
}
data := make([]float32, dlen)
for i := range data {
data[i] = rand.Float32()
}
err = arr.CopyFrom(data)
if err != nil {
t.Error(err.Error())
return
}
ret, err := arr.AsSlice()
if err != nil {
t.Error(err.Error())
return
}
switch ret.(type) {
case []float32:
default:
t.Errorf("Expected : %T but got :%T\n", data, ret)
return
}
dataRet := ret.([]float32)
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v \nGot :%v \n", data, dataRet)
return
}
}
}
// Copy Float64 data to created Array and verify.
func TestArrayCopySliceFloat64(t *testing.T) {
dlen := int64(32)
arr, err := Empty([]int64{4, dlen/4}, "float64")
if err != nil {
t.Error(err.Error())
return
}
data := make([]float64, dlen)
for i := range data {
data[i] = rand.Float64()
}
err = arr.CopyFrom(data)
if err != nil {
t.Error(err.Error())
return
}
ret, err := arr.AsSlice()
if err != nil {
t.Error(err.Error())
return
}
switch ret.(type) {
case []float64:
default:
t.Errorf("Expected : %T but got :%T\n", data, ret)
return
}
dataRet := ret.([]float64)
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v\n", data, dataRet)
return
}
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for TVMByteArray interface.
* \file bytearray.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"unsafe"
)
// ByteArray type wraps the TVMByteArray of C runtime API.
//
// This can be used to hold raw data like params of a model.
type ByteArray uintptr
// nativeCPtr returns the type freed unitptr for ByteArray.
func (tbytearray ByteArray) nativeCPtr() (retVal uintptr) {
retVal = (uintptr)(tbytearray)
return
}
// SetData is used to intialize ByteArray from a golang string object.
//
// This method initialize both data and data size of the underlaying object.
// This function handles freeing old data object if any before allocating new.
//
// `val` is the golang string object from which the ByteArray is initialized.
func (tbytearray ByteArray) setData(val string) {
bufPtr := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data
if bufPtr == (*_Ctype_char)(C.NULL) {
C.free(unsafe.Pointer(bufPtr))
}
((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data = C.CString(val)
((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).size = C.ulong(len(val))
}
// getData returns the golang byte slice corresponding to the ByteArray.
func (tbytearray ByteArray) getData() (retVal []byte) {
val := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data
blen := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).size
retVal = C.GoBytes(unsafe.Pointer(val), C.int(blen))
return
}
// newByteArray initilizes the native TVMByteArray object with given byte slice
//
//`val` is the golang byte array used to initialize.
//
// returns newly created ByteArray.
func newByteArray(val []byte) (retVal ByteArray) {
handle := ByteArray(C.malloc(C.sizeof_TVMByteArray))
((*C.TVMByteArray)(unsafe.Pointer(handle))).data = (*_Ctype_char)(C.NULL)
((*C.TVMByteArray)(unsafe.Pointer(handle))).size = 0
handle.setData(string(val))
retVal = handle
return
}
// deleteTVMByteArray releases the allocated native object of ByteArray.
//
// This delete handles freeing of underlaying native data object too.
func (tbytearray ByteArray) deleteTVMByteArray() {
bufPtr := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data
C.free(unsafe.Pointer(bufPtr))
C.free(unsafe.Pointer(tbytearray.nativeCPtr()))
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file bytearray_test.go
*/
package gotvm
import (
"testing"
"math/rand"
)
// Check ByteArray creation from byte slice and verify the data.
func TestByteArrayGet(t *testing.T) {
data := make([]byte, 1024)
rand.Read(data)
barr := newByteArray(data)
dataRet := barr.getData()
if len(data) != len(dataRet) {
t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet))
return
}
for i := range data {
if data[i] != dataRet[i] {
t.Errorf("Data expected: %v Got :%v at : %v\n", data[i], dataRet[i], i)
return
}
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for TVMContext interface
* \file context.go
*/
package gotvm
//#include "gotvm.h"
import "C"
// KDLCPU is golang enum correspond to TVM device type kDLCPU.
var KDLCPU = int32(C.kDLCPU)
// KDLGPU is golang enum correspond to TVM device type kDLGPU.
var KDLGPU = int32(C.kDLGPU)
// KDLCPUPinned is golang enum correspond to TVM device type kDLCPUPinned.
var KDLCPUPinned = int32(C.kDLCPUPinned)
// KDLOpenCL is golang enum correspond to TVM device type kDLOpenCL.
var KDLOpenCL = int32(C.kDLOpenCL)
// KDLMetal is golang enum correspond to TVM device type kDLMetal.
var KDLMetal = int32(C.kDLMetal)
// KDLVPI is golang enum correspond to TVM device type kDLVPI.
var KDLVPI = int32(C.kDLVPI)
// KDLROCM is golang enum correspond to TVM device type kDLROCM.
var KDLROCM = int32(C.kDLROCM)
// KDLSDAccel is golang enum correspond to TVM device type kDLSDAccel.
var KDLSDAccel = int32(C.kDLSDAccel)
// KDLVulkan is golang enum correspond to TVM device type kDLVulkan.
var KDLVulkan = int32(C.kDLVulkan)
// KOpenGL is golang enum correspond to TVM device type kOpenGL.
var KOpenGL = int32(C.kOpenGL)
// KExtDev is golang enum correspond to TVM device type kDLExtDev.
var KExtDev = int32(C.kDLExtDev)
// Context dtype corresponding to TVMContext aka DLContext
type Context struct {
DeviceType int32
DeviceID int32
}
// CPU returns the Context object for CPU target on given index
func CPU(index int32) Context {
return Context{KDLCPU, index}
}
// GPU returns the Context object for GPU target on given index
func GPU(index int32) Context {
return Context{KDLGPU, index}
}
// CPUPinned returns the Context object for CPUPinned target on given index
func CPUPinned(index int32) Context {
return Context{KDLCPUPinned, index}
}
// OpenCL returns the Context object for OpenCL target on given index
func OpenCL(index int32) Context {
return Context{KDLOpenCL, index}
}
// Metal returns the Context object for Metal target on given index
func Metal(index int32) Context {
return Context{KDLMetal, index}
}
// VPI returns the Context object for VPI target on given index
func VPI(index int32) Context {
return Context{KDLVPI, index}
}
// ROCM returns the Context object for ROCM target on given index
func ROCM(index int32) Context {
return Context{KDLROCM, index}
}
// SDAccel returns the Context object for SDAccel target on given index
func SDAccel(index int32) Context {
return Context{KDLSDAccel, index}
}
// Vulkan returns the Context object for Vulkan target on given index
func Vulkan(index int32) Context {
return Context{KDLVulkan, index}
}
// OpenGL returns the Context object for OpenGL target on given index
func OpenGL(index int32) Context {
return Context{KOpenGL, index}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for error related API interface.
* \file error.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"unsafe"
)
// getTVMLastError returns the detailed error string for any api called in TVM runtime.
//
// This is useful when any api returns non zero value.
//
// Returns golang string for the corresponding native error message.
func getTVMLastError() (retVal string) {
errStr := C.TVMGetLastError()
retVal = C.GoString(errStr)
return
}
func setTVMLastError(errStr string) {
cstr := C.CString(errStr)
C.TVMAPISetLastError(cstr)
C.free(unsafe.Pointer(cstr))
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file error_test.go
*/
package gotvm
import (
"testing"
"strings"
)
// Check err receiving from TVM global function.
func TestErrorTest(t *testing.T) {
_, err := LoadModuleFromFile("dummy.so")
if err == nil {
t.Error("Expected an error, but not received\n")
return
}
errStr := err.Error()
if !(strings.Contains(errStr, string("cannot open shared object"))) {
t.Error("Ah! TVM didn't report an error\n")
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for TVMFunction interface.
* \file function.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"unsafe"
"encoding/binary"
"errors"
"runtime"
"reflect"
"fmt"
)
// Function type in golang hold pointer for the TVMFunction handle.
type Function uintptr
// nativeCPtr returns type freed uintptr for the Function.
func (tvmfunction Function) nativeCPtr() (retVal uintptr) {
retVal = (uintptr)(tvmfunction)
return
}
// Invoke calls the TVM packed function referred by the handle with given arguments.
func (tvmfunction *Function) Invoke(args ...interface{}) (retVal *Value, err error) {
funccall := func (fargs ...interface{}) (*Value, error) {
return callNativeFunction(tvmfunction, fargs)
}
// Check is any args are contain any ValueArray
// Possible is it's a args forward from one packed function to another.
valueArrayFound := false
for ii := range args {
switch args[ii].(type) {
case []*Value:
valueArrayFound = true
}
}
if !valueArrayFound {
return funccall(args...)
}
if len(args) != 1 {
err = fmt.Errorf("Not supported if packed function args are a mix of []Value and other types")
return
}
valArray := args[0].([]*Value)
if len(valArray) > 0 {
newArgs := make([]interface{}, len(valArray))
for ii := range valArray {
newVal := newTVMValue()
newVal.moveFrom(valArray[ii])
newArgs[ii] = newVal
}
return funccall(newArgs...)
}
return funccall()
}
// FuncListGlobalNames is used to query global callable packed function names from TVM.
//
// returns slice of string holding function names and error if any.
func FuncListGlobalNames() (retVal []string, err error) {
var str string
ret := (int32)(C._TVMFuncListGlobalNames(unsafe.Pointer((&str))))
if ret != 0 {
err = errors.New(getTVMLastError())
return
}
str = goStringFromNative(*(*string)(unsafe.Pointer(&str)))
bin := binary.LittleEndian
size := bin.Uint64([]byte(str[:8]))
str = str[8:]
retVal = make([]string, size)
for i := range retVal {
len := bin.Uint64([]byte(str[:8]))
str = str[8:]
retVal[i] = str[:len]
str = str[len:]
}
return
}
// GetGlobalFunction is to get handle to the given global function name.
//
// `funcname` is the name of global packed function.
//
// returns a function closure with signature
// func (args ...interface{}) (interface{}, error) and error if any.
//
// The closure function can be used to call Function with arguments directly.
//
// Variadic arguments can be any type which can be embed into Value.
func GetGlobalFunction(funcname string) (retVal *Function, err error) {
var funp uintptr
cfuncname := C.CString(funcname)
ret := (int32)(C.TVMFuncGetGlobal(cfuncname,
(*_Ctype_TVMFunctionHandle)(unsafe.Pointer(&funp))))
C.free(unsafe.Pointer(cfuncname))
if ret != 0 {
err = errors.New(getTVMLastError())
return
}
handle := new(Function)
*handle = Function(funp)
finalizer := func(fhandle *Function) {
nativeTVMFuncFree(fhandle)
fhandle = nil
}
runtime.SetFinalizer(handle, finalizer)
retVal = handle
return
}
// callNativeFunction is routine which calls gotvm native wrapper with given arguments.
//
// `handle` is the handle for Function.
//
// `args` are the variadic arguments to the Function.
//
// returns the interface for the return value from TVM if any and error if any.
func callNativeFunction(handle *Function, args []interface{}) (retVal *Value, err error) {
argsIn := make([]*Value, len(args))
var typeCodes []int32
if len(args) != 0 {
typeCodes = make([]int32, len(args))
} else {
typeCodes = make([]int32, 1)
}
for ii := range args {
argsIn[ii] = newTVMValue()
if typeCodes[ii], err = argsIn[ii].setValue(args[ii]); err != nil {
return
}
}
retVal = newTVMValue()
argsOut := []*Value{retVal}
retTypeCode := KNull
err = nativeTVMFuncCall(handle, argsIn, typeCodes, argsOut, &retTypeCode)
if err != nil {
retVal = nil
return
}
retVal.isLocal = false
retVal.dtype = retTypeCode
return
}
// nativeTVMFuncFree free the function handle allocated in TVM runtime.
//
// `funp` is the Function handle to be freed.
func nativeTVMFuncFree(funp *Function) (retVal int32) {
retVal = (int32) (C.TVMFuncFree(C.TVMFunctionHandle(funp.nativeCPtr())))
return
}
// nativeToGoSlice converts native TVMValue array to Golang slice of TVMValue
//
//
func nativeToGoSlice(nargValues (*C.void), argValues []*Value, typeCodes []int32) {
for ii := range argValues {
C._TVMValueNativeGet(unsafe.Pointer(argValues[ii].nativeCPtr()),
unsafe.Pointer(nargValues),
C.int(int32(ii)))
argValues[ii].dtype = typeCodes[ii]
}
}
// nativeFromGoSlice converts golang slice of TVMValue to native TVMValue array.
//
//
func nativeFromGoSlice(argValues []*Value) (nptr (*C.void)) {
nargValues := ((uintptr)(C.malloc(C.ulong(C.sizeof_TVMValue * len(argValues)))))
for ii := range argValues {
C._TVMValueNativeSet(unsafe.Pointer(nargValues),
unsafe.Pointer(argValues[ii].nativeCPtr()),
C.int(int32(ii)))
}
nptr = (*C.void)(unsafe.Pointer(nargValues))
return
}
// nativeTVMFuncCall executes the function with given arguments
//
// `funp` Function handle to the packed function.
//
// `argValues` is the slice of Value which are arguments to the packed function.
//
// `typeCodes` is the alice of argument type codes corresponding to argValues.
//
// `retValues` is return argument which is slice of return values from the packed function.
//
// `retTypeCode` is int32 holding type codes for retValue
//
// Returns err indicating native error if any.
func nativeTVMFuncCall(funp *Function, argValues []*Value, typeCodes []int32,
retValues []*Value, retTypeCode *int32) (err error) {
nargValues := nativeFromGoSlice(argValues)
nretValues := nativeFromGoSlice(retValues)
result := (int32)(C.TVMFuncCall(_Ctype_TVMFunctionHandle(*funp),
(*_Ctype_TVMValue)(unsafe.Pointer(nargValues)),
(*_Ctype_int)(unsafe.Pointer(&(typeCodes[0]))),
C.int(len(argValues)),
(*_Ctype_TVMValue)(unsafe.Pointer(nretValues)),
(*_Ctype_int)(unsafe.Pointer(retTypeCode))))
nativeToGoSlice(nargValues, argValues, typeCodes)
nativeToGoSlice(nretValues, retValues, (*[1<<31] int32)(unsafe.Pointer(retTypeCode))[:1:1])
C.free(unsafe.Pointer(nargValues))
C.free(unsafe.Pointer(nretValues))
if result != 0 {
err = errors.New(getTVMLastError())
}
return
}
// goCallBack is a structure holding the go callback function pointer.
// This wrapping is necessary as cgo doesn't support
// passing golang functions type conversion to native.
type goCallBack struct {
cb func (args ...*Value) (interface{}, error)
}
//export goTVMCallback
func goTVMCallback(args C.native_voidp, typeCodes C.native_voidp, numArgs int32,
retArg C.native_voidp, resourceHandle C.native_voidp) (ret int32){
fcb := (*goCallBack)(resourceHandle)
// Make Value Sice from native TVMValue pointer.
argValues := make([]*Value, numArgs)
for ii := range argValues {
argValues[ii] = newTVMValue()
argValues[ii].isLocal = false
}
// Prepare arguments for golang callback function
nativeToGoSlice((*C.void)(unsafe.Pointer(args)), argValues,
(*[1<<31] int32)(unsafe.Pointer(typeCodes))[:numArgs:numArgs])
cbargs := argValues
// Execute the callback
retVal, err := fcb.cb(cbargs...)
if err != nil {
errStr := err.Error()
setTVMLastError(errStr)
return -1
}
// It's possible a packed function directly return
// the return value of another packed function.
//
// Inside a packed func :
// ```return pfunc.Invoke(args)```
//
// In this case pfunc returns nil which is
// returned as an interface holding nil *Value.
// Which becomes a valid retVal holding nil *Value.
isRetNull := false
switch retVal.(type) {
case *Value:
pRet := retVal.(*Value)
if pRet == nil {
isRetNull = true
}
}
// Handle return value from callback function
if retVal != nil && !isRetNull {
var retTypeCode int32
retValues := []*Value{newTVMValue()}
retTypeCode, err = retValues[0].setValue(retVal)
if err != nil {
errStr := err.Error()
setTVMLastError(errStr)
return -1
}
nretValues := nativeFromGoSlice(retValues)
// Handle KStr, KBytes: Local finalizers shouldn't try freeing them.
retValues[0].isLocal = false
apiRet := (int32) (C.TVMCFuncSetReturn(_Ctype_TVMRetValueHandle(retArg),
(*_Ctype_TVMValue)(unsafe.Pointer(nretValues)),
(*_Ctype_int)(unsafe.Pointer(&retTypeCode)), 1))
C.free(unsafe.Pointer(nretValues))
if apiRet != 0 {
errStr := string("TVMCFuncSetReturn failed ")
setTVMLastError(errStr)
}
}
return
}
// ConvertFunction converts given golang function to TVM packed function.
//
// `args[0]` function pointer for a type ```func (args ...interface{}) (interface{})```
//
// Returns Function handle and err if any.
func ConvertFunction(args ...interface{}) (retVal *Function, err error) {
function := args[0].(func (args ...*Value) (interface{}, error))
fcb := &goCallBack{cb:function}
var funp uintptr
result := (int32) (C._ConvertFunction(unsafe.Pointer(fcb),
unsafe.Pointer(&funp)))
if result != 0 {
err = errors.New(getTVMLastError())
}
handle := new(Function)
*handle = Function(funp)
finalizer := func(fhandle *Function) {
nativeTVMFuncFree(fhandle)
fhandle = nil
}
runtime.SetFinalizer(handle, finalizer)
retVal = handle
return
}
// RegisterFunction registers the golang func in TVM runtime global space.
//
// `args[0]` function pointer for a type ```func (args ...interface{}) (interface{})```
//
// `args[1]` Optional argument of function name with which it will be registered.
// If not passed we use function name from reflection.
//
// Returns err indicating native error if any.
func RegisterFunction(args ...interface{}) (err error) {
fhandle, err := ConvertFunction(args...)
if err != nil {
return
}
funcname := runtime.FuncForPC(reflect.ValueOf(args[0]).Pointer()).Name()
if len(args) > 1 {
funcname = args[1].(string)
}
cfuncname := C.CString(funcname)
result := (int32) (C.TVMFuncRegisterGlobal(cfuncname,
_Ctype_TVMFunctionHandle(*fhandle),
0)); // Override = False
C.free(unsafe.Pointer(cfuncname))
if result != 0 {
err = errors.New(getTVMLastError())
}
// Clear the finalizer as we don't need to control it anymore.
runtime.SetFinalizer(fhandle, nil)
return
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file function_test.go
*/
package gotvm
import (
"testing"
"reflect"
"math/rand"
"strings"
"fmt"
)
// Check global function list API
func TestFunctionGlobals(t *testing.T) {
funcNames, err := FuncListGlobalNames()
if err != nil {
t.Error(err.Error())
return
}
if len(funcNames) < 1 {
t.Errorf("Global Function names received:%v\n", funcNames)
}
}
// Check GetFunction API
func TestFunctionGlobalGet(t *testing.T) {
funp, err := GetGlobalFunction("tvm.graph_runtime.create")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(funp).Kind() != reflect.Ptr {
t.Error("Function type mis matched\n")
return
}
}
func TestFunctionModuleGet(t *testing.T) {
modp, err := LoadModuleFromFile("./deploy.so")
if err != nil {
t.Error(err.Error())
return
}
funp, err := modp.GetFunction("myadd")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(funp).Kind() != reflect.Ptr {
t.Error("Function type mis matched\n")
return
}
dlen := int64(1024)
shape := []int64{dlen}
inX, _ := Empty(shape)
inY, _ := Empty(shape)
out, _ := Empty(shape)
dataX := make([]float32, (dlen))
dataY := make([]float32, (dlen))
outExpected := make([]float32, (dlen))
for i := range dataX {
dataX[i] = rand.Float32()
dataY[i] = rand.Float32()
outExpected[i] = dataX[i] + dataY[i]
}
inX.CopyFrom(dataX)
inY.CopyFrom(dataY)
funp.Invoke(inX, inY, out)
outi, _ := out.AsSlice()
outSlice := outi.([]float32)
if len(outSlice) != len(outExpected) {
t.Errorf("Data expected Len: %v Got :%v\n", len(outExpected), len(outSlice))
return
}
for i := range outSlice {
if outExpected[i] != outSlice[i] {
t.Errorf("Data expected: %v Got :%v at index %v\n", outExpected[i], outSlice[i], i)
return
}
}
}
// Check FunctionConvert API
func TestFunctionConvert(t *testing.T) {
sampleCb := func (args ...*Value) (retVal interface{}, err error) {
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
fhandle, err := ConvertFunction(sampleCb)
if err != nil {
t.Error(err.Error())
return
}
retVal, err := fhandle.Invoke(10, 20)
if err != nil {
t.Error(err.Error())
return
}
if retVal.AsInt64() != int64(30) {
t.Errorf("Expected result :30 got:%v\n", retVal.AsInt64())
return
}
}
func TestFunctionError(t *testing.T) {
sampleCb := func (args ...*Value) (retVal interface{}, err error) {
err = fmt.Errorf("Sample Error XYZABC");
return
}
fhandle, err := ConvertFunction(sampleCb)
if err != nil {
t.Error(err.Error())
return
}
_, err = fhandle.Invoke()
if err == nil {
t.Error("Expected error but didn't received\n")
return
}
if !strings.Contains(err.Error(), string("Sample Error XYZABC")) {
t.Errorf("Expected Error should contain :\"Sample Error XYZABC\" got :%v\n", err.Error())
}
}
// Check FunctionRegister
func TestFunctionRegister(t *testing.T) {
sampleCb := func (args ...*Value) (retVal interface{}, err error) {
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
RegisterFunction(sampleCb, "TestFunctionRegister.sampleCb");
// Query global functions available
funcNames, err := FuncListGlobalNames()
if err != nil {
t.Error(err.Error())
return
}
found := 0
for ii := range (funcNames) {
if strings.Compare(funcNames[ii], "TestFunctionRegister.sampleCb") == 0 {
found = 1
}
}
if found == 0 {
t.Error("Registered function not found in global function list.")
return
}
// Get "sampleCb" and verify the call.
funp, err := GetGlobalFunction("TestFunctionRegister.sampleCb")
if err != nil {
t.Error(err.Error())
return
}
// Call function
result, err := funp.Invoke((int64)(10), (int64)(20))
if err != nil {
t.Error(err.Error())
return
}
if result.AsInt64() != int64(30) {
t.Errorf("Expected result :30 got:%v\n", result.AsInt64())
return
}
}
// Check packed function receiving go-closure as argument.
func TestFunctionClosureArg(t *testing.T) {
// sampleFunctionArg receives a Packed Function handle and calls it.
sampleFunctionArg := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
// Call Packed Function by Value
ret, err := pfunc.Invoke(args[1], args[2])
if err != nil {
return
}
// Call Packed Function with extracted values
ret1, err := pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64())
if err != nil {
return
}
if ret1.AsInt64() != ret.AsInt64() {
err = fmt.Errorf("Invoke with int64 didn't match with Value\n")
return
}
retVal = ret
return
}
RegisterFunction(sampleFunctionArg, "TestFunctionClosureArg.sampleFunctionArg");
funp, err := GetGlobalFunction("TestFunctionClosureArg.sampleFunctionArg")
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
val1 := args[0].AsInt64()
val2 := args[1].AsInt64()
retVal = int64(val1+val2)
return
}
// Call function
result, err := funp.Invoke(funccall, 30, 50)
if err != nil {
t.Error(err.Error())
return
}
if result.AsInt64() != int64(80) {
t.Errorf("Expected result :80 got:%v\n", result.AsInt64())
return
}
}
// Check packed function returning a go-closure.
func TestFunctionClosureReturn(t *testing.T) {
// sampleFunctionCb returns a function closure which is embed as packed function in TVMValue.
sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) {
funccall := func (cargs ...*Value) (fret interface{}, ferr error) {
val1 := cargs[0].AsInt64()
val2 := cargs[1].AsInt64()
fret = int64(val1+val2)
return
}
retVal = funccall
return
}
RegisterFunction(sampleFunctionCb, "TestFunctionClosureReturn.sampleFunctionCb");
funp, err := GetGlobalFunction("TestFunctionClosureReturn.sampleFunctionCb")
if err != nil {
t.Error(err.Error())
return
}
// Call function
result, err := funp.Invoke()
if err != nil {
t.Error(err.Error())
return
}
pfunc := result.AsFunction()
pfuncRet, err := pfunc.Invoke(30, 40)
if err != nil {
t.Error(err.Error())
return
}
if pfuncRet.AsInt64() != int64(70) {
t.Errorf("Expected result :70 got:%v\n", pfuncRet.AsInt64())
return
}
}
// Check packed function with no arguments and no return values.
func TestFunctionNoArgsReturns(t *testing.T) {
sampleFunction := func (args ...*Value) (retVal interface{}, err error) {
return
}
fhandle, err := ConvertFunction(sampleFunction)
if err != nil {
t.Error(err.Error())
return
}
_, err = fhandle.Invoke()
if err != nil {
t.Error(err.Error())
return
}
}
// Check packed function returning a go-closure with no arg and returns.
func TestFunctionNoArgsReturns2(t *testing.T) {
// sampleFunctionCb returns a function closure which is embed as packed function in TVMValue.
sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) {
funccall := func (cargs ...*Value) (fret interface{}, ferr error) {
return
}
retVal = funccall
return
}
funp, err := ConvertFunction(sampleFunctionCb)
if err != nil {
t.Error(err.Error())
return
}
// Call function
result, err := funp.Invoke()
if err != nil {
t.Error(err.Error())
return
}
pfunc := result.AsFunction()
_, err = pfunc.Invoke()
if err != nil {
t.Error(err.Error())
return
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm native interface definition
* \file gotvm.cxx
*/
// Standard includes
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <stdint.h>
// golang string compatible definition
typedef struct { char *p; int n; } _gostring_;
#include <string>
#ifdef __cplusplus
extern "C" {
#endif
// TVM runtime C interface
#include <tvm/runtime/c_runtime_api.h>
#include <dlpack/dlpack.h>
/*!
* \brief Convert native char array to _gostring_ structure.
* _gostring_ structure represents the same memory footprint as golang string object.
*
* \param p is char pointer to a char array.
* \param l is the size of the char array. this method exclusively need length as
* its possible to have a bytearray in a string.
*
* \return _gostring_ object corresponding to native char array.
* Caller is responsible to free the memory block allocated here.
*/
static _gostring_ _native_to_gostring(const char *p, size_t l) {
_gostring_ ret;
ret.p = reinterpret_cast<char*>(malloc(l));
if (NULL == ret.p) {
ret.n = 0;
return ret;
}
memcpy(ret.p, p, l);
ret.n = l;
return ret;
}
/*!
* \brief embeds a 64bit uint value inside a string to serialize the data.
*
* \param s is string object.
* \param off is the offset in the string object.
* \param v is the uint64_t value which need to embed into given string.
*/
static void putuint64(std::string *s, size_t off, uint64_t v) {
for (int i = 0; i < 8; i++) {
(*s)[off + i] = (v >> (i * 8)) & 0xff;
}
}
// TVM runtime C interface wrappers
/*!
* \brief Native interface to query TVM_VERSION in golang string format.
*
* \return char pointer to TVM-VERSION
*/
const char* _TVM_VERSION(void) {
const char *version = TVM_VERSION;
return version;
}
/*!
* \brief Native interface for getting TVMGlobal function list.
*
* \param names return by argument to return the function names.
* We wrap all strings into single string joined by (len+string)
* which is unpacked and processed in golang.
*
* \return c_runtime_api return status.
*/
int _TVMFuncListGlobalNames(_gostring_* names) {
int names_size;
char **names_array;
int result;
result = TVMFuncListGlobalNames(&names_size, (char const ***)&names_array);
if (result) {
return result;
}
size_t tot = 8;
for (int ii = 0; ii < names_size ; ++ii) {
tot += 8 + strlen(names_array[ii]);
}
std::string str;
str.resize(tot);
putuint64(&str, 0, names_size);
size_t off = 8;
for (int64_t ii = 0; ii < names_size ; ++ii) {
putuint64(&str, off, strlen(names_array[ii]));
off += 8;
str.replace(off, strlen(names_array[ii]), names_array[ii]);
off += strlen(names_array[ii]);
}
*names = _native_to_gostring(str.data(), str.size());
if (str.size() != names->n) {
TVMAPISetLastError("malloc failed during _native_to_gostring");
result = 1;
}
return result;
}
// Helpers for TVMValue
/*!
* \brief Native helper to copy TVMValue from golang slice to native array.
* this helper is need as underlying momory for golang slice is not continueous.
*
* \param to_ptr is the native pointer of TVMValue array.
* \param from_ptr pointer to TVMValue in golang slice.
* \param array index in native array.
*/
void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) {
TVMValue *from_p = reinterpret_cast<TVMValue*>(from_ptr);
TVMValue *to_p = reinterpret_cast<TVMValue*>(to_ptr);
memcpy(to_p+ind, from_p, sizeof(TVMValue));
}
/*!
* \brief Native helper to copy TVMValue from golang slice to native array.
* this helper is need as underlying momory for golang slice is not continueous.
*
* \param to_ptr pointer to TVMValue in golang slice.
* \param from_ptr is the native pointer of TVMValue array.
* \param array index in native array.
*/
void _TVMValueNativeGet(void* to_ptr, void* from_ptr, int ind) {
TVMValue *from_p = reinterpret_cast<TVMValue*>(from_ptr);
TVMValue *to_p = reinterpret_cast<TVMValue*>(to_ptr);
memcpy(to_p, from_p+ind, sizeof(TVMValue));
}
extern int goTVMCallback(void*, void*, int, void*, void*);
/*!
* \brief _TVMCallback is the TVM runtime callback function for PackedFunction system.
*
* \param args is an array of TVMValue
* \param type_codes is an array of int
* \param num_args is int representing number of in arguments
* \param ret is the return value handle to set the packed function return.
* \param resource_handle is the golang private data pointer.
*
* \returns the error status as TVM_DLL
*/
int _TVMCallback(TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* resource_handle) {
return goTVMCallback(args, type_codes, num_args, ret, resource_handle);
}
/*!
* _TVMPackedCFuncFinalizer is finalizer for packed function system.
*
*/
void _TVMPackedCFuncFinalizer(void* resource_handle) {
return;
}
/*!
* /brief _ConvertFunction creates a packed function for with given resource handle.
*
* /param fptr is the pointer to golang resource handle.
* /param *fhandle is the return argument holding packed function.
*
* /return is an int indicating the return status.
*/
int _ConvertFunction(void* fptr, TVMFunctionHandle *fhandle) {
int ret = TVMFuncCreateFromCFunc(_TVMCallback,
fptr,
_TVMPackedCFuncFinalizer,
fhandle);
return ret;
}
#ifdef __cplusplus
}
#endif
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file gotvm.go
*/
// Package gotvm is TVM runtime interface definition for golang.
//
// Application need to import this package to access the c_runtime_api exposed by TVM.
package gotvm
//#include "gotvm.h"
import "C"
// DLPackVersion is the dlpack version of tvm runtime.
var DLPackVersion = int(C.DLPACK_VERSION)
// TVMVersion is the TVM runtime version.
var TVMVersion = getTVMVersion()
func getTVMVersion() (retStr string) {
retStr = C.GoString(C._TVM_VERSION())
return
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm native interface declaration.
* \file gotvm.h
*
* These declarations are in cgo interface definition while calling API
* across golang and native C boundaries.
*/
#ifndef GOTVM_GOTVM_H_
#define GOTVM_GOTVM_H_
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <tvm/runtime/c_runtime_api.h>
#include <dlpack/dlpack.h>
// Some type definitions for golang "C"
typedef void* native_voidp;
// Version
extern char* _TVM_VERSION(void);
// Wrappers : For incompatible cgo API.
// To handle array of strings wrapped into __gostring__
extern int _TVMFuncListGlobalNames(void*);
// To handle TVMValue slice to/from native sequential TVMValue array.
extern void _TVMValueNativeSet(void* to, void* from, int index);
extern void _TVMValueNativeGet(void* to, void* from, int index);
// Callbacks
extern int _ConvertFunction(void* fptr, void* funp);
#ifdef __cplusplus
}
#endif
#endif // GOTVM_GOTVM_H_
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file gotvm_test.go
*/
package gotvm
import (
"testing"
"reflect"
)
// Check TVMVersion API
func TestTVMVersion(t *testing.T) {
if len(TVMVersion) == 0 {
t.Error("TVMVersion not set\n")
}
if reflect.TypeOf(TVMVersion).Kind() != reflect.String {
t.Error("TVMVersion type mismatch\n")
}
}
// Check DLPackVersion API
func TestDLPackVersion(t *testing.T) {
if reflect.TypeOf(DLPackVersion).Kind() != reflect.Int {
t.Error("TVMVersion type mismatch\n")
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for TVMModule interface.
* \file module.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"errors"
"runtime"
"unsafe"
)
// Module type in golang hold pointer for the TVMModule handle.
//
// Module initialization happen through TVMModLoadFromFile api in TVM runtime.
type Module uintptr
// nativeCPtr returns type freed uintptr for the Module.
func (tvmmodule *Module) nativeCPtr() (retVal uintptr) {
retVal = (uintptr)(*tvmmodule)
return
}
// LoadModuleFromFile loads the given module in TVM runtime.
//
// `modpath` is the path to tvm module.
//
// `args` is an optional arguments of ["dll", "dylib", "dso", "so"] with default value "so"
//
// returns pointer to Module and err or if any.
func LoadModuleFromFile(modpath string, args ...interface{}) (retVal *Module, err error) {
modtype := "so"
if len(args) > 0 {
modtype = args[0].(string)
}
var modp uintptr
cmodpath := C.CString(modpath)
cmodtype := C.CString(modtype)
ret := (int32)(C.TVMModLoadFromFile(cmodpath,
cmodtype,
(*_Ctype_TVMModuleHandle)(unsafe.Pointer(&modp))))
C.free(unsafe.Pointer(cmodpath))
C.free(unsafe.Pointer(cmodtype))
if ret != 0 {
err = errors.New(getTVMLastError())
return
}
handle := new(Module)
*handle = Module(modp)
finalizer := func(mhandle *Module) {
nativeTVMModFree(mhandle)
mhandle = nil
}
runtime.SetFinalizer(handle, finalizer)
retVal = handle
return
}
// nativeTVMModFree free the module handle allocated in TVM runtime.
//
// `modp` is the Module handle to be freed.
func nativeTVMModFree(modp *Module) (retVal int32) {
retVal = (int32) (C.TVMModFree(C.TVMModuleHandle(modp.nativeCPtr())))
return
}
// GetFunction returns the function pointer from the module for given function name.
//
// `tvmmodule` is handle for Module
//
// `funcname` function name in module.
//
// `args` variadic args of `queryImport`
//
// returns function closure with signature
// func (args ...interface{}) (interface{}, error) and error if any.
//
// The closure function can be used to call Function with arguments directly.
//
// Variadic arguments can be any type which can be embed into Value.
func (tvmmodule *Module) GetFunction (
funcname string, args ...interface{}) (
retVal *Function, err error){
queryImports := int32(1)
if len(args) > 0 {
queryImports = int32(args[1].(int))
}
var funp uintptr
cfuncname := C.CString(funcname)
ret := (int32)(C.TVMModGetFunction((_Ctype_TVMModuleHandle)(*tvmmodule),
cfuncname,
C.int(queryImports),
(*_Ctype_TVMFunctionHandle)(unsafe.Pointer(&funp))))
C.free(unsafe.Pointer(cfuncname))
if ret != 0 {
err = errors.New(getTVMLastError())
return
}
handle := new(Function)
*handle = Function(funp)
finalizer := func(fhandle *Function) {
nativeTVMFuncFree(fhandle)
fhandle = nil
}
runtime.SetFinalizer(handle, finalizer)
retVal = handle
return
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file module_test.go
*/
package gotvm
import (
"testing"
"reflect"
)
// Check module loading - dll
func TestModuleTestLoad1(t *testing.T) {
// dll
mod, err := LoadModuleFromFile("./deploy.so", "dll")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(mod).Kind() != reflect.Ptr {
t.Error("Module type mis matched\n")
return
}
}
// Check module loading - dylib
func TestModuleTestLoad2(t *testing.T) {
// dylib
mod, err := LoadModuleFromFile("./deploy.so", "dylib")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(mod).Kind() != reflect.Ptr {
t.Error("Module type mis matched\n")
return
}
}
func TestModuleTestLoad3(t *testing.T) {
// dso
mod, err := LoadModuleFromFile("./deploy.so", "dso")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(mod).Kind() != reflect.Ptr {
t.Error("Module type mis matched\n")
return
}
}
// Check module loading - so
func TestModuleTestLoad4(t *testing.T) {
// so
mod, err := LoadModuleFromFile("./deploy.so", "so")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(mod).Kind() != reflect.Ptr {
t.Error("Module type mis matched\n")
return
}
}
// Check module loading - default (so)
func TestModuleTestLoad5(t *testing.T) {
// default type as so
mod, err := LoadModuleFromFile("./deploy.so")
if err != nil {
t.Error(err.Error())
return
}
if reflect.TypeOf(mod).Kind() != reflect.Ptr {
t.Error("Module type mis matched\n")
return
}
}
// Check module loading err
func TestModuleTestLoadErr(t *testing.T) {
// Unknown file should return error
_, err := LoadModuleFromFile("xyzabc.so")
if err == nil {
t.Error("Expected an error, but not received\n")
return
}
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for TVMArray aka DLTensor
* \file ndarray.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"unsafe"
"fmt"
"errors"
"runtime"
"reflect"
)
// Array type in golang hold pointer for the TVMArray object from dlpack.
//
// Array initialization happen through Empty api
type Array uintptr
// nativeCPtr returns type freed uintptr for the Array.
func (parray Array) nativeCPtr() (retVal uintptr) {
retVal = (uintptr)(parray)
return
}
func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error) {
ret := C.TVMArrayCopyFromBytes((*_Ctype_TVMArray)(unsafe.Pointer(parray.nativeCPtr())),
data,
C.ulong(datalen))
if ret != 0 {
err = errors.New(getTVMLastError())
}
return
}
// CopyFrom copies given golang data slice into Array.
//
// `val` is interface homding a slice of Array data type.
//
// returns err is any.
// TOD: Use reflections for better handling
func (parray Array) CopyFrom(val interface{}) (err error) {
var data unsafe.Pointer
var datalen int
dtype := ((*_Ctype_TVMArray)(unsafe.Pointer(parray))).dtype
switch val.(type) {
case []int8:
sliceVal := val.([]int8)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
return parray.nativeCopyFrom(data, datalen)
case []int16:
sliceVal := val.([]int16)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
return parray.nativeCopyFrom(data, datalen)
case []int32:
sliceVal := val.([]int32)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
return parray.nativeCopyFrom(data, datalen)
case []int64:
sliceVal := val.([]int64)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
return parray.nativeCopyFrom(data, datalen)
case []uint8:
sliceVal := val.([]uint8)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
return parray.nativeCopyFrom(data, datalen)
case []uint16:
sliceVal := val.([]uint16)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
return parray.nativeCopyFrom(data, datalen)
case []uint32:
sliceVal := val.([]uint32)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
return parray.nativeCopyFrom(data, datalen)
case []uint64:
sliceVal := val.([]uint64)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
return parray.nativeCopyFrom(data, datalen)
case []float32:
sliceVal := val.([]float32)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
return parray.nativeCopyFrom(data, datalen)
case []float64:
sliceVal := val.([]float64)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
return parray.nativeCopyFrom(data, datalen)
default:
err = fmt.Errorf("Given type not supported : %v\n", reflect.TypeOf(val))
return
}
return
}
func (parray Array) nativeCopyTo (data unsafe.Pointer, datalen int) (err error){
ret := C.TVMArrayCopyToBytes((*_Ctype_TVMArray)(unsafe.Pointer(parray.nativeCPtr())),
unsafe.Pointer(data),
C.ulong(datalen))
if ret != 0 {
err = errors.New(getTVMLastError())
}
return
}
// AsSlice returns the unitptr of for the data inside Array.
//
// returns the slice of array inside Array and err of any.
// TOD: Use reflections for better handling
func (parray Array) AsSlice() (retVal interface{}, err error) {
shape := parray.GetShape()
size := int64(1)
var data unsafe.Pointer
var datalen int
for ii := range shape {
size *= shape[ii]
}
dtype := ((*_Ctype_TVMArray)(unsafe.Pointer(parray))).dtype
switch parray.GetDType() {
case "int8":
sliceVal := make([]int8, size)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
err = parray.nativeCopyTo(data, datalen)
retVal = sliceVal
case "int16":
sliceVal := make([]int16, size)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
err = parray.nativeCopyTo(data, datalen)
retVal = sliceVal
case "int32":
sliceVal := make([]int32, size)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
err = parray.nativeCopyTo(data, datalen)
retVal = sliceVal
case "int64":
sliceVal := make([]int64, size)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
err = parray.nativeCopyTo(data, datalen)
retVal = sliceVal
case "uint8":
sliceVal := make([]uint8, size)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
err = parray.nativeCopyTo(data, datalen)
retVal = sliceVal
case "uint16":
sliceVal := make([]uint16, size)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
err = parray.nativeCopyTo(data, datalen)
retVal = sliceVal
case "uint32":
sliceVal := make([]uint32, size)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
err = parray.nativeCopyTo(data, datalen)
retVal = sliceVal
case "uint64":
sliceVal := make([]uint64, size)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
err = parray.nativeCopyTo(data, datalen)
retVal = sliceVal
case "float32":
sliceVal := make([]float32, size)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
err = parray.nativeCopyTo(data, datalen)
retVal = sliceVal
case "float64":
sliceVal := make([]float64, size)
data = unsafe.Pointer(&sliceVal[0])
datalen = len(sliceVal) * int(dtype.bits / 8)
err = parray.nativeCopyTo(data, datalen)
retVal = sliceVal
default:
err = fmt.Errorf("Given type not supported : %v\n", parray.GetDType())
return
}
return
}
// GetNdim returns the number of dimentions in Array
func (parray Array) GetNdim() (retVal int32) {
retVal = int32(((*_Ctype_TVMArray)(unsafe.Pointer(parray))).ndim)
return
}
// GetShape returns the number of dimentions in Array
func (parray Array) GetShape() (retVal []int64) {
shapePtr := (*C.int64_t)(((*_Ctype_TVMArray)(unsafe.Pointer(parray))).shape)
ndim := parray.GetNdim()
shapeSlice := (*[1<<31] int64)(unsafe.Pointer(shapePtr))[:ndim:ndim]
retVal = make([]int64, ndim)
copy(retVal, shapeSlice)
return
}
// GetDType returns the number of dimentions in Array
func (parray Array) GetDType() (retVal string) {
ret := ((*_Ctype_TVMArray)(unsafe.Pointer(parray))).dtype
retVal, _ = dtypeFromTVMType(*(*pTVMType)(unsafe.Pointer(&ret)))
return
}
// GetCtx returns the number of dimentions in Array
func (parray Array) GetCtx() (retVal Context) {
ret := ((*_Ctype_TVMArray)(unsafe.Pointer(parray))).ctx
retVal = *(*Context)(unsafe.Pointer(&ret))
return
}
// nativeTVMArrayAlloc is used to allocate TVMArray from given attributes.
//
// `shape` is int64 slice holding shape of the Array to be created.
//
// `ndim` is the rank of the Array to be created.
//
// `dtypeCode`, `dtypeBits` and `dtypeLanes` describe the data type in Array.
//
// `deviceType` indicates the device on whose memory the Array to allocated.
//
// `deviceID` indicates device index if multiple devices of same type present.
//
// return argument holding native pointer to newly created Array and error is any.
func nativeTVMArrayAlloc(shape []int64, ndim int32,
dtypeCode int32, dtypeBits int32, dtypeLanes int32,
deviceType int32, deviceID int32) (retVal uintptr, err error) {
ret := (int32)(C.TVMArrayAlloc((*_Ctype_long)(&(shape[0])),
C.int(ndim),
C.int(dtypeCode),
C.int(dtypeBits),
C.int(dtypeLanes),
C.int(deviceType),
C.int(deviceID),
(*_Ctype_TVMArrayHandle)(unsafe.Pointer(&retVal))))
if ret != 0 {
err = errors.New(getTVMLastError())
return
}
return
}
// Empty is used to allocate TVM empty array of given epecification.
//
// `shape` is int64 slice holding shape of the Array
//
// `args` is variadic args for
//
// `args[0]` is string for data type. Default value is 'float32'
//
// `args[1]` is Context. Default value is '{KDLCPU, 0}'
//
// returns pointer to Array on successful execution and error if any.
func Empty(shape []int64, args ...interface{}) (parray *Array, err error) {
typeName := "float32"
ctx := Context{KDLCPU, 0}
if len(shape) < 1 {
err = fmt.Errorf("Invalid shape for Array creation: %v\n", len(shape))
return
}
for i, val := range args {
switch val.(type) {
case string:
typeName = args[i].(string)
case Context:
ctx = args[i].(Context)
default:
err = fmt.Errorf("Invalid Optional Argument Type: %T\n", val)
return
}
}
tvmType, err := dtypeToTVMType(typeName)
if err != nil {
return
}
ndim := int32(len(shape))
newArray, err := nativeTVMArrayAlloc(shape, ndim, int32(tvmType.code),
int32(tvmType.bits), int32(tvmType.lanes),
ctx.DeviceType, ctx.DeviceID)
if err != nil {
return
}
handle := new(Array)
*handle = Array(newArray)
finalizer := func (ahandle *Array) {
nativeTVMArrayFree(*ahandle)
ahandle = nil
}
runtime.SetFinalizer(handle, finalizer)
parray = handle
return
}
// nativeTVMArrayFree is used to release the Array.
//
// `parray` is the Array handle.
//
// `ret` indicates the status of this api execution.
func nativeTVMArrayFree(parray Array) (retVal int32) {
retVal = (int32)(C.TVMArrayFree((*_Ctype_TVMArray)(unsafe.Pointer(parray.nativeCPtr()))))
return
}
/*!
* Copyright (c) 2018 by Contributors
* \brief This is an all in one TVM runtime file.
* \file tvm_runtime_pack.cc
*/
#include "src/runtime/c_runtime_api.cc"
#include "src/runtime/cpu_device_api.cc"
#include "src/runtime/workspace_pool.cc"
#include "src/runtime/module_util.cc"
#include "src/runtime/module.cc"
#include "src/runtime/registry.cc"
#include "src/runtime/file_util.cc"
#include "src/runtime/threading_backend.cc"
#include "src/runtime/thread_pool.cc"
#include "src/runtime/ndarray.cc"
// NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use.
// Likely we only need to enable one of the following
// If you use Module::Load, use dso_module
// For system packed library, use system_lib_module
#include "src/runtime/dso_module.cc"
#include "src/runtime/system_lib_module.cc"
// Graph runtime
#include "src/runtime/graph/graph_runtime.cc"
// Uncomment the following lines to enable RPC
// #include "../../src/runtime/rpc/rpc_session.cc"
// #include "../../src/runtime/rpc/rpc_event_impl.cc"
// #include "../../src/runtime/rpc/rpc_server_env.cc"
// These macros enables the device API when uncommented.
#define TVM_CUDA_RUNTIME 1
#define TVM_METAL_RUNTIME 1
#define TVM_OPENCL_RUNTIME 1
// Uncomment the following lines to enable Metal
// #include "../../src/runtime/metal/metal_device_api.mm"
// #include "../../src/runtime/metal/metal_module.mm"
// Uncomment the following lines to enable CUDA
// #include "../../src/runtime/cuda/cuda_device_api.cc"
// #include "../../src/runtime/cuda/cuda_module.cc"
// Uncomment the following lines to enable OpenCL
// #include "../../src/runtime/opencl/opencl_device_api.cc"
// #include "../../src/runtime/opencl/opencl_module.cc"
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package for TVMType interface
* \file type.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"fmt"
)
// pTVMType corresponding to data types.
type pTVMType struct {
code uint8
bits uint8
lanes uint16
}
// data type to pTVMType mapping
var dtypeMap = map[string] pTVMType {
"int8": pTVMType{0, 8, 1},
"int16": pTVMType{0, 16, 1},
"int32": pTVMType{0, 32, 1},
"int64": pTVMType{0, 64, 1},
"uint8": pTVMType{1, 8, 1},
"uint16": pTVMType{1, 16, 1},
"uint32": pTVMType{1, 32, 1},
"uint64": pTVMType{1, 64, 1},
"float32": pTVMType{2, 32, 1},
"float64": pTVMType{2, 64, 1},
}
// dtypeFromTVMType return the pTVMType corresponding to given dtype
//
// `dtype` string for the given data type.
func dtypeFromTVMType(tvmtype pTVMType) (retVal string, err error) {
for k, v := range dtypeMap {
if v.code == tvmtype.code && v.bits == tvmtype.bits && v.lanes == tvmtype.lanes {
retVal = k
return
}
}
err = fmt.Errorf("Cannot map TVMType:%v to dtype", tvmtype)
return
}
// dtypeToTVMType return the pTVMType corresponding to given dtype
//
// `dtype` string for the given data type.
func dtypeToTVMType(args ...interface{}) (tvmtype pTVMType, err error) {
dtype := args[0].(string)
lanes := 1
if len(args) == 2 {
lanes = args[1].(int)
}
for k, v := range dtypeMap {
if k == dtype {
tvmtype = v
tvmtype.lanes = uint16(lanes)
return
}
}
err = fmt.Errorf("Cannot map dtype:%v to TVMType", dtype)
return
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for common utilities
* \file util.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"unsafe"
)
// Native string map for go string
type nativeGoString struct { p uintptr; n int32 }
func goStringFromNative (s string) (retStr string) {
p := *(*nativeGoString)(unsafe.Pointer(&s))
retStr = string((*[0x7fffffff]byte)(unsafe.Pointer(p.p))[:p.n])
C.free(unsafe.Pointer(p.p))
return
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package source for TVMValue interface
* \file value.go
*/
package gotvm
//#include "gotvm.h"
import "C"
import (
"fmt"
"runtime"
"unsafe"
)
// KHandle is golang type code for TVM enum kHandle.
var KHandle = int32(C.kHandle)
// KNull is golang type code for TVM kNull.
var KNull = int32(C.kNull)
// KTVMType is golang type code for TVM kTVMType.
var KTVMType = int32(C.kTVMType)
// KTVMContext is golang type code for TVM kTVMContext.
var KTVMContext = int32(C.kTVMContext)
// KArrayHandle is golang type code for TVM kArrayHandle.
var KArrayHandle = int32(C.kArrayHandle)
// KNodeHandle is golang type code for TVM kNodeHandle.
var KNodeHandle = int32(C.kNodeHandle)
// KModuleHandle is gonag type code for TVM kModuleHandle.
var KModuleHandle = int32(C.kModuleHandle)
// KFuncHandle is gonalg type code for TVM kFuncHandle.
var KFuncHandle = int32(C.kFuncHandle)
// KStr is golang type code for TVM kStr.
var KStr = int32(C.kStr)
// KBytes is golang type code for TVM kBytes.
var KBytes = int32(C.kBytes)
// KNDArrayContainer is golang typecode for kNDArrayContainer.
var KNDArrayContainer = int32(C.kNDArrayContainer)
// KExtBegin is golang enum corresponding to TVM kExtBegin.
var KExtBegin = int32(C.kExtBegin)
// KNNVMFirst is golang enum corresponding to TVM kNNVMFirst.
var KNNVMFirst = int32(C.kNNVMFirst)
// KNNVMLast is golang enum corresponding to TVM kNNVMLast.
var KNNVMLast = int32(C.kNNVMLast)
// KExtReserveEnd is golang enum corresponding to TVM kExtReserveEnd.
var KExtReserveEnd = int32(C.kExtReserveEnd)
// KExtEnd is golang enum corresponding to TVM kExtEnd.
var KExtEnd = int32(C.kExtEnd)
// KDLInt is golang type code for TVM kDLInt.
var KDLInt = int32(C.kDLInt)
// KDLUInt is golang type code for TVM kDLUInt.
var KDLUInt = int32(C.kDLUInt)
// KDLFloat is golang type code for TVM kDLFloat.
var KDLFloat = int32(C.kDLFloat)
// Value Typemap for union exposed by TVM runtime API.
//
// gotvm maps it to a uintptr and then dynamically allocates memory by newTVMValue method.
type Value struct {
nptr uintptr
dtype int32
isLocal bool
}
// AsInt64 returns the int64 value inside the Value.
func (tvmval *Value) AsInt64() (retVal int64) {
retVal = tvmval.getVInt64()
return
}
// AsFloat64 returns the Float64 value inside the Value.
func (tvmval *Value) AsFloat64() (retVal float64) {
retVal = tvmval.getVFloat64()
return
}
// AsModule returns the Module inside the Value.
func (tvmval *Value) AsModule() (retVal *Module) {
mhandle := tvmval.getVMHandle()
retVal = &mhandle
return
}
// AsFunction returns the Function inside the Value.
func (tvmval *Value) AsFunction() (retVal *Function) {
fhandle := tvmval.getVFHandle()
retVal = &fhandle
return
}
// AsBytes returns the byte slice value inside the Value.
func (tvmval *Value) AsBytes() (retVal []byte) {
retVal = tvmval.getVBHandle().getData()
return
}
// AsStr returns the golang string in the Value.
func (tvmval *Value) AsStr() (retVal string) {
str := tvmval.getVStr()
retVal = str
return
}
// nativeCPtr return the unitptr corresponding to Value type.
func (tvmval *Value) nativeCPtr() (ret uintptr) {
ret = (uintptr)(tvmval.nptr)
return
}
// moveFrom copies the tvmval from other Value object.
func (tvmval *Value) moveFrom(fromval *Value) () {
C.memcpy(unsafe.Pointer(tvmval.nativeCPtr()),
unsafe.Pointer(fromval.nativeCPtr()),
C.sizeof_TVMValue)
// Move the dtype too.
tvmval.dtype = fromval.dtype
fromval.dtype = KNull
return
}
// setVInt64 initializes the Value object with given int64 value.
//
// `val` is the int64 value to initialize the Value
func (tvmval *Value) setVInt64(val int64) {
valp := (*C.int64_t)(unsafe.Pointer(tvmval.nativeCPtr()))
*valp = C.int64_t(val)
tvmval.dtype = KDLInt
return
}
// getVInt64 returns the int64 value inside the Value.
func (tvmval *Value) getVInt64() (retVal int64) {
valp := (*C.int64_t)(unsafe.Pointer(tvmval.nativeCPtr()))
retVal = int64(*valp)
return
}
// setVFloat64 initializes the Value object with given float64 value.
//
// `val` is the float64 value to initialize the Value.
func (tvmval *Value) setVFloat64(val float64) {
valp := (*C.double)(unsafe.Pointer(tvmval.nativeCPtr()))
*valp = C.double(val)
tvmval.dtype = KDLFloat
return
}
// getVFloat64 returns the float64 value inside Value.
func (tvmval *Value) getVFloat64() (retVal float64) {
valp := (*C.double)(unsafe.Pointer(tvmval.nativeCPtr()))
retVal = float64(*valp)
return
}
// setVHandle initializes the handle inside the Value.
//
// Can be used to store any uintptr type object like
// module handle, function handle and any object's nativeCPtr.
//
// `val` is the uintptr type of given handle.
func (tvmval *Value) setVHandle(val uintptr) {
valp := (**C.void)(unsafe.Pointer(tvmval.nativeCPtr()))
*valp = (*C.void)(unsafe.Pointer(val))
}
// getVHandle returns the uintptr handle
func (tvmval *Value) getVHandle() (retVal uintptr) {
valp := (**C.void)(unsafe.Pointer(tvmval.nativeCPtr()))
retVal = uintptr(unsafe.Pointer(*valp))
return
}
// setVStr intializes the Value with given golang string object.
//
// `val` is the golang string object used to initialize the Value.
func (tvmval *Value) setVStr(val string) {
valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr()))
*valp = C.CString(val)
tvmval.dtype = KStr
return
}
// getVStr returns the golang string for the native string inside Value.
func (tvmval *Value) getVStr() (retVal string) {
valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr()))
retVal = C.GoString(*valp)
return
}
// unSetVStr release the memory allocated in setVStr
func (tvmval *Value) unSetVStr() {
valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr()))
C.free(unsafe.Pointer(*valp))
tvmval.dtype = KNull
}
// setVAHandle is used to set Array handle in Value.
//
// Application can call the setVHandle with nativeCPtr instead too.
// This is a wrapper to accept Array directly.
func (tvmval *Value) setVAHandle(ptvmarray Array) {
tvmval.setVHandle(ptvmarray.nativeCPtr())
tvmval.dtype = KArrayHandle
return
}
// getVAHandle is used to get Array handle in Value.
func (tvmval *Value) getVAHandle() (retVal Array) {
retVal = (Array)(tvmval.getVHandle())
return
}
// setVMHandle is used to set Module handle in Value.
//
// Application can call the setVHandle with nativeCPtr instead too.
// This is a wrapper to accept Module directly.
func (tvmval *Value) setVMHandle(tvmmodule Module) {
tvmval.setVHandle(tvmmodule.nativeCPtr())
tvmval.dtype = KModuleHandle
return
}
// getVMHandle is used to get Module handle in Value.
func (tvmval *Value) getVMHandle() (retVal Module) {
retVal = (Module)(tvmval.getVHandle())
return
}
// setVFHandle is used to set Function handle in Value.
//
// Application can call the setVHandle with nativeCPtr instead.
// This is a wrapper to accept Function directly.
func (tvmval *Value) setVFHandle(tvmfunction Function) {
tvmval.setVHandle(tvmfunction.nativeCPtr())
tvmval.dtype = KFuncHandle
return
}
// getVFHandle is used to get Function handle in Value.
func (tvmval *Value) getVFHandle() (retVal Function) {
retVal = (Function)(tvmval.getVHandle())
return
}
// setVBHandle is used to set ByteArray handle in Value.
//
// Application can call the setVHandle with nativeCPtr instead.
// This is a wrapper to accept ByteArray directly.
func (tvmval *Value) setVBHandle(tbytearray ByteArray) {
tvmval.setVHandle(tbytearray.nativeCPtr())
tvmval.dtype = KBytes
return
}
// getVBHandle is used to get ByteArray handle in Value.
func (tvmval *Value) getVBHandle() (retVal ByteArray) {
retVal = (ByteArray)(tvmval.getVHandle())
return
}
// setValue is used to set the given value in Value.
//
// `val` is value of types accepted by Value container or native union.
func (tvmval *Value) setValue(val interface{}) (retVal int32, err error) {
retVal = KNull
switch val.(type) {
case string:
tvmval.setVStr(val.(string))
case uint8:
tvmval.setVInt64(int64(val.(uint8)))
case uint16:
tvmval.setVInt64(int64(val.(uint16)))
case uint32:
tvmval.setVInt64(int64(val.(uint32)))
case uint64:
tvmval.setVInt64(int64(val.(uint64)))
case int:
tvmval.setVInt64(int64(val.(int)))
case int8:
tvmval.setVInt64(int64(val.(int8)))
case int16:
tvmval.setVInt64(int64(val.(int16)))
case int32:
tvmval.setVInt64(int64(val.(int32)))
case int64:
tvmval.setVInt64(val.(int64))
case float32:
tvmval.setVFloat64(float64(val.(float32)))
case float64:
tvmval.setVFloat64(val.(float64))
case *Module:
tvmval.setVMHandle(*(val.(*Module)))
case *Function:
tvmval.setVFHandle(*(val.(*Function)))
case *ByteArray:
tvmval.setVBHandle(*(val.(*ByteArray)))
case []byte:
barray := newByteArray(val.([]byte))
tvmval.setVBHandle(barray)
case *Array:
tvmval.setVAHandle(*(val.(*Array)))
case func (args ...*Value) (interface{}, error):
fhandle, apierr := ConvertFunction(val)
if apierr != nil {
err = fmt.Errorf("Given value Type not defined for Value: %v : %T\n", val, val);
return
}
tvmval.setVFHandle(*fhandle)
// Clear the finalizer as we don't need to control it anymore.
runtime.SetFinalizer(fhandle, nil)
case *Value:
tvmval.moveFrom(val.(*Value))
case Value:
fromval := val.(Value)
tvmval.moveFrom(&fromval)
default:
err = fmt.Errorf("Given value Type not defined for Value: %v : %T\n", val, val);
}
retVal = tvmval.dtype
return
}
// newTVMValue initialize the TVMValue native object.
//
// This is intended to use as intermediate type between native and golang types.
// Allocated from FuncCall or Callback to handle conversions.
func newTVMValue() (retVal *Value) {
handle := new(Value)
handle.nptr = (uintptr(C.malloc(C.sizeof_TVMValue)))
handle.dtype = KNull
handle.isLocal = true
finalizer := func(vhandle *Value) {
vhandle.deleteTVMValue()
vhandle = nil
}
runtime.SetFinalizer(handle, finalizer)
retVal = handle
return
}
// deleteTVMValue free the native Value object which is allocated in newTVMValue.
func (tvmval Value) deleteTVMValue() {
if tvmval.isLocal == true {
if tvmval.dtype == KStr {
tvmval.unSetVStr()
}
if tvmval.dtype == KBytes {
tvmval.getVBHandle().deleteTVMByteArray()
}
}
C.free(unsafe.Pointer(tvmval.nativeCPtr()))
}
/*!
* Copyright (c) 2018 by Contributors
* \brief gotvm package
* \file value_test.go
*/
package gotvm
import (
"testing"
"math/rand"
"strings"
)
// Check Int64 Value looping via packed function calling another packed function.
func TestValueLoopInt64(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
newArgs := args[1:]
// Call Packed Function by Value
return pfunc.Invoke(newArgs)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0]
return
}
result := rand.Int63()
retVal, err := fhandle.Invoke(funccall, result)
if err != nil {
t.Error(err.Error())
return
}
if retVal.AsInt64() != result {
t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64())
return
}
}
// Check Int32 Value looping via packed function calling another packed function.
func TestValueLoopInt32(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
newArgs := args[1:]
// Call Packed Function by Value
return pfunc.Invoke(newArgs)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0]
return
}
result := rand.Int31()
retVal, err := fhandle.Invoke(funccall, result)
if err != nil {
t.Error(err.Error())
return
}
if retVal.AsInt64() != int64(result) {
t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64())
return
}
}
// Check Float32 Value looping via packed function calling another packed function.
func TestValueLoopFloat32(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
newArgs := args[1:]
// Call Packed Function by Value
return pfunc.Invoke(newArgs)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0]
return
}
result := rand.Float32()
retVal, err := fhandle.Invoke(funccall, result)
if err != nil {
t.Error(err.Error())
return
}
if retVal.AsFloat64() != float64(result) {
t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64())
return
}
}
// Check Float64 Value looping via packed function calling another packed function.
func TestValueLoopFloat64(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
newArgs := args[1:]
// Call Packed Function by Value
return pfunc.Invoke(newArgs)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0]
return
}
result := rand.Float64()
retVal, err := fhandle.Invoke(funccall, result)
if err != nil {
t.Error(err.Error())
return
}
if retVal.AsFloat64() != result {
t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64())
return
}
}
func TestValueLoopString(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
argStr := args[1].AsStr()
// Call Packed Function by Value
return pfunc.Invoke(argStr)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0].AsStr()
return
}
retVal, err := fhandle.Invoke(funccall, "TestString")
if err != nil {
t.Error(err.Error())
return
}
vStr := retVal.AsStr()
if strings.Compare(vStr, string("TestString")) != 0 {
t.Errorf("Expected : %v got:%v\n", string("TestString"), vStr)
return
}
}
// Check []byte Value looping via packed function calling another packed function.
func TestValueLoopByteSlice(t *testing.T) {
// Receive a function Handle and argument and echo the Value on the handle.
sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) {
// Reveive Packed Function Handle
pfunc := args[0].AsFunction()
argBytes := args[1].AsBytes()
// Call Packed Function by Value
return pfunc.Invoke(argBytes)
}
fhandle, err := ConvertFunction(sampleFunctionLoop)
if err != nil {
t.Error(err.Error())
return
}
// funccall is a simple golang callback function like C = A + B.
funccall := func (args ...*Value) (retVal interface{}, err error) {
retVal = args[0].AsBytes()
return
}
result := make([]byte, 1024)
rand.Read(result)
retVal, err := fhandle.Invoke(funccall, result)
if err != nil {
t.Error(err.Error())
return
}
received := retVal.AsBytes()
if len(result) != len(received) {
t.Errorf("Data expected Len: %v Got :%v\n", len(result), len(received))
return
}
for i := range result {
if result[i] != received[i] {
t.Errorf("Data expected: %v Got :%v at index %v\n", result[i], received[i], i)
return
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment