function.go 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
/*!
 * \brief gotvm package source for TVMFunction interface.
 * \file function.go
 */

package gotvm

//#include "gotvm.h"
import "C"

import (
    "unsafe"
    "encoding/binary"
    "errors"
    "runtime"
    "reflect"
    "fmt"
)

// Function type in golang hold pointer for the TVMFunction handle.
type Function uintptr

// nativeCPtr returns type freed uintptr for the Function.
func (tvmfunction Function) nativeCPtr() (retVal uintptr) {
    retVal = (uintptr)(tvmfunction)
    return
}

// Invoke calls the TVM packed function referred by the handle with given arguments.
func (tvmfunction *Function) Invoke(args ...interface{}) (retVal *Value, err error) {
    funccall := func (fargs ...interface{}) (*Value, error) {
        return callNativeFunction(tvmfunction, fargs)
    }
    // Check is any args are contain any ValueArray
    // Possible is it's a args forward from one packed function to another.
    valueArrayFound := false
    for ii := range args {
        switch args[ii].(type) {
            case []*Value:
                valueArrayFound = true
        }
    }

    if !valueArrayFound {
        return funccall(args...)
    }
    if len(args) != 1 {
        err = fmt.Errorf("Not supported if packed function args are a mix of []Value and other types")
        return
    }

    valArray := args[0].([]*Value)
    if len(valArray) > 0 {
        newArgs := make([]interface{}, len(valArray))
        for ii := range valArray {
            newVal := newTVMValue()
            newVal.moveFrom(valArray[ii])
            newArgs[ii] = newVal
        }

        return funccall(newArgs...)
    }
    return funccall()
}

// FuncListGlobalNames is used to query global callable packed function names from TVM.
//
// returns slice of string holding function names and error if any.
func FuncListGlobalNames() (retVal []string, err error) {
    var str string
    ret := (int32)(C._TVMFuncListGlobalNames(unsafe.Pointer((&str))))
    if ret != 0 {
        err = errors.New(getTVMLastError())
        return
    }

    str = goStringFromNative(*(*string)(unsafe.Pointer(&str)))
    bin := binary.LittleEndian
    size := bin.Uint64([]byte(str[:8]))
    str = str[8:]
    retVal = make([]string, size)
    for i := range retVal {
        len := bin.Uint64([]byte(str[:8]))
        str = str[8:]
        retVal[i] = str[:len]
        str = str[len:]
    }
    return
}

// GetGlobalFunction is to get handle to the given global function name.
//
// `funcname` is the name of global packed function.
//
// returns a function closure with signature
//         func (args ...interface{}) (interface{}, error) and  error if any.
//
// The closure function can be used to call Function with arguments directly.
//
// Variadic arguments can be any type which can be embed into Value.
func GetGlobalFunction(funcname string) (retVal *Function, err error) {
    var funp uintptr

    cfuncname := C.CString(funcname)
    ret := (int32)(C.TVMFuncGetGlobal(cfuncname,
125
                                      (*C.TVMFunctionHandle)(unsafe.Pointer(&funp))))
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
    C.free(unsafe.Pointer(cfuncname))

    if ret != 0 {
        err = errors.New(getTVMLastError())
        return
    }

    handle := new(Function)
    *handle = Function(funp)
    finalizer := func(fhandle *Function) {
        nativeTVMFuncFree(fhandle)
        fhandle = nil
    }
    runtime.SetFinalizer(handle, finalizer)
    retVal = handle
    return
}

// callNativeFunction is routine which calls gotvm native wrapper with given arguments.
//
// `handle` is the handle for Function.
//
// `args` are the variadic arguments to the Function.
//
// returns the interface for the return value from TVM if any and error if any.
func callNativeFunction(handle *Function, args []interface{}) (retVal *Value, err error) {
    argsIn := make([]*Value, len(args))
    var typeCodes []int32
    if len(args) != 0 {
        typeCodes = make([]int32, len(args))
    } else {
        typeCodes = make([]int32, 1)
    }

    for ii := range args {
        argsIn[ii] = newTVMValue()
        if typeCodes[ii], err = argsIn[ii].setValue(args[ii]); err != nil {
            return
        }
    }

    retVal = newTVMValue()
    argsOut := []*Value{retVal}
    retTypeCode := KNull
    err = nativeTVMFuncCall(handle, argsIn, typeCodes, argsOut, &retTypeCode)
    if err != nil {
        retVal = nil
        return
    }
    retVal.isLocal = false
    retVal.dtype = retTypeCode
    return
}

