/*
  Singly-linked lists over an element type T, together with the
  standard operations (length, append, reverse, map, fold, zip,
  filter, get, take, drop, ...) and theorems relating them.

  The two constructors are `empty` (also written `[]`) and
  `node(x, xs)`. The literal syntax `[1, 2, 3]` desugars to
  `node(1, node(2, node(3, empty)))`.
*/

import Base
import UInt
import Option
import Set
import MultiSet
import Pair
import Maps

// A list is either empty or a node carrying a head element and a tail list.
union List<T> {
  empty
  node(T, List<T>)
}

// Number of elements in the list.
recursive length<E>(List<E>) -> UInt {
  length(empty) = 0
  length(node(n, next)) = 1 + length(next)
}

// Does the list contain the given element?
recursive contains<E>(List<E>, E) -> bool {
  contains(empty, x) = false
  contains(node(y, next), x) = (y = x) or contains(next, x)
}

// List concatenation (append). `xs ++ ys` is the list whose elements
// are those of `xs` followed by those of `ys`.
recursive operator ++ <E>(List<E>, List<E>) -> List<E> {
  operator ++([], ys) = ys
  operator ++(node(n, xs), ys) = node(n, xs ++ ys)
}

// Reverse the order of the elements.
recursive reverse<E>(List<E>) -> List<E> {
  reverse([]) = []
  reverse(node(n, next)) = reverse(next) ++ [n]
}

// Convert the list to a set: the set of all elements that appear,
// discarding both order and how many times each element occurs.
recursive set_of<T>(List<T>) -> Set<T> {
  set_of(empty) = 
  set_of(node(x, xs)) = single(x)  set_of(xs)
}

// Convert the list to a multiset: discards order but preserves the
// number of occurrences of each element.
recursive mset_of<T>(List<T>) -> MultiSet<T> {
  mset_of(empty) = m_empty()
  mset_of(node(x, xs)) = m_one(x)  mset_of(xs)
}

// Apply `f` to every element, producing a list of the same length.
recursive map<T,U>(List<T>, fn T->U) -> List<U> {
  map(empty, f) = empty
  map(node(x, ls), f) = node(f(x), map(ls, f))
}

// Right fold: combine the elements right-to-left, starting from `u`,
// using the binary operation `c`. Equivalent to
// `c(x0, c(x1, ... c(x_{n-1}, u) ...))`.
recursive foldr<T,U>(List<T>, U, fn T,U->U) -> U {
  foldr(empty, u, c) = u
  foldr(node(t, ls), u, c) = c(t, foldr(ls, u, c))
}

// Left fold: combine the elements left-to-right, starting from `u`,
// using the binary operation `c`. Equivalent to
// `c(... c(c(u, x0), x1) ..., x_{n-1})`.
recursive foldl<T,U>(List<T>, U, fn U,T->U) -> U {
  foldl(empty, u, c) = u
  foldl(node(t, ls), u, c) = foldl(ls, c(u, t), c)
}

/*
recursive down_from(Nat) -> List {
  down_from(zero) = empty
  down_from(suc(n)) = node(n, down_from(n))
}

define up_to : fn Nat -> List = λ n { reverse(down_from(n)) }
*/
/* interval(num, begin) returns the list of numbers in the half-open
   interval: [begin, begin+num).
 */
 /*
recursive interval(Nat, Nat) -> List {
  interval(zero, begin) = empty
  interval(suc(num), begin) = node(begin, interval(num, suc(begin)))
}
*/

/* range(begin, end) returns the list of numbers in the half-open
   interval: [begin, end).
 */
 /*
define range : fn Nat,Nat -> List = λ b, e { interval(e ∸ b, b) }
*/

