import Base

import Nat

union UInt {
  bzero
  dub_inc(UInt)
  inc_dub(UInt)
}

recursive toNat(UInt) -> Nat{
  toNat(0) = ℕ0
  toNat(dub_inc(x)) = ℕ2 * suc(toNat(x))
  toNat(inc_dub(x)) = suc(ℕ2 * toNat(x))
}

recursive operator <(UInt,UInt) -> bool{
  operator <(0, y) = 
    switch y {
      case 0 {
        false
      }
      case dub_inc(y') {
        true
      }
      case inc_dub(y') {
        true
      }
    }
  operator <(dub_inc(x'), y) = 
    switch y {
      case 0 {
        false
      }
      case dub_inc(y') {
        x' < y'
      }
      case inc_dub(y') {
        x' < y'
      }
    }
  operator <(inc_dub(x'), y) = 
    switch y {
      case 0 {
        false
      }
      case dub_inc(y') {
        (x' < y' or x' = y')
      }
      case inc_dub(y') {
        x' < y'
      }
    }
}

fun operator ≤(x:UInt, y:UInt) {
  (x < y or x = y)
}

fun operator >(x:UInt, y:UInt) {
  y < x
}

fun operator ≥(x:UInt, y:UInt) {
  y  x
}

recursive operator +(UInt,UInt) -> UInt{
  operator +(0, y) = y
  operator +(dub_inc(x), y) = 
    switch y {
      case 0 {
        dub_inc(x)
      }
      case dub_inc(y') {
        dub_inc(inc(x + y'))
      }
      case inc_dub(y') {
        inc(dub_inc(x + y'))
      }
    }
  operator +(inc_dub(x), y) = 
    switch y {
      case 0 {
        inc_dub(x)
      }
      case dub_inc(y') {
        inc(dub_inc(x + y'))
      }
      case inc_dub(y') {
        inc(inc_dub(x + y'))
      }
    }
}

recursive operator ∸(UInt,UInt) -> UInt{
  operator ∸(0, y) = 0
  operator ∸(dub_inc(x), y) = 
    switch y {
      case 0 {
        dub_inc(x)
      }
      case dub_inc(y') {
        dub(x  y')
      }
      case inc_dub(y') {
        (if x < y' then 0 else inc_dub(x  y'))
      }
    }
  operator ∸(inc_dub(x), y) = 
    switch y {
      case 0 {
        inc_dub(x)
      }
      case dub_inc(y') {
        (if x < y' then 0 else pred(dub(x  y')))
      }
      case inc_dub(y') {
        dub(x  y')
      }
    }
}

recursive operator *(UInt,UInt) -> UInt{
  operator *(0, y) = 0
  operator *(dub_inc(x), y) = 
    switch y {
      case 0 {
        0
      }
      case dub_inc(y') {
        dub(dub_inc((x + y') + x * y'))
      }
      case inc_dub(y') {
        dub_inc(x + dub(y' + x * y'))
      }
    }
  operator *(inc_dub(x), y) = 
    switch y {
      case 0 {
        0
      }
      case dub_inc(y') {
        dub_inc((dub(x) + y') + dub(x * y'))
      }
      case inc_dub(y') {
        inc_dub((x + y') + dub(x * y'))
      }
    }
}

fun sqr(a:UInt) {
  a * a
}

fun operator ^(a:UInt, b:UInt) {
  expt(b, a)
}

recursive fromNat(Nat) -> UInt{
  fromNat(ℕ0) = 0
  fromNat(suc(n)) = inc(fromNat(n))
}

fun max(x:UInt, y:UInt) {
  if x < y then
    y
  else
    x
}

fun min(x:UInt, y:UInt) {
  if x < y then
    x
  else
    y
}

from_zero: fromNat(ℕ0) = 0

from_one: fromNat(ℕ1) = 1

to_fromNat: (all x:Nat. toNat(fromNat(x)) = x)

toNat_less: (all x:UInt, y:UInt. (if x < y then toNat(x) < toNat(y)))

toNat_injective: (all x:UInt, y:UInt. (if toNat(x) = toNat(y) then x = y))

from_toNat: (all b:UInt. fromNat(toNat(b)) = b)

fromNat_injective: (all x:Nat, y:Nat. (if fromNat(x) = fromNat(y) then x = y))

toNat_less_equal: (all x:UInt, y:UInt. (if x  y then toNat(x)  toNat(y)))

less_toNat: (all x:UInt, y:UInt. (if toNat(x) < toNat(y) then x < y))

uint_less_equal_refl: (all n:UInt. n  n)

uint_less_implies_less_equal: (all x:UInt, y:UInt. (if x < y then x  y))

less_equal_toNat: (all x:UInt, y:UInt. (if toNat(x)  toNat(y) then x  y))

uint_less_irreflexive: (all x:UInt. not (x < x))

uint_less_trans: (all x:UInt, y:UInt, z:UInt. (if (x < y and y < z) then x < z))

uint_not_less_zero: (all x:UInt. not (x < 0))

uint_not_less_implies_less_equal: (all x:UInt, y:UInt. (if not (x < y) then y  x))

uint_le_refl: (all x:UInt. x  x)

uint_le_trans: (all x:UInt, y:UInt, z:UInt. (if (x  y and y  z) then x  z))

uint_zero_le: (all x:UInt. 0  x)

uint_le_zero: (all x:UInt. (if x  0 then x = 0))

toNat_add: (all x:UInt, y:UInt. toNat(x + y) = toNat(x) + toNat(y))

uint_add_commute: (all x:UInt, y:UInt. x + y = y + x)

uint_add_assoc: (all x:UInt, y:UInt, z:UInt. (x + y) + z = x + (y + z))

uint_zero_add: (all x:UInt. 0 + x = x)

uint_add_zero: (all x:UInt. x + 0 = x)

uint_zero_less_one_add: (all n:UInt. 0 < 1 + n)

uint_add_both_sides_of_equal: (all x:UInt, y:UInt, z:UInt. ((x + y = x + z)  (y = z)))

uint_add_to_zero: (all n:UInt, m:UInt. (if n + m = 0 then (n = 0 and m = 0)))

uint_less_equal_add: (all x:UInt, y:UInt. x  x + y)

uint_less_add_pos: (all x:UInt, y:UInt. (if 0 < y then x < x + y))

toNat_mult: (all x:UInt, y:UInt. toNat(x * y) = toNat(x) * toNat(y))

uint_mult_commute: (all m:UInt, n:UInt. m * n = n * m)

uint_mult_assoc: (all m:UInt, n:UInt, o:UInt. (m * n) * o = m * (n * o))

toNat_monus: (all x:UInt, y:UInt. toNat(x  y) = toNat(x)  toNat(y))

uint_zero_monus: (all x:UInt. 0  x = 0)

uint_monus_zero: (all n:UInt. n  0 = n)

uint_monus_cancel: (all n:UInt. n  n = 0)

uint_add_monus_identity: (all m:UInt, n:UInt. (m + n)  m = n)

uint_monus_monus_eq_monus_add: (all x:UInt, y:UInt, z:UInt. (x  y)  z = x  (y + z))

uint_monus_order: (all x:UInt, y:UInt, z:UInt. (x  y)  z = (x  z)  y)

uint_add_both_monus: (all z:UInt, y:UInt, x:UInt. (z + y)  (z + x) = y  x)

uint_induction: (all P:(fn UInt -> bool), n:UInt. (if (P(0) and (all m:UInt. (if P(m) then P(1 + m)))) then P(n)))

fromNat_add: (all x:Nat, y:Nat. fromNat(x + y) = fromNat(x) + fromNat(y))

less_fromNat: (all x:Nat, y:Nat. (if x < y then fromNat(x) < fromNat(y)))

less_equal_fromNat: (all x:Nat, y:Nat. (if x  y then fromNat(x)  fromNat(y)))

uint_zero_or_positive: (all x:UInt. (x = 0 or 0 < x))

uint_less_plus1: (all n:UInt. n < 1 + n)

uint_add_both_sides_of_less: (all x:UInt, y:UInt, z:UInt. ((x + y < x + z)  (y < z)))

less_is_less_equal: (all x:UInt, y:UInt. (x < y) = (1 + x  y))

uint_monus_add_identity: (all n:UInt. (all m:UInt. (if m  n then m + (n  m) = n)))

uint_not_one_add_zero: (all n:UInt. not (1 + n = 0))

uint_not_zero_pos: (all n:UInt. (if not (n = 0) then 0 < n))

uint_pos_not_zero: (all n:UInt. (if 0 < n then not (n = 0)))

recfun operator /(n:UInt, m:UInt) -> UInt
measure	n  {
  if n < m then
    0
  else
    if m = 0 then
      0
    else
      1 + (n  m) / m
}

fun operator %(n:UInt, m:UInt) {
  n  (n / m) * m
}

inc_add_one: (all n:UInt. inc(n) = 1 + n)

uint_pos_implies_one_le: (all n:UInt. (if 0 < n then 1  n))

uint_positive_add_one: (all n:UInt. (if 0 < n then some n':UInt. n = 1 + n'))

uint_trichotomy: (all x:UInt, y:UInt. (x < y or x = y or y < x))

uint_less_implies_not_greater: (all x:UInt, y:UInt. (if x < y then not (y < x)))

uint_monus_add_assoc: (all n:UInt, l:UInt, m:UInt. (if m  n then l + (n  m) = (l + n)  m))

uint_zero_mult: (all n:UInt. 0 * n = 0)

uint_mult_zero: (all n:UInt. n * 0 = 0)

uint_one_mult: (all n:UInt. 1 * n = n)

uint_mult_one: (all n:UInt. n * 1 = n)

uint_dist_mult_add: (all a:UInt, x:UInt, y:UInt. a * (x + y) = a * x + a * y)

uint_dist_mult_add_right: (all x:UInt, y:UInt, a:UInt. (x + y) * a = x * a + y * a)

uint_mult_to_zero: (all n:UInt, m:UInt. (if n * m = 0 then (n = 0 or m = 0)))

uint_less_equal_trans: (all m:UInt. (all n:UInt, o:UInt. (if (m  n and n  o) then m  o)))

uint_less_equal_antisymmetric: (all x:UInt, y:UInt. (if (x  y and y  x) then x = y))

fun Even(n:UInt) {
  some m:UInt. n = 2 * m
}

fun Odd(n:UInt) {
  some m:UInt. n = 1 + 2 * m
}

uint_Even_or_Odd: (all n:UInt. (Even(n) or Odd(n)))

uint_odd_one_even: (all n:UInt. (if Odd(1 + n) then Even(n)))

uint_even_one_odd: (all n:UInt. (if Even(1 + n) then Odd(n)))

uint_Even_not_Odd: (all n:UInt. (Even(n)  not Odd(n)))

uint_div_mod: (all n:UInt, m:UInt. (if 0 < m then (n / m) * m + n % m = n))

uint_mod_less_divisor: (all n:UInt, m:UInt. (if 0 < m then n % m < m))

fun divides(a:UInt, b:UInt) {
  some k:UInt. a * k = b
}

uint_divides_mod: (all d:UInt, m:UInt, n:UInt. (if (divides(d, n) and divides(d, m % n) and 0 < n) then divides(d, m)))

uint_div_cancel: (all y:UInt. (if 0 < y then y / y = 1))

uint_mod_self_zero: (all y:UInt. y % y = 0)

uint_zero_mod: (all x:UInt. 0 % x = 0)

uint_zero_div: (all x:UInt. (if 0 < x then 0 / x = 0))

uint_mod_one: (all n:UInt. n % 1 = 0)

uint_div_one: (all n:UInt. n / 1 = n)

uint_add_div_one: (all n:UInt, m:UInt. (if 0 < m then (n + m) / m = 1 + n / m))

uint_mult_div_inverse: (all n:UInt, m:UInt. (if 0 < m then (n * m) / m = n))

uint_two_mult: (all n:UInt. 2 * n = n + n)

uint_equal_implies_less_equal: (all x:UInt, y:UInt. (if x = y then x  y))