// nativeTVMFuncFree free the function handle allocated in TVM runtime.
//
// `funp` is the Function handle to be freed.
func nativeTVMFuncFree(funp *Function) (retVal int32) {
    retVal = (int32) (C.TVMFuncFree(C.TVMFunctionHandle(funp.nativeCPtr())))
    return
}

// nativeToGoSlice converts native TVMValue array to Golang slice of TVMValue
//
//
func nativeToGoSlice(nargValues (*C.void), argValues []*Value, typeCodes []int32) {
    for ii := range argValues {
        C._TVMValueNativeGet(unsafe.Pointer(argValues[ii].nativeCPtr()),
                             unsafe.Pointer(nargValues),
                             C.int(int32(ii)))
        argValues[ii].dtype = typeCodes[ii]
    }
}

// nativeFromGoSlice converts golang slice of TVMValue to native TVMValue array.
//
//
func nativeFromGoSlice(argValues []*Value) (nptr (*C.void)) {
    nargValues := ((uintptr)(C.malloc(C.ulong(C.sizeof_TVMValue * len(argValues)))))
    for ii := range argValues {
        C._TVMValueNativeSet(unsafe.Pointer(nargValues),
                             unsafe.Pointer(argValues[ii].nativeCPtr()),
                             C.int(int32(ii)))
    }
    nptr = (*C.void)(unsafe.Pointer(nargValues))
    return
}

// nativeTVMFuncCall executes the function with given arguments
//
// `funp` Function handle to the packed function.
//
// `argValues` is the slice of Value which are arguments to the packed function.
//
// `typeCodes` is the alice of argument type codes corresponding to argValues.
//
// `retValues` is return argument which is slice of return values from the packed function.
//
// `retTypeCode` is int32 holding type codes for retValue
//
// Returns err indicating native error if any.
func nativeTVMFuncCall(funp *Function, argValues []*Value, typeCodes []int32,
                 retValues []*Value, retTypeCode *int32) (err error) {
    nargValues := nativeFromGoSlice(argValues)
    nretValues := nativeFromGoSlice(retValues)
231 232 233
	result := (int32)(C.TVMFuncCall(C.TVMFunctionHandle(*funp),
                                    (*C.TVMValue)(unsafe.Pointer(nargValues)),
                                    (*C.int)(unsafe.Pointer(&(typeCodes[0]))),
234
                                    C.int(len(argValues)),
235 236
                                    (*C.TVMValue)(unsafe.Pointer(nretValues)),
                                    (*C.int)(unsafe.Pointer(retTypeCode))))
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
    nativeToGoSlice(nargValues, argValues, typeCodes)
    nativeToGoSlice(nretValues, retValues, (*[1<<31] int32)(unsafe.Pointer(retTypeCode))[:1:1])
    C.free(unsafe.Pointer(nargValues))
    C.free(unsafe.Pointer(nretValues))

    if result != 0 {
	    err = errors.New(getTVMLastError())
    }
    return
}

// goCallBack is a structure holding the go callback function pointer.
// This wrapping is necessary as cgo doesn't support
// passing golang functions type conversion to native.
type goCallBack struct {
    cb func (args ...*Value) (interface{}, error)
}

