module Nat

union Nat {
  zero
  suc(Nat)
}

recursive operator +(Nat,Nat) -> Nat {
  operator +(zero, m) = m
  operator +(suc(n), m) = suc(n + m)
}

recursive operator *(Nat,Nat) -> Nat {
  operator *(zero, m) = zero
  operator *(suc(n), m) = m + (n * m)
}

recursive expt(Nat, Nat) -> Nat {
  expt(zero, n) = suc(zero)
  expt(suc(p), n) = n * expt(p, n)
}

fun operator ^(a : Nat, b : Nat)  {
  expt(b, a)
}

recursive max(Nat,Nat) -> Nat {
  max(zero, n) = n
  max(suc(m'), n) =
    switch n {
      case zero { suc(m') }
      case suc(n') { suc(max(m',n')) }
    }
}

recursive min(Nat,Nat) -> Nat {
  min(zero, n) = zero
  min(suc(m'), n) =
    switch n {
      case zero { zero }
      case suc(n') { suc(min(m',n')) }
    }
}

fun pred(n : Nat) {
  switch n {
    case zero { zero }
    case suc(n') { n' }
  }
}

recursive operator ∸(Nat,Nat) -> Nat {
  operator ∸(zero, m) = zero
  operator ∸(suc(n), m) =
    switch m {
      case zero { suc(n) }
      case suc(m') { n  m' }
    }
}

recursive operator ≤(Nat,Nat) -> bool {
  operator ≤(zero, m) = true
  operator ≤(suc(n'), m) =
    switch m {
      case zero { false }
      case suc(m') { n'  m' }
    }
}

fun operator < (x:Nat, y:Nat) {
  suc(x)  y
}

fun operator > (x:Nat, y:Nat) {
  y < x
}

fun operator ≥ (x:Nat, y:Nat) {
  y  x
}

recursive summation(Nat, Nat, fn Nat->Nat) -> Nat {
  summation(zero, begin, f) = zero
  summation(suc(k), begin, f) = f(begin) + summation(k, suc(begin), f)
}

recursive iterate<T>(Nat, T, fn T -> T) -> T {
  iterate(zero, init, f) = init
  iterate(suc(n), init, f) = f(@iterate<T>(n, init, f))
}

recursive pow2(Nat) -> Nat {
  pow2(zero) = suc(zero)
  pow2(suc(n')) = suc(suc(zero)) * pow2(n')
}

private recursive div2_helper(Nat, bool) -> Nat {
  div2_helper(zero, b) = zero
  div2_helper(suc(n'), b) =
    if b then div2_helper(n', not b)
    else suc(div2_helper(n', not b))
}

private fun div2(n : Nat) {
  div2_helper(n, true)
}

private fun div2_aux(n : Nat) {
  div2_helper(n, false)
}

recursive equal(Nat, Nat) -> bool {
  equal(zero, n) =
    switch n {
      case zero { true }
      case suc(n') { false }
    }
  equal(suc(m'), n) =
    switch n {
      case zero { false }
      case suc(n') { equal(m', n') }
    }
}

recursive dist(Nat, Nat) -> Nat {
  dist(zero, n) = n
  dist(suc(m), n) =
    switch n {
      case zero {
        suc(m)
      }
      case suc(n') {
        dist(m, n')
      }
    }
}

private recursive find_log(Nat, Nat, Nat) -> Nat {
  find_log(zero, n, l) = l
  find_log(suc(m'), n, l) =
    if pow2(l) < suc(n) then
      find_log(m', suc(n), suc(l))
    else
      find_log(m', suc(n), l)
}

define log : fn Nat -> Nat = λn{ find_log(n, zero, zero) }

// Invariant: only apply lit to a literal Nat.
fun lit(a : Nat) { a }