import Base

import UInt

import Option

import Set

import MultiSet

import Pair

import Maps

union List<T> {
  empty
  node(T, List<T>)
}

recursive length<E>(List<E>) -> UInt{
  length([]) = 0
  length(node(n, next)) = 1 + length(next)
}

recursive operator ++<E>(List<E>,List<E>) -> List<E>{
  operator ++([], ys) = ys
  operator ++(node(n, xs), ys) = node(n, xs ++ ys)
}

recursive reverse<E>(List<E>) -> List<E>{
  reverse([]) = []
  reverse(node(n, next)) = reverse(next) ++ node(n, [])
}

recursive set_of<T>(List<T>) -> Set<T>{
  set_of([]) = empty_set()
  set_of(node(x, xs)) = single(x)  set_of(xs)
}

recursive mset_of<T>(List<T>) -> MultiSet<T>{
  mset_of([]) = m_empty()
  mset_of(node(x, xs)) = m_one(x)  mset_of(xs)
}

recursive map<T,U>(List<T>,(fn T -> U)) -> List<U>{
  map([], f) = []
  map(node(x, ls), f) = node(f(x), map(ls, f))
}

recursive foldr<T,U>(List<T>,U,(fn T, U -> U)) -> U{
  foldr([], u, c) = u
  foldr(node(t, ls), u, c) = c(t, foldr(ls, u, c))
}

recursive foldl<T,U>(List<T>,U,(fn U, T -> U)) -> U{
  foldl([], u, c) = u
  foldl(node(t, ls), u, c) = foldl(ls, c(u, t), c)
}

recursive zip<T,U>(List<T>,List<U>) -> List<Pair<T,U>>{
  zip([], ys) = []
  zip(node(x, xs'), ys) = 
    switch ys {
      case [] {
        []
      }
      case node(y, ys') {
        node(pair(x, y), zip(xs', ys'))
      }
    }
}

recursive filter<E>(List<E>,(fn E -> bool)) -> List<E>{
  filter([], P) = []
  filter(node(x, ls), P) = 
    if P(x) then
      node(x, filter(ls, P))
    else
      filter(ls, P)
}

recursive remove<T>(List<T>,T) -> List<T>{
  remove([], y) = []
  remove(node(x, xs'), y) = 
    if x = y then
      xs'
    else
      node(x, remove(xs', y))
}

recursive remove_all<T>(List<T>,T) -> List<T>{
  remove_all([], y) = []
  remove_all(node(x, xs'), y) = 
    if x = y then
      remove_all(xs', y)
    else
      node(x, remove_all(xs', y))
}

recursive get<T>(List<T>,UInt) -> Option<T>{
  get([], i) = none
  get(node(x, xs), i) = 
    if i = 0 then
      just(x)
    else
      get(xs, i  1)
}

recursive take<T>(List<T>,UInt) -> List<T>{
  take([], n) = []
  take(node(x, xs), n) = 
    if n = 0 then
      []
    else
      node(x, take(xs, n  1))
}

recursive drop<T>(List<T>,UInt) -> List<T>{
  drop([], n) = []
  drop(node(x, xs'), n) = 
    if n = 0 then
      node(x, xs')
    else
      drop(xs', n  1)
}

fun head<T>(ls:List<T>) {
  switch ls {
    case [] {
      @none<T>
    }
    case node(x, xs) {
      just(x)
    }
  }
}

fun tail<T>(ls:List<T>) {
  switch ls {
    case [] {
      @[]<T>
    }
    case node(x, xs) {
      xs
    }
  }
}

length_append: (all U:type, xs:List<U>, ys:List<U>. length(xs ++ ys) = length(xs) + length(ys))

length_zero_empty: (all T:type, xs:List<T>. (if length(xs) = 0 then xs = []))

append_assoc: (all U:type, xs:List<U>, ys:List<U>, zs:List<U>. (xs ++ ys) ++ zs = xs ++ (ys ++ zs))

append_empty: (all U:type, xs:List<U>. xs ++ [] = xs)

length_reverse: (all U:type, xs:List<U>. length(reverse(xs)) = length(xs))

length_map: (all T:type, f:(fn T -> T), xs:List<T>. length(map(xs, f)) = length(xs))

map_id: (all T:type, f:(fn T -> T). (if (all x:T. f(x) = x) then (all xs:List<T>. map(xs, f) = xs)))

map_append: (all T:type, f:(fn T -> T), ys:List<T>, xs:List<T>. map(xs ++ ys, f) = map(xs, f) ++ map(ys, f))

map_compose: (all T:type, U:type, V:type, f:(fn T -> U), g:(fn U -> V), ls:List<T>. map(map(ls, f), g) = map(ls, g  f))

zip_empty_right: (all T:type, U:type, xs:List<T>. zip(xs, []:List<U>) = [])

zip_map: (all T1:type, T2:type, U1:type, U2:type, f:(fn T1 -> T2), g:(fn U1 -> U2), xs:List<T1>, ys:List<U1>. zip(map(xs, f), map(ys, g)) = map(zip(xs, ys), pairfun(f, g)))

set_of_empty: (all T:type, xs:List<T>. (if set_of(xs) = empty_set() then xs = []))

set_of_append: (all T:type, xs:List<T>, ys:List<T>. set_of(xs ++ ys) = set_of(xs)  set_of(ys))

mset_of_empty: (all T:type, xs:List<T>. (if mset_of(xs) = @m_empty<T>() then xs = []))

som_mset_eq_set: (all T:type, xs:List<T>. set_of_mset(mset_of(xs)) = set_of(xs))

not_set_of_remove_all: (all W:type, xs:List<W>, y:W. not (y  set_of(remove_all(xs, y))))

take_zero: (all E:type, xs:List<E>. take(xs, 0) = [])

take_append: (all E:type, xs:List<E>, ys:List<E>. take(xs ++ ys, length(xs)) = xs)

length_drop: (all E:type, xs:List<E>, n:UInt. (if n  length(xs) then length(drop(xs, n)) + n = length(xs)))

len_drop: (all E:type, xs:List<E>, n:UInt. (if n  length(xs) then length(drop(xs, n)) + n = length(xs)))

length_drop_zero: (all E:type, xs:List<E>, n:UInt. (if length(xs) < n then length(drop(xs, n)) = 0))

drop_zero_identity: (all E:type, xs:List<E>. drop(xs, 0) = xs)

drop_append: (all E:type, xs:List<E>, ys:List<E>. drop(xs ++ ys, length(xs)) = ys)

get_drop: (all T:type, xs:List<T>, n:UInt, i:UInt, d:T. get(drop(xs, n), i) = get(xs, n + i))

get_append_front: (all T:type, xs:List<T>, ys:List<T>, i:UInt. (if i < length(xs) then get(xs ++ ys, i) = get(xs, i)))

get_append_back: (all T:type, xs:List<T>, ys:List<T>, i:UInt. get(xs ++ ys, length(xs) + i) = get(ys, i))

get_none: (all T:type, xs:List<T>, i:UInt. (if length(xs)  i then get(xs, i) = none))

mset_equal_implies_set_equal: (all T:type, xs:List<T>, ys:List<T>. (if mset_of(xs) = mset_of(ys) then set_of(xs) = set_of(ys)))

head_append: (all E:type, L:List<E>, R:List<E>. (if 0 < length(L) then head(L ++ R) = head(L)))

tail_append: (all E:type, L:List<E>, R:List<E>. (if 0 < length(L) then tail(L ++ R) = tail(L) ++ R))