//export goTVMCallback
func goTVMCallback(args C.native_voidp, typeCodes C.native_voidp, numArgs int32,
                   retArg C.native_voidp, resourceHandle C.native_voidp) (ret int32){
    fcb := (*goCallBack)(resourceHandle)
    // Make Value Sice from native TVMValue pointer.
    argValues := make([]*Value, numArgs)

    for ii := range argValues {
        argValues[ii] = newTVMValue()
        argValues[ii].isLocal = false
    }

    // Prepare arguments for golang callback function
    nativeToGoSlice((*C.void)(unsafe.Pointer(args)), argValues,
                    (*[1<<31] int32)(unsafe.Pointer(typeCodes))[:numArgs:numArgs])
    cbargs := argValues

    // Execute the callback
    retVal, err := fcb.cb(cbargs...)
    if err != nil {
        errStr := err.Error()
        setTVMLastError(errStr)
        return -1
    }

280
    // It's possible a packed function directly return
281 282 283 284 285
    // the return value of another packed function.
    //
    // Inside a packed func :
    //      ```return pfunc.Invoke(args)```
    //
286
    // In this case pfunc returns nil which is
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
    // returned as an interface holding nil *Value.
    // Which becomes a valid retVal holding nil *Value.
    isRetNull := false
    switch retVal.(type) {
        case *Value:
            pRet := retVal.(*Value)
            if pRet == nil {
                isRetNull = true
            }
    }

    // Handle return value from callback function
    if retVal != nil && !isRetNull {
        var retTypeCode int32
        retValues := []*Value{newTVMValue()}

        retTypeCode, err = retValues[0].setValue(retVal)
        if err != nil {
            errStr := err.Error()
            setTVMLastError(errStr)
            return -1
        }
        nretValues := nativeFromGoSlice(retValues)

        // Handle KStr, KBytes: Local finalizers shouldn't try freeing them.
        retValues[0].isLocal = false

314 315 316
        apiRet := (int32) (C.TVMCFuncSetReturn(C.TVMRetValueHandle(retArg),
                                               (*C.TVMValue)(unsafe.Pointer(nretValues)),
                                               (*C.int)(unsafe.Pointer(&retTypeCode)), 1))
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
        C.free(unsafe.Pointer(nretValues))
        if apiRet != 0 {
            errStr := string("TVMCFuncSetReturn failed ")
            setTVMLastError(errStr)
        }
    }
    return
}

// ConvertFunction converts given golang function to TVM packed function.
//
// `args[0]` function pointer for a type ```func (args ...interface{}) (interface{})```
//
// Returns Function handle and err if any.
func ConvertFunction(args ...interface{}) (retVal *Function, err error) {
    function := args[0].(func (args ...*Value) (interface{}, error))
    fcb := &goCallBack{cb:function}
    var funp uintptr

    result := (int32) (C._ConvertFunction(unsafe.Pointer(fcb),
                                          unsafe.Pointer(&funp)))
    if result != 0 {
	    err = errors.New(getTVMLastError())
    }

    handle := new(Function)
    *handle = Function(funp)
    finalizer := func(fhandle *Function) {
        nativeTVMFuncFree(fhandle)
        fhandle = nil
    }
    runtime.SetFinalizer(handle, finalizer)
    retVal = handle
    return
}

// RegisterFunction registers the golang func in TVM runtime global space.
//
// `args[0]` function pointer for a type ```func (args ...interface{}) (interface{})```
//
// `args[1]` Optional argument of function name with which it will be registered.
//           If not passed we use function name from reflection.
//
// Returns err indicating native error if any.
func RegisterFunction(args ...interface{}) (err error) {
    fhandle, err := ConvertFunction(args...)
    if err != nil {
        return
    }

    funcname := runtime.FuncForPC(reflect.ValueOf(args[0]).Pointer()).Name()
    if len(args) > 1 {
        funcname = args[1].(string)
    }

    cfuncname := C.CString(funcname)
    result := (int32) (C.TVMFuncRegisterGlobal(cfuncname,
374
                                               C.TVMFunctionHandle(*fhandle),
375 376 377 378 379 380 381 382 383
                                               0)); // Override = False
    C.free(unsafe.Pointer(cfuncname))
    if result != 0 {
	    err = errors.New(getTVMLastError())
    }
    // Clear the finalizer as we don't need to control it anymore.
    runtime.SetFinalizer(fhandle, nil)
    return
}