/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
v0.0.4

// TODO(weberlo): should we add sugar for scalar types (e.g., `int32` => `Tensor[(), int32]`)?

def @id[A](%x: A) -> A {
  %x
}

def @compose[A, B, C](%f: fn(B) -> C, %g: fn(A) -> B) {
  fn (%x: A) -> C {
    %f(%g(%x))
  }
}

def @flip[A, B, C](%f: fn(A, B) -> C) -> fn(B, A) -> C {
  fn(%b: B, %a: A) -> C {
    %f(%a, %b)
  }
}

/*
 * A LISP-style list ADT. An empty list is represented by `Nil`, and a member
 * `x` can be appended to the front of a list `l` via the constructor `Cons(x, l)`.
 */
type List[A] {
  Cons(A, List[A]),
  Nil,
}

/*
 * Get the head of a list. Assume the list has at least one element.
 */
def @hd[A](%xs: List[A]) -> A {
  match? (%xs) {
    Cons(%x, _) => %x,
  }
}

/*
 * Get the tail of a list.
 */
def @tl[A](%xs: List[A]) -> List[A] {
  match? (%xs) {
    Cons(_, %rest) => %rest,
  }
}

/*
 * Get the `n`th element of a list.
 */
def @nth[A](%xs: List[A], %n: Tensor[(), int32]) -> A {
  if (%n == 0) {
    @hd(%xs)
  } else {
    @nth(@tl(%xs), %n - 1)
  }
}

/*
 * Return the length of a list.
 */
def @length[A](%xs: List[A]) -> Tensor[(), int32] {
  match (%xs) {
    Cons(_, %rest) => 1 + @length(%rest),
    Nil => 0,
  }
}

/*
 * Update the `n`th element of a list and return the updated list.
 */
def @update[A](%xs: List[A], %n: Tensor[(), int32], %v: A) -> List[A] {
  if (%n == 0) {
    Cons(%v, @tl(%xs))
  } else {
    Cons(@hd(%xs), @update(@tl(%xs), %n - 1, %v))
  }
}

/*
 * Map a function over a list's elements. That is, `map(f, xs)` returns a new
 * list where the `i`th member is `f` applied to the `i`th member of `xs`.
 */
def @map[A, B](%f: fn(A) -> B, %xs: List[A]) -> List[B] {
  match (%xs) {
    Cons(%x, %rest) => Cons(%f(%x), @map(%f, %rest)),
    Nil => Nil,
  }
}

/*
 * A left-way fold over a list.
 *
 * `foldl(f, z, cons(a1, cons(a2, cons(a3, cons(..., nil)))))`
 * evaluates to `f(...f(f(f(z, a1), a2), a3)...)`.
 */
def @foldl[A, B](%f: fn(A, B) -> A, %acc: A, %xs: List[B]) -> A {
  match (%xs) {
    Cons(%x, %rest) => @foldl(%f, %f(%acc, %x), %rest),
    Nil => %acc,
  }
}

/*
 * A right-way fold over a list.
 *
 * `foldr(f, z, cons(a1, cons(a2, cons(..., cons(an, nil)))))`
 * evaluates to `f(a1, f(a2, f(..., f(an, z)))...)`.
 */
def @foldr[A, B](%f: fn(A, B) -> B, %acc: B, %xs: List[A]) -> B {
  match (%xs) {
    Cons(%x, %rest) => %f(%x, @foldr(%f, %acc, %rest)),
    Nil => %acc,
  }
}

/*
 * A right-way fold over a nonempty list.
 *
 * `foldr1(f, cons(a1, cons(a2, cons(..., cons(an, nil)))))`
 * evaluates to `f(a1, f(a2, f(..., f(an-1, an)))...)`
 */
def @foldr1[A](%f: fn(A, A) -> A, %xs: List[A]) -> A {
  match? (%xs) {
    Cons(%x, Nil) => %x,
    Cons(%x, %rest) => %f(%x, @foldr1(%f, %rest)),
  }
}

/*
 * Computes the sum of a list of integer scalars.
 */
def @sum(%xs: List[Tensor[(), int32]]) {
  let %add_f = fn(%x: Tensor[(), int32], %y: Tensor[(), int32]) -> Tensor[(), int32] {
    %x + %y
  };
  @foldl(%add_f, 0, %xs)
}

/*
 * Concatenates two lists.
 */
def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] {
  let %updater = fn(%x: A, %xss: List[A]) -> List[A] {
    Cons(%x, %xss)
  };
  @foldr(%updater, %ys, %xs)
  // TODO(weberlo): write it like below, once VM constructor compilation is fixed
  // @foldr(Cons, %ys, %xs)
}

/*
 * Filters a list, returning a sublist of only the values which satisfy the given predicate.
 */
