/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \brief Sample golang application deployment over tvm.
 * \file complex.go
 */

package main

import (
    "fmt"
    "io/ioutil"
    "math/rand"
    "./gotvm"
    "runtime"
)

// NNVM compiled model paths.
const (
    modLib    = "./mobilenet.so"
    modJSON   = "./mobilenet.json"
    modParams = "./mobilenet.params"
)

// main
func main() {
    defer runtime.GC()
    // Welcome
    fmt.Printf("TVM Version   : v%v\n", gotvm.TVMVersion)
    fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion)

    // Query global functions available
    funcNames, err := gotvm.FuncListGlobalNames()
    if err != nil {
        fmt.Print(err)
        return
    }
    fmt.Printf("Global Functions:%v\n", funcNames)

    // Import tvm module (so)
    modp, err := gotvm.LoadModuleFromFile(modLib)
    if err != nil {
        fmt.Print(err)
        fmt.Printf("Please copy tvm compiled modules here and update the sample.go accordingly.\n")
        fmt.Printf("You may need to update modLib, modJSON, modParams, tshapeIn, tshapeOut\n")
        return
    }
    fmt.Printf("Module Imported:%p\n", modp)
    bytes, err := ioutil.ReadFile(modJSON)
    if err != nil {
        fmt.Print(err)
        return
    }
    jsonStr := string(bytes)

    // Load module on tvm runtime - call tvm.graph_runtime.create
    funp, err := gotvm.GetGlobalFunction("tvm.graph_runtime.create")
    if err != nil {
        fmt.Print(err)
        return
    }
    fmt.Printf("Calling tvm.graph_runtime.create\n")
    // Call function
    graphrt, err := funp.Invoke(jsonStr, modp, (int64)(gotvm.KDLCPU), (int64)(0))
    if err != nil {
        fmt.Print(err)
        return
    }
    graphmod := graphrt.AsModule()
    fmt.Printf("Graph runtime Created\n")

    // Array allocation attributes
    tshapeIn  := []int64{1, 224, 224, 3}
    tshapeOut := []int64{1, 1001}

    // Allocate input Array
    inX, err := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0))
    if err != nil {
        fmt.Print(err)
        return
    }

    // Allocate output Array
    out, err := gotvm.Empty(tshapeOut)
    if err != nil {
        fmt.Print(err)
        return
    }
    fmt.Printf("Input and Output Arrays allocated\n")

    // Get module function from graph runtime : load_params
    // Read params
    bytes, err = ioutil.ReadFile(modParams)
    if err != nil {
        fmt.Print(err)
    }

    // Load Params
    funp, err = graphmod.GetFunction("load_params")
    if err != nil {
        fmt.Print(err)
        return
    }
    fmt.Printf("Func load_params:%p\n", funp)

    // Call function
    _, err = funp.Invoke(bytes)
    if err != nil {
        fmt.Print(err)
        return
    }
    fmt.Printf("Module params loaded\n")

    // Set some data in input Array
    inSlice := make([]float32, (244 * 244 * 3))
    rand.Seed(10)
    rand.Shuffle(len(inSlice), func(i, j int) {inSlice[i],
                                               inSlice[j] = rand.Float32(),
                                               rand.Float32() })
    inX.CopyFrom(inSlice)

    // Set Input
    funp, err = graphmod.GetFunction("set_input")
    if err != nil {
        fmt.Print(err)
        return
    }

    // Call function
    _, err = funp.Invoke("input", inX)
    if err != nil {
        fmt.Print(err)
        return
    }

    fmt.Printf("Module input is set\n")

    // Run
    funp, err = graphmod.GetFunction("run")
    if err != nil {
        fmt.Print(err)
        return
    }

    // Call function
    _, err = funp.Invoke()
    if err != nil {
        fmt.Print(err)
        return
    }
    fmt.Printf("Module Executed \n")

    // Call runtime function get_output
    funp, err = graphmod.GetFunction("get_output")
    if err != nil {
        fmt.Print(err)
        return
    }

    // Call function
    _, err = funp.Invoke(int64(0), out)
    if err != nil {
        fmt.Print(err)
        return
    }
    fmt.Printf("Got Module Output \n")

    // Print results
    outIntf, _ := out.AsSlice()
    outSlice := outIntf.([]float32)
    fmt.Printf("Result:%v\n", outSlice[:10])
}