/*! * Copyright (c) 2018 by Contributors * \brief gotvm package * \file function_test.go */ package gotvm import ( "testing" "reflect" "math/rand" "strings" "fmt" ) // Check global function list API func TestFunctionGlobals(t *testing.T) { funcNames, err := FuncListGlobalNames() if err != nil { t.Error(err.Error()) return } if len(funcNames) < 1 { t.Errorf("Global Function names received:%v\n", funcNames) } } // Check GetFunction API func TestFunctionGlobalGet(t *testing.T) { funp, err := GetGlobalFunction("tvm.graph_runtime.create") if err != nil { t.Error(err.Error()) return } if reflect.TypeOf(funp).Kind() != reflect.Ptr { t.Error("Function type mis matched\n") return } } func TestFunctionModuleGet(t *testing.T) { modp, err := LoadModuleFromFile("./deploy.so") if err != nil { t.Error(err.Error()) return } funp, err := modp.GetFunction("myadd") if err != nil { t.Error(err.Error()) return } if reflect.TypeOf(funp).Kind() != reflect.Ptr { t.Error("Function type mis matched\n") return } dlen := int64(1024) shape := []int64{dlen} inX, _ := Empty(shape) inY, _ := Empty(shape) out, _ := Empty(shape) dataX := make([]float32, (dlen)) dataY := make([]float32, (dlen)) outExpected := make([]float32, (dlen)) for i := range dataX { dataX[i] = rand.Float32() dataY[i] = rand.Float32() outExpected[i] = dataX[i] + dataY[i] } inX.CopyFrom(dataX) inY.CopyFrom(dataY) funp.Invoke(inX, inY, out) outi, _ := out.AsSlice() outSlice := outi.([]float32) if len(outSlice) != len(outExpected) { t.Errorf("Data expected Len: %v Got :%v\n", len(outExpected), len(outSlice)) return } for i := range outSlice { if outExpected[i] != outSlice[i] { t.Errorf("Data expected: %v Got :%v at index %v\n", outExpected[i], outSlice[i], i) return } } } // Check FunctionConvert API func TestFunctionConvert(t *testing.T) { sampleCb := func (args ...*Value) (retVal interface{}, err error) { val1 := args[0].AsInt64() val2 := args[1].AsInt64() retVal = int64(val1+val2) return } fhandle, err := ConvertFunction(sampleCb) if err != nil { t.Error(err.Error()) return } retVal, err := fhandle.Invoke(10, 20) if err != nil { t.Error(err.Error()) return } if retVal.AsInt64() != int64(30) { t.Errorf("Expected result :30 got:%v\n", retVal.AsInt64()) return } } func TestFunctionError(t *testing.T) { sampleCb := func (args ...*Value) (retVal interface{}, err error) { err = fmt.Errorf("Sample Error XYZABC"); return } fhandle, err := ConvertFunction(sampleCb) if err != nil { t.Error(err.Error()) return } _, err = fhandle.Invoke() if err == nil { t.Error("Expected error but didn't received\n") return } if !strings.Contains(err.Error(), string("Sample Error XYZABC")) { t.Errorf("Expected Error should contain :\"Sample Error XYZABC\" got :%v\n", err.Error()) } } // Check FunctionRegister func TestFunctionRegister(t *testing.T) { sampleCb := func (args ...*Value) (retVal interface{}, err error) { val1 := args[0].AsInt64() val2 := args[1].AsInt64() retVal = int64(val1+val2) return } RegisterFunction(sampleCb, "TestFunctionRegister.sampleCb"); // Query global functions available funcNames, err := FuncListGlobalNames() if err != nil { t.Error(err.Error()) return } found := 0 for ii := range (funcNames) { if strings.Compare(funcNames[ii], "TestFunctionRegister.sampleCb") == 0 { found = 1 } } if found == 0 { t.Error("Registered function not found in global function list.") return } // Get "sampleCb" and verify the call. funp, err := GetGlobalFunction("TestFunctionRegister.sampleCb") if err != nil { t.Error(err.Error()) return } // Call function result, err := funp.Invoke((int64)(10), (int64)(20)) if err != nil { t.Error(err.Error()) return } if result.AsInt64() != int64(30) { t.Errorf("Expected result :30 got:%v\n", result.AsInt64()) return } } // Check packed function receiving go-closure as argument. func TestFunctionClosureArg(t *testing.T) { // sampleFunctionArg receives a Packed Function handle and calls it. sampleFunctionArg := func (args ...*Value) (retVal interface{}, err error) { // Reveive Packed Function Handle pfunc := args[0].AsFunction() // Call Packed Function by Value ret, err := pfunc.Invoke(args[1], args[2]) if err != nil { return } // Call Packed Function with extracted values ret1, err := pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64()) if err != nil { return } if ret1.AsInt64() != ret.AsInt64() { err = fmt.Errorf("Invoke with int64 didn't match with Value") return } retVal = ret return } RegisterFunction(sampleFunctionArg, "TestFunctionClosureArg.sampleFunctionArg"); funp, err := GetGlobalFunction("TestFunctionClosureArg.sampleFunctionArg") if err != nil { t.Error(err.Error()) return } // funccall is a simple golang callback function like C = A + B. funccall := func (args ...*Value) (retVal interface{}, err error) { val1 := args[0].AsInt64() val2 := args[1].AsInt64() retVal = int64(val1+val2) return } // Call function result, err := funp.Invoke(funccall, 30, 50) if err != nil { t.Error(err.Error()) return } if result.AsInt64() != int64(80) { t.Errorf("Expected result :80 got:%v\n", result.AsInt64()) return } } // Check packed function returning a go-closure. func TestFunctionClosureReturn(t *testing.T) { // sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) { funccall := func (cargs ...*Value) (fret interface{}, ferr error) { val1 := cargs[0].AsInt64() val2 := cargs[1].AsInt64() fret = int64(val1+val2) return } retVal = funccall return } RegisterFunction(sampleFunctionCb, "TestFunctionClosureReturn.sampleFunctionCb"); funp, err := GetGlobalFunction("TestFunctionClosureReturn.sampleFunctionCb") if err != nil { t.Error(err.Error()) return } // Call function result, err := funp.Invoke() if err != nil { t.Error(err.Error()) return } pfunc := result.AsFunction() pfuncRet, err := pfunc.Invoke(30, 40) if err != nil { t.Error(err.Error()) return } if pfuncRet.AsInt64() != int64(70) { t.Errorf("Expected result :70 got:%v\n", pfuncRet.AsInt64()) return } } // Check packed function with no arguments and no return values. func TestFunctionNoArgsReturns(t *testing.T) { sampleFunction := func (args ...*Value) (retVal interface{}, err error) { return } fhandle, err := ConvertFunction(sampleFunction) if err != nil { t.Error(err.Error()) return } _, err = fhandle.Invoke() if err != nil { t.Error(err.Error()) return } } // Check packed function returning a go-closure with no arg and returns. func TestFunctionNoArgsReturns2(t *testing.T) { // sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) { funccall := func (cargs ...*Value) (fret interface{}, ferr error) { return } retVal = funccall return } funp, err := ConvertFunction(sampleFunctionCb) if err != nil { t.Error(err.Error()) return } // Call function result, err := funp.Invoke() if err != nil { t.Error(err.Error()) return } pfunc := result.AsFunction() _, err = pfunc.Invoke() if err != nil { t.Error(err.Error()) return } }