def @filter[A](%f: fn(A) -> Tensor[(), bool], %xs: List[A]) -> List[A] {
  match (%xs) {
    Cons(%x, %rest) => {
      if (%f(%x)) {
        Cons(%x, @filter(%f, %rest))
      } else {
        @filter(%f, %rest)
      }
    },
    Nil => Nil,
  }
}

/*
 * Combines two lists into a list of tuples of their elements.
 *
 * The zipped list will be the length of the shorter list.
 */
def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] {
  match (%xs, %ys) {
    (Cons(%x, %x_rest), Cons(%y, %y_rest)) => Cons((%x, %y), @zip(%x_rest, %y_rest)),
    _ => Nil,
  }
}

/*
 * Reverses a list.
 */
def @rev[A](%xs: List[A]) -> List[A] {
  let %updater = fn(%xss: List[A], %x: A) -> List[A] {
    Cons(%x, %xss)
  };
  @foldl(%updater, Nil, %xs)
  // TODO(weberlo): write it like below, once VM constructor compilation is fixed
  // @foldl(@flip(Cons), Nil, %xs)
}

/*
 * An accumulative map, which is a fold that simulataneously updates an
 * accumulator value and a list of results.
 *
 * This map proceeds through the list from right to left.
 */
def @map_accumr[A, B, C](%f: fn(A, B) -> (A, C), %init: A, %xs: List[B]) -> (A, List[C]) {
  let %updater = fn(%x: B, %acc: (A, List[C])) -> (A, List[C]) {
    let %f_out = %f(%acc.0, %x);
    (%f_out.0, Cons(%f_out.1, %acc.1))
  };
  @foldr(%updater, (%init, Nil), %xs)
}

/*
 * an accumulative map, which is a fold that simulataneously updates an
 * accumulator value and a list of results.
 *
 * This map proceeds through the list from left to right.
 */
def @map_accuml[A, B, C](%f: fn(A, B) -> (A, C), %init: A, %xs: List[B]) -> (A, List[C]) {
  let %updater = fn(%acc: (A, List[C]), %x: B) -> (A, List[C]) {
    let %f_out = %f(%acc.0, %x);
    (%f_out.0, Cons(%f_out.1, %acc.1))
  };
  @foldl(%updater, (%init, Nil), %xs)
}

/*
 * An optional ADT, which can either contain some other type or nothing at all.
 */
type Option[A] {
  Some(A),
  None,
}

/*
 * Builds up a list starting from a seed value.
 *
 * `f` returns an option containing a new seed and an output value. `f` will
 * continue to be called on the new seeds until it returns `None`. All the output
 * values will be combined into a list, right to left.
 */
def @unfoldr[A, B](%f: fn(A) -> Option[(A, B)], %seed: A) -> List[B] {
  match (%f(%seed)) {
    Some(%val) => Cons(%val.1, @unfoldr(%f, %val.0)),
    None => Nil,
  }
}

/*
 * Builds up a list starting from a seed value.
 *
 * `f` returns an option containing a new seed and an output value. `f` will
 * continue to be called on the new seeds until it returns `None`. All the
 * output values will be combined into a list, left to right.
 */
def @unfoldl[A, B](%f: fn(A) -> Option[(A, B)], %seed: A) -> List[B] {
  @rev(@unfoldr(%f, %seed))
}

/*
 * A tree ADT. A tree can contain any type. It has only one
 * constructor, rose(x, l), where x is the content of that point of the tree
 * and l is a list of more trees of the same type. A leaf is thus rose(x,
 * nil()).
 */
type Tree[A] {
  Rose(A, List[Tree[A]]),
}

/*
 * Maps over a tree. The function is applied to each subtree's contents.
 */
def @tmap[A, B](%f: fn(A) -> B, %t: Tree[A]) -> Tree[B] {
  match(%t) {
    Rose(%v, %sub_trees) => {
      let %list_f = fn(%tt: Tree[A]) -> Tree[B] {
        @tmap(%f, %tt)
      };
      Rose(%f(%v), @map(%list_f, %sub_trees))
    },
  }
}

/*
 * Computes the size of a tree.
 */
def @size[A](%t: Tree[A]) -> Tensor[(), int32] {
  match(%t) {
    Rose(_, %sub_trees) => {
      1 + @sum(@map(@size, %sub_trees))
    },
  }
}

/*
 * Takes a number n and a function f; returns a closure that takes an argument
 * and applies f n times to its argument.
 */
def @iterate[A](%f: fn(A) -> A, %n: Tensor[(), int32]) -> (fn(A) -> A) {
  if (%n == 0) {
    @id
  } else {
    @compose(%f, @iterate(%f, %n - 1))
  }
}