// Pair up corresponding elements from two lists. The result has
// length equal to the shorter input; extra elements of the longer
// list are discarded.
recursive zip<T,U>(List<T>, List<U>) -> List< Pair<T, U> > {
  zip(empty, ys) = empty
  zip(node(x, xs'), ys) =
    switch ys {
      case empty { empty }
      case node(y, ys') { node(pair(x,y), zip(xs', ys')) }
    }
}

// Inverse of `zip`: split a list of pairs into a pair of lists,
// each holding the corresponding component.
recursive unzip<T,U>(List< Pair<T, U> >) -> Pair< List<T>, List<U> > {
  unzip(empty) = pair(@empty<T>, @empty<U>)
  unzip(node(p, ps)) =
    define xys = unzip(ps);
    pair(node(first(p), first(xys)), node(second(p), second(xys)))
}

// Keep only the elements that satisfy the predicate `P`.
recursive filter<E>(List<E>, fn (E)->bool) -> List<E> {
  filter(empty, P) = empty
  filter(node(x, ls), P) =
    if P(x) then node(x, filter(ls, P))
    else filter(ls, P)
}

// Remove the first occurrence of `y` from the list. If `y` is not
// present, the list is returned unchanged.
recursive remove<T>(List<T>, T) -> List<T> {
  remove(empty, y) = empty
  remove(node(x, xs'), y) =
    if x = y then
      xs'
    else
      node(x, remove(xs', y))
}

// Remove every occurrence of `y` from the list.
recursive remove_all<T>(List<T>, T) -> List<T> {
  remove_all(empty, y) = empty
  remove_all(node(x, xs'), y) =
    if x = y then
      remove_all(xs', y)
    else
      node(x, remove_all(xs', y))
}

// Lookup the element at position `i` (0-indexed). Returns
// `just(x)` if the index is in range, or `none` if it is past
// the end of the list.
recursive get<T>(List<T>, UInt) -> Option<T> {
  get(empty, i) = none
  get(node(x, xs), i) =
    if i = 0 then
       just(x)
    else
       get(xs, i  1)
}

// The prefix of length `n`. If the list has fewer than `n`
// elements, the whole list is returned.
recursive take<T>(List<T>, UInt) -> List<T> {
  take(empty, n) = empty
  take(node(x, xs), n) =
    if n = 0 then empty
    else node(x, take(xs, n  1))
}

// Drop the first `n` elements and return the rest. If the list
// has fewer than `n` elements, the result is empty.
recursive drop<T>(List<T>, UInt) -> List<T> {
  drop(empty, n) = empty
  drop(node(x, xs'), n) =
    if n = 0 then node(x, xs')
    else drop(xs', n  1)
}

// The first element, as an `Option`: `just(x)` if non-empty,
// otherwise `none`.
fun head<T>(ls : List<T>) {
  switch ls {
    case empty { @none<T> }
    case node(x, xs) { just(x) }
  }
}

// All elements after the first. Returns the empty list when given
// an empty list.
fun tail<T>(ls : List<T>) {
  switch ls {
    case empty { @empty<T> }
    case node(x, xs) { xs }
  }
}

// The last element, as an `Option`: `just(x)` if non-empty,
// otherwise `none`.
//   last([1,2,3]) = just(3),    last([]) = none.
recursive last<T>(List<T>) -> Option<T> {
  last(empty) = @none<T>
  last(node(x, xs)) =
    switch xs {
      case empty { just(x) }
      case node(y, ys) { last(xs) }
    }
}

// All elements except the final one. Returns the empty list for
// empty or singleton input.
//   init([1,2,3]) = [1,2],     init([x]) = [],     init([]) = [].
recursive init<T>(List<T>) -> List<T> {
  init(empty) = @empty<T>
  init(node(x, xs)) =
    switch xs {
      case empty { @empty<T> }
      case node(y, ys) { node(x, init(xs)) }
    }
}

// All suffixes, starting with `xs` and ending with `empty`.
//   tails([1,2,3]) = [[1,2,3], [2,3], [3], []].
recursive tails<T>(List<T>) -> List< List<T> > {
  tails(empty) = node(empty, empty)
  tails(node(x, xs)) = node(node(x, xs), tails(xs))
}

// All prefixes, starting with `empty` and ending with `xs`.
//   inits([1,2,3]) = [[], [1], [1,2], [1,2,3]].
recursive inits<T>(List<T>) -> List< List<T> > {
  inits(empty) = node(empty, empty)
  inits(node(x, xs)) = node(empty, map(inits(xs), λys:List<T>{node(x, ys)}))
}

// Insert `sep` between consecutive elements of `xs`.
//   intersperse([1,2,3,4], 0) = [1, 0, 2, 0, 3, 0, 4]
//   intersperse([x], sep)     = [x]
//   intersperse([], sep)      = []
recursive intersperse<T>(List<T>, T) -> List<T> {
  intersperse(empty, sep) = @empty<T>
  intersperse(node(x, xs), sep) =
    switch xs {
      case empty { node(x, empty) }
      case node(y, ys) { node(x, node(sep, intersperse(xs, sep))) }
    }
}

/********************** Theorems *********************************/

// Length distributes over append: |xs ++ ys| = |xs| + |ys|.
theorem length_append: all U :type, xs :List<U>, ys :List<U>.
  length(xs ++ ys) = length(xs) + length(ys)
proof
  arbitrary U :type
  induction List<U>
  case [] {
    arbitrary ys:List<U>
    conclude length(@[]<U> ++ ys) = length(@[]<U>) + length(ys)  by {
      expand operator++ | length.
    }
  }
  case node(n, xs') suppose IH {
    arbitrary ys :List<U>
    equations
      length(node(n,xs') ++ ys)
          = 1 + length(xs' ++ ys)              by expand operator++ | length.
      ... = 1 + (length(xs') + length(ys))     by replace IH[ys].
      ... = #length(node(n,xs'))# + length(ys) by expand length.
  }
end

// A list whose length is zero must be the empty list.
theorem length_zero_empty: all T:type, xs:List<T>.
  if length(xs) = 0
  then xs = empty
proof
  arbitrary T:type
  arbitrary xs:List<T>
  switch xs {
    case empty { . }
    case node(x, xs') { expand length. }
  }
end

// Append is associative: the order of grouping does not matter.
theorem append_assoc: all U :type, xs :List<U>, ys :List<U>, zs:List<U>.
  (xs ++ ys) ++ zs = xs ++ (ys ++ zs)
proof
  arbitrary U :type
  induction List<U>
  case empty {
    expand operator++.
  }
  case node(n, xs') suppose IH {
    arbitrary ys :List<U>, zs :List<U>
    equations
      (node(n,xs') ++ ys) ++ zs
          = node(n, (xs' ++ ys) ++ zs)     by expand operator++.
      ... = node(n, xs' ++ (ys ++ zs))     by replace IH[ys, zs].
      ... = # node(n,xs') ++ (ys ++ zs) #  by expand operator++.
  }
end

associative operator++ <T> in List<T>

// The empty list is a right identity for append. (The matching
// left identity, `[] ++ xs = xs`, follows directly from the
// definition of `++`.)
theorem append_empty: all U :type, xs :List<U>.
  xs ++ [] = xs
proof
  arbitrary U:type
  induction List<U>
  case empty {
    conclude (empty : List<U>) ++ empty = empty  by expand operator++.
  }
  case node(n, xs') suppose IH: xs' ++ empty = xs' {
    equations
      node(n,xs') ++ empty
          = node(n, xs' ++ empty)    by expand operator++.
      ... = node(n,xs')              by replace IH.
  }
end

// Reversing a list preserves its length.
theorem length_reverse: all U :type, xs :List<U>.
  length(reverse(xs)) = length(xs)
proof
  arbitrary U : type
  induction List<U>
  case empty {
    conclude length(reverse(@empty<U>)) = length(@empty<U>)
        by expand reverse.
  }
  case node(n, xs') suppose IH {
    equations
      length(reverse(node(n,xs')))
          = length(reverse(xs') ++ node(n,empty)) by expand reverse.
      ... = length(reverse(xs')) + length(node(n,empty))
                    by replace length_append<U>[reverse(xs')][node(n,empty)].
      ... = length(xs') + length(node(n,empty))   by replace IH.
      ... = length(xs') + 1                       by expand 2*length.
      ... = 1 + length(xs')                       by uint_add_commute
      ... = # length(node(n,xs')) #               by expand length.
  }
end

// Reversing an append swaps and reverses the two parts.
theorem reverse_append: all U :type, xs :List<U>, ys :List<U>.
  reverse(xs ++ ys) = reverse(ys) ++ reverse(xs)
proof
  arbitrary U :type
  induction List<U>
  case [] {
    arbitrary ys :List<U>
    equations
          reverse(@[]<U> ++ ys)
        = reverse(ys)                     by expand operator++.
    ... = # reverse(ys) ++ [] #           by replace append_empty<U>[reverse(ys)].
    ... = # reverse(ys) ++ reverse([]) #  by expand reverse.
  }
  case node(n, xs') suppose IH {
    arbitrary ys :List<U>
    equations
          reverse(node(n,xs') ++ ys)
        = reverse(node(n, xs' ++ ys))                by expand operator++.
    ... = reverse(xs' ++ ys) ++ [n]                  by expand reverse.
    ... = (reverse(ys) ++ reverse(xs')) ++ [n]       by replace IH[ys].
    ... = reverse(ys) ++ (reverse(xs') ++ [n])       by append_assoc<U>[reverse(ys)][reverse(xs'), [n]]
    ... = # reverse(ys) ++ reverse(node(n,xs')) #    by expand reverse.
  }
end

// `map` preserves length: the output has one entry per input.
theorem length_map: all T:type, f:fn T->T, xs:List<T>.
  length(map(xs, f)) = length(xs)
proof
  arbitrary T:type
  arbitrary f:fn T->T
  induction List<T>
  case empty {
    expand map.
  }
  case node(x, ls') suppose IH {
    equations
      length(map(node(x,ls'),f))
          = 1 + length(map(ls',f))  by expand map | length.
      ... = 1 + length(ls')         by replace IH.
      ... =# length(node(x,ls')) #  by expand length.
  }
end

// Mapping a function that is pointwise the identity leaves the
// list unchanged.
theorem map_id: all T:type, f:fn T->T. if (all x:T. f(x) = x) then
   all xs:List<T>. map(xs, f) = xs
proof
  arbitrary T:type
  arbitrary f:fn T->T
  suppose fxx: (all x:T. f(x) = x)
  induction List<T>
  case empty {
    conclude map(@empty<T>, f) = empty      by expand map.
  }
  case node(x, ls) suppose IH {
    equations
      map(node(x,ls),f)
          = node(f(x), map(ls, f)) by expand map.
      ... = node(x, map(ls, f))    by replace fxx[x].
      ... = node(x,ls)             by replace IH.
  }
end

// `map` distributes over append.
theorem map_append: all T:type, U:type, f: fn T->U, ys:List<T>, xs:List<T>.
  map(xs ++ ys, f) = map(xs,f) ++ map(ys, f)
proof
  arbitrary T:type, U:type
  arbitrary f:fn T->U, ys:List<T>
  induction List<T>
  case empty {
    equations
      map(@empty<T> ++ ys, f)
          = map(ys, f)                          by expand operator++.
      ... =# @empty<U> ++ map(ys, f) #          by expand operator++.
      ... =# map(@empty<T>, f) ++ map(ys, f) #  by expand map.
  }
  case node(x, xs')
    suppose IH: map(xs' ++ ys, f) = map(xs',f) ++ map(ys, f)
  {
    equations
      map((node(x,xs') ++ ys), f)
          = node(f(x), map(xs' ++ ys, f))         by expand operator++ | map.
      ... = node(f(x), map(xs',f) ++ map(ys,f))   by replace IH.
      ... = # map(node(x,xs'),f) ++ map(ys,f) #   by expand map | operator++.
  }
end

// Map fusion: mapping `f` and then `g` is the same as mapping
// the composition `g ∘ f` once.
theorem map_compose: all T:type, U:type, V:type, f:fn T->U, g:fn U->V, ls :List<T>.
  map(map(ls, f), g) = map(ls, g  f)
proof
  arbitrary T:type, U:type, V:type
  arbitrary f:fn T->U, g:fn U->V
  induction List<T>
  case empty { expand map. }
  case node(x, ls) suppose IH {
    equations
      map(map(node(x,ls),f),g)
        = node(g(f(x)), map(map(ls, f), g))     by expand map.
    ... = node(#g  f#(x), map(map(ls, f), g))  by expand operator ∘.
    ... = node((g  f)(x), map(ls, g  f))      by replace IH.
    ... = #map(node(x,ls), g .o. f)#            by expand map.
  }
end

// Zipping any list with an empty list on the right gives an empty
// list. (The matching empty-on-the-left case follows directly from
// the definition of `zip`.)
theorem zip_empty_right: all T:type, U:type, xs:List<T>.
    zip(xs, empty : List<U>) = empty
proof
  arbitrary T:type, U:type
  induction List<T>
  case empty { evaluate }
  case node(x, xs') { evaluate }
end

// `zip` commutes with `map`: zipping after mapping each side equals
// mapping `pairfun(f, g)` over the zipped list.
theorem zip_map: all T1:type, T2:type, U1:type, U2:type,
  f : fn T1 -> T2, g : fn U1 -> U2, xs:List<T1>, ys:List<U1>.
  zip(map(xs, f), map(ys, g)) = map(zip(xs,ys), pairfun(f,g))
proof
  arbitrary T1:type, T2:type, U1:type, U2:type
  arbitrary f:fn T1 -> T2, g:fn U1 -> U2
  induction List<T1>
  case empty {
    arbitrary ys:List<U1>
    conclude zip(map(@empty<T1>, f), map(ys,g))
           = map(zip(@empty<T1>,ys), pairfun(f,g))
        by expand map | zip.
  }
  case node(x, xs') suppose IH {
    arbitrary ys:List<U1>
    switch ys {
      case empty suppose EQ { evaluate }
      case node(y, ys') {
        equations
            zip(map(node(x,xs'),f),map(node(y,ys'),g))
              = node(pair(f(x), g(y)), zip(map(xs',f), map(ys',g)))       by expand map | zip.
          ... = node(pair(f(x), g(y)), map(zip(xs',ys'), pairfun(f,g)))   by replace IH[ys'].
          ... = node(pair(f(x), g(y)), map(zip(xs', ys'), λp{pair(f(first(p)), g(second(p)))}))
                      by expand pairfun.
          ... = map(zip(node(x,xs'),node(y,ys')), pairfun(f,g))   by evaluate
      }
    }
  }
end

/* update to remove use of all_elements
theorem filter_all: all T:type. all P:fn (T)->bool. all xs:List. 
  if all_elements(xs, P) then filter(xs, P) = xs
proof
  arbitrary T:type
  arbitrary P:fn (T)->bool
  induction List
  case empty {
    suppose cond: all_elements(@empty, P)
    conclude filter(@empty, P) = empty by expand filter
  }
  case node(x, xs') suppose IH: if all_elements(xs',P) then filter(xs',P) = xs' {
    suppose Pxs: all_elements(node(x,xs'),P)
    have Px: P(x) by expand all_elements in Pxs
    suffices node(x,filter(xs',P)) = node(x,xs')
        by expand filter and replace (apply eq_true to Px)
    have Pxs': all_elements(xs',P) by expand all_elements in Pxs
    have IH': filter(xs',P) = xs' by apply IH to Pxs'
    replace IH'
  }
end
*/

// A list whose `set_of` is empty must itself be empty.
theorem set_of_empty: all T:type, xs:List<T>.
  if set_of(xs) = 
  then xs = empty
proof
  arbitrary T:type
  arbitrary xs:List<T>
  switch xs {
    case empty {
      .
    }
    case node(x, xs') {
      suppose prem
      conclude false by expand set_of in prem
    }
  }
end

// `set_of` distributes over append, with union as the corresponding
// operation on sets.
theorem set_of_append: all T:type, xs:List<T>, ys:List<T>.
  set_of(xs ++ ys) = set_of(xs)  set_of(ys)
proof
  arbitrary T:type
  induction List<T>
  case empty {
    arbitrary ys:List<T>
    expand operator++ | set_of.
  }
  case node(x, xs') suppose IH {
    arbitrary ys:List<T>
    expand operator++ | set_of
    replace IH[ys].
  }
end

// A list whose `mset_of` is the empty multiset must itself be empty.
theorem mset_of_empty: all T:type, xs:List<T>.
  if mset_of(xs) = @m_empty<T>()
  then xs = empty
proof
  arbitrary T:type
  arbitrary xs:List<T>
  switch xs {
    case empty {
      .
    }
    case node(x, xs') {
      suppose prem
      have X1: cnt(m_one(x)  mset_of(xs'))(x) = cnt(@m_empty<T>())(x)
        by replace (expand mset_of in prem).
      have X2: 1 + cnt(mset_of(xs'))(x) = 0
        by replace cnt_sum in X1
      conclude false by apply uint_not_one_add_zero to X2
    }
  }
end

// Forgetting multiplicity after building the multiset gives the same
// result as building the set directly: `set_of_mset ∘ mset_of = set_of`.
theorem som_mset_eq_set: all T:type, xs:List<T>.
  set_of_mset(mset_of(xs)) = set_of(xs)
proof
  arbitrary T:type
  induction List<T>
  case empty {
    expand mset_of | set_of
    replace som_empty<T>.
  }
  case node(x, xs') suppose IH {
    suffices set_of_mset(m_one(x)  mset_of(xs')) = single(x)  set_of(xs')
        by expand mset_of | set_of.
    suffices single(x)  set_of_mset(mset_of(xs')) = single(x)  set_of(xs')
        by replace som_union<T>[m_one(x), mset_of(xs')]
                 | som_one_single<T>[x].
    replace IH.
  }
end

// After `remove_all(xs, y)`, `y` is no longer in the resulting set.
theorem not_set_of_remove_all: all W:type, xs:List<W>, y:W.
  not (y  set_of(remove_all(xs, y)))
proof
  arbitrary U:type
  induction List<U>
  case empty {
    arbitrary y:U
    expand remove_all | set_of.
  }
  case node(x, xs')
      suppose IH: all y:U. not (y  set_of(remove_all(xs',y)))
  {
    arbitrary y:U
    switch x = y {
      case true suppose xy_true {
        have x_eq_y: x = y by replace xy_true.
        suffices not (y  set_of(remove_all(node(y,xs'),y)))
            by replace x_eq_y.
        suffices not (y  set_of(remove_all(xs',y)))
            by expand remove_all.
        IH[y]
      }
      case false suppose not_xy {
        suffices not (y  set_of(if x = y then remove_all(xs',y)
                                 else node(x,remove_all(xs',y))))
            by expand remove_all.
         suppose y_in_sx_xsy
         have y_in_sx_or_xsy: y  single(x) or y  set_of(remove_all(xs',y))
           by expand set_of in simplify with not_xy in y_in_sx_xsy
         cases y_in_sx_or_xsy
         case y_in_sx {
           have xy: x = y by y_in_sx
           conclude false by simplify with not_xy in xy
         }
         case y_in_xsy {
           conclude false by apply IH[y] to y_in_xsy
         }
      }
    }
  }
end

// Taking zero elements always yields the empty list.
theorem take_zero: all E:type, xs:List<E>.
  take(xs, 0) = []
proof
  arbitrary E:type, xs:List<E>
  switch xs {
    case empty {
      expand take.
    }
    case node(x, xs') {
      expand take.
    }
  }
end

// Taking the first `length(xs)` elements of `xs ++ ys` recovers `xs`.
theorem take_append: all E:type, xs:List<E>, ys:List<E>.
  take(xs ++ ys, length(xs)) = xs
proof
  arbitrary E:type
  induction List<E>
  case empty {
    arbitrary ys:List<E>
    expand operator++ | length
    replace take_zero.
  }
  case node(x, xs') assume IH {
    arbitrary ys:List<E>
    expand operator++ | take | length
    replace IH[ys].
  }  
end
  
// Length of `drop` plus the number dropped equals the original length,
// provided we did not try to drop more than was there.
theorem length_drop: all E:type, xs:List<E>, n:UInt.
  if n  length(xs)
  then length(drop(xs, n)) + n = length(xs)
proof
  arbitrary E:type
  induction List<E>
  case empty {
    arbitrary n:UInt
    assume prem
    suffices n = 0 by expand drop | length.
    apply uint_zero_le_zero[n] to expand length in prem
  }
  case node(x, xs') assume IH: all n:UInt. (if n  length(xs') then length(drop(xs', n)) + n = length(xs')) {
    arbitrary n:UInt
    assume prem
    cases uint_zero_or_add_one[n]
    case nz {
      replace nz
      evaluate
    }
    case n1 {
      obtain n' where n_sn from n1
      replace n_sn
      suffices length(drop(xs', n')) + 1 + n' = 1 + length(xs') by {
        expand length | drop.
      }
      have n_xs: n'  length(xs') by {
        apply uint_add_both_sides_of_less_equal[1, n', length(xs')] to
        expand length in replace n_sn in prem
      }
      have eq: length(drop(xs', n')) + n' = length(xs')
        by apply IH[n'] to n_xs
      conclude length(drop(xs', n')) + 1 + n' = 1 + length(xs') by {
        replace (symmetric eq)
        replace uint_add_commute[length(drop(xs', n')), 1].
      }
    }
  }
end

// Alias for `length_drop`.
theorem len_drop: all E:type, xs:List<E>, n:UInt.
  if n  length(xs)
  then length(drop(xs, n)) + n = length(xs)
proof
  length_drop
end
  
// Dropping more elements than the list contains leaves an empty list.
theorem length_drop_zero: all E:type, xs:List<E>, n:UInt.
  if length(xs) < n
  then length(drop(xs, n)) = 0
proof
  arbitrary E:type
  induction List<E>
  case empty {
    arbitrary n:UInt
    assume prem
    evaluate
  }
  case node(x, xs') assume IH: all n:UInt. (if length(xs') < n then length(drop(xs', n)) = 0) {
    arbitrary n:UInt
    assume prem
    cases uint_zero_or_add_one[n]
    case nz {
      conclude false by {
        apply uint_not_less_zero
        to expand length in replace nz in prem
      }
    }
    case n1 {
      obtain n' where n_sn from n1
      replace n_sn
      expand drop
      have xs_n: length(xs') < n' by {
        apply uint_add_both_sides_of_less[1, length(xs'), n'] to
        expand length in replace n_sn in prem
      }
      conclude length(drop(xs', n')) = 0 by apply IH[n'] to xs_n
    }
  }
end

// Dropping zero elements leaves the list unchanged.
theorem drop_zero_identity: all E:type, xs:List<E>.
  drop(xs, 0) = xs
proof
  arbitrary E:type
  induction List<E>
  case empty {
    expand drop.
  }
  case node(x, xs') suppose IH {
    expand drop.
  }
end

// Dropping `length(xs)` elements from `xs ++ ys` leaves just `ys`.
theorem drop_append: all E:type, xs:List<E>, ys:List<E>.
  drop(xs ++ ys, length(xs)) = ys
proof
  arbitrary E:type
  induction List<E>
  case empty {
    arbitrary ys:List<E>
    expand operator++ | length
    drop_zero_identity
  }
  case node(x, xs') suppose IH {
    arbitrary ys:List<E>
    expand length | operator++ | drop
    IH[ys]
  }
end

// Looking up index `i` after dropping the first `n` elements is the
// same as looking up index `n + i` in the original list.
theorem get_drop: all T:type, xs:List<T>, n:UInt, i:UInt, d:T.
  get(drop(xs, n), i) = get(xs, n + i)
proof
  arbitrary T:type
  induction List<T>
  case empty {
    arbitrary n:UInt, i:UInt, d:T
    evaluate
  }
  case node(x, xs') suppose IH {
    arbitrary n:UInt, i:UInt, d:T
    expand drop
    cases uint_zero_or_add_one[n]
    case nz {
      replace nz.
    }
    case n1 {
      obtain n' where n_sn from n1
      replace n_sn
      expand get
      IH
    }
  }
end

// If the index falls within `xs`, then `get` on `xs ++ ys` returns
// the same element as `get` on `xs` alone.
theorem get_append_front: all T:type, xs:List<T>, ys:List<T>, i:UInt.
  if i < length(xs)
  then get(xs ++ ys, i) = get(xs, i)
proof
  arbitrary T:type
  induction List<T>
  case empty {
    arbitrary ys:List<T>, i:UInt
    suppose i_z: i < length(@empty<T>)
    conclude false by {
      apply uint_not_less_zero to
      expand length in i_z
    }
  }
  case node(x, xs) suppose IH {
    arbitrary ys:List<T>, i:UInt
    suppose i_xxs: i < length(node(x,xs))
    cases uint_zero_or_add_one[i]
    case iz {
      replace iz
      expand operator++ | get.
    }
    case i1 {
      obtain i' where i_si from i1
      replace i_si
      expand operator++ | get
      have i_xs: i' < length(xs) by {
        apply uint_add_both_sides_of_less[1, i', length(xs)] to
        replace i_si in expand length in i_xxs
      }
      apply IH[ys, i'] to i_xs
    }
  }
end

// Indexing past `xs` into `xs ++ ys` is the same as indexing into
// `ys` by the offset.
theorem get_append_back: all T:type, xs:List<T>, ys:List<T>, i:UInt.
  get(xs ++ ys, length(xs) + i) = get(ys, i)
proof
  arbitrary T:type
  induction List<T>
  case empty {
    arbitrary ys:List<T>, i:UInt
    expand operator++ | length.
  }
  case node(x, xs) suppose IH {
    arbitrary ys:List<T>, i:UInt
    expand operator++ | length | get
    IH[ys, i]
  }
end

// An out-of-range index returns `none`.
theorem get_none: all T:type, xs:List<T>, i:UInt.
  if length(xs) <= i
  then get(xs, i) = none
proof
  arbitrary T:type
  induction List<T>
  case [] {
    arbitrary i:UInt
    assume xs_i
    conclude get(@[]<T>, i) = none by expand get.
  }
  case node(x, xs) assume IH {
    arbitrary i:UInt
    assume xxs_i
    have i_pos: 1 + length(xs)  i
      by expand length in xxs_i
    cases uint_zero_or_add_one[i]
    case iz {
      replace iz
      conclude false by {
        apply uint_not_one_add_le_zero to
        expand length in replace iz in xxs_i
      }
    }
    case i1 {
      obtain i' where i_si from i1
      replace i_si
      have xs_i: length(xs)  i' by {
        apply uint_add_both_sides_of_less_equal[1, length(xs), i'] to
        replace i_si in i_pos
      }
      expand get
      conclude get(xs, i') = none  by apply IH to xs_i
    }
  }
end

// If two lists have the same multiset, they have the same underlying
// set. (The converse fails: equal sets can arise from different
// multiplicities.)
theorem mset_equal_implies_set_equal: all T:type, xs:List<T>, ys:List<T>.
  if mset_of(xs) = mset_of(ys)
  then set_of(xs) = set_of(ys)
proof
  arbitrary T:type
  arbitrary xs:List<T>, ys:List<T>
  suppose xs_ys
  equations
    set_of(xs) = #set_of_mset(mset_of(xs))#   by replace som_mset_eq_set<T>[xs].
           ... = set_of_mset(mset_of(ys))     by replace xs_ys.
           ... = set_of(ys)                   by som_mset_eq_set<T>[ys]
end

// If the left list is non-empty, the head of `L ++ R` is the head of `L`.
theorem head_append: all E:type, L:List<E>, R:List<E>.
  if 0 < length(L)
  then head(L ++ R) = head(L)
proof
  arbitrary E:type
  induction List<E>
  case empty {
    arbitrary R:List<E>
    suppose pos_len
    conclude false  by {
      expand length in pos_len
    }
  }
  case node(x, xs) suppose IH {
    arbitrary R:List<E>
    suppose pos_len
    equations
          head(node(x,xs) ++ R)
        = just(x)                        by expand operator++ | head.
    ... =# head(node(x,xs)) #            by expand head.
  }
end

// If the left list is non-empty, the tail of `L ++ R` is `tail(L) ++ R`.
theorem tail_append: all E:type, L:List<E>, R:List<E>.
  if 0 < length(L)
  then tail(L ++ R) = tail(L) ++ R
proof
  arbitrary E:type
  induction List<E>
  case empty {
    arbitrary R:List<E>
    suppose pos_len
    conclude false by {
      expand length in pos_len
    }
  }
  case node(x, xs') suppose IH {
    arbitrary R:List<E>
    suppose pos_len
    expand operator++ | tail.
  }
end

// `length(tails(xs)) = 1 + length(xs)`: one suffix per "cut point",
// including the empty suffix at the end.
theorem length_tails: all T:type. all xs:List<T>.
  length(tails(xs)) = 1 + length(xs)
proof
  arbitrary T:type
  induction List<T>
  case empty {
    expand tails | 2*length.
  }
  case node(x, xs') assume IH: length(tails(xs')) = 1 + length(xs') {
    equations
      length(tails(node(x, xs')))
          = length(node(node(x, xs'), tails(xs')))   by expand tails.
      ... = 1 + length(tails(xs'))                   by expand length.
      ... = 1 + (1 + length(xs'))                    by replace IH.
      ... = # 1 + length(node(x, xs')) #             by expand length.
  }
end

// The first suffix is the whole list.
theorem head_tails: all T:type. all xs:List<T>.
  head(tails(xs)) = just(xs)
proof
  arbitrary T:type
  arbitrary xs:List<T>
  switch xs {
    case empty { expand tails | head. }
    case node(x, xs') { expand tails | head. }
  }
end

// Every `tails(xs)` contains the empty list — it is the last suffix.
theorem tails_contains_empty: all T:type. all xs:List<T>.
  contains(tails(xs), @empty<T>) = true
proof
  arbitrary T:type
  induction List<T>
  case empty { expand tails | contains. }
  case node(x, xs') assume IH: contains(tails(xs'), @empty<T>) = true {
    expand tails | contains
    replace IH.
  }
end

// `length(inits(xs)) = 1 + length(xs)`: one prefix per "cut point",
// including the empty prefix at the front.
theorem length_inits: all T:type. all xs:List<T>.
  length(inits(xs)) = 1 + length(xs)
proof
  arbitrary T:type
  induction List<T>
  case empty {
    expand inits | 2*length.
  }
  case node(x, xs') assume IH: length(inits(xs')) = 1 + length(xs') {
    equations
      length(inits(node(x, xs')))
          = length(node(@empty<T>, map(inits(xs'), λys:List<T>{node(x, ys)})))
              by expand inits.
      ... = 1 + length(map(inits(xs'), λys:List<T>{node(x, ys)}))
              by expand length.
      ... = 1 + length(inits(xs'))
              by replace length_map<List<T>>[λys:List<T>{node(x, ys)}, inits(xs')].
      ... = 1 + (1 + length(xs'))
              by replace IH.
      ... = # 1 + length(node(x, xs')) #
              by expand length.
  }
end

// The first prefix is always the empty list.
theorem head_inits: all T:type. all xs:List<T>.
  head(inits(xs)) = just(@empty<T>)
proof
  arbitrary T:type
  arbitrary xs:List<T>
  switch xs {
    case empty { expand inits | head. }
    case node(x, xs') { expand inits | head. }
  }
end

// Snoc-ing `y` makes `y` the last element: `last(xs ++ [y]) = just(y)`.
theorem last_append_node: all T:type. all xs:List<T>, y:T.
  last(xs ++ node(y, @empty<T>)) = just(y)
proof
  arbitrary T:type
  induction List<T>
  case empty {
    arbitrary y:T
    expand operator++ | last.
  }
  case node(x, xs') assume IH: all y:T. last(xs' ++ node(y, @empty<T>)) = just(y) {
    arbitrary y:T
    switch xs' {
      case empty {
        expand 2*operator++ | 2*last.
      }
      case node(z, zs') assume eq {
        have rec: last(node(z, zs') ++ node(y, @empty<T>)) = just(y)
          by replace eq in IH[y]
        equations
          last(node(x, node(z, zs')) ++ node(y, @empty<T>))
              = last(node(x, node(z, zs' ++ node(y, @empty<T>))))
                                                          by expand 2*operator++.
          ... = last(node(z, zs' ++ node(y, @empty<T>)))  by expand last.
          ... = # last(node(z, zs') ++ node(y, @empty<T>)) #
                                                          by expand operator++.
          ... = just(y)                                   by rec
      }
    }
  }
end

// Snoc-ing `y` makes `init` forget it: `init(xs ++ [y]) = xs`.
theorem init_append_node: all T:type. all xs:List<T>, y:T.
  init(xs ++ node(y, @empty<T>)) = xs
proof
  arbitrary T:type
  induction List<T>
  case empty {
    arbitrary y:T
    expand operator++ | init.
  }
  case node(x, xs') assume IH: all y:T. init(xs' ++ node(y, @empty<T>)) = xs' {
    arbitrary y:T
    switch xs' {
      case empty {
        expand 2*operator++ | 2*init.
      }
      case node(z, zs') assume eq {
        have rec: init(node(z, zs') ++ node(y, @empty<T>)) = node(z, zs')
          by replace eq in IH[y]
        equations
          init(node(x, node(z, zs')) ++ node(y, @empty<T>))
              = init(node(x, node(z, zs' ++ node(y, @empty<T>))))
                                                          by expand 2*operator++.
          ... = node(x, init(node(z, zs' ++ node(y, @empty<T>))))
                                                          by expand init.
          ... = # node(x, init(node(z, zs') ++ node(y, @empty<T>))) #
                                                          by expand operator++.
          ... = node(x, node(z, zs'))                     by replace rec.
      }
    }
  }
end

// Intersperse on a singleton is the identity: nothing to interleave.
theorem intersperse_singleton: all T:type. all x:T, sep:T.
  intersperse(node(x, @empty<T>), sep) = node(x, @empty<T>)
proof
  arbitrary T:type
  arbitrary x:T, sep:T
  expand intersperse.
end