module UInt

import Base
import Nat
import UIntDefs
import UIntToFrom
import UIntLess

theorem toNat_add: all x:UInt, y:UInt.
  toNat(x + y) = toNat(x) + toNat(y)
proof
  induction UInt
  case bzero {
    arbitrary y:UInt
    expand operator+ | toNat.
  }
  case dub_inc(x') assume IH {
    arbitrary y:UInt
    expand operator+
    switch y {
      case bzero {
        expand toNat.
      }
      case dub_inc(y') {
        expand toNat
        replace toNat_inc | nat_suc_one_add | nat_suc_one_add | nat_suc_one_add
        show ℕ4 + ℕ2 * toNat(x' + y') = ℕ2 + ℕ2 * toNat(x') + ℕ2 + ℕ2 * toNat(y')
        replace IH[y'] | add_commute[ℕ2 * toNat(x'), ℕ2].
      }
      case inc_dub(y') {
        expand inc | toNat
        replace toNat_inc | nat_suc_one_add | nat_suc_one_add
        replace IH[y'] | add_commute[ℕ2 * toNat(x'), ℕ1].
      }
    }
  }
  case inc_dub(x') assume IH {
    arbitrary y:UInt
    switch y {
      case bzero {
        expand operator+ | toNat.
      }
      case dub_inc(y') {
        expand operator+ | toNat | inc
        replace toNat_inc | nat_suc_one_add | nat_suc_one_add
        replace IH[y'] | add_commute[ℕ2 * toNat(x'), ℕ2].
      }
      case inc_dub(y') {
        expand operator+ | toNat | inc
        replace nat_suc_one_add | nat_suc_one_add
        replace IH[y'] | add_commute[ℕ2 * toNat(x'), ℕ1].
      }
    }
  }
end

theorem fromNat_add: all x:Nat, y:Nat.
  fromNat(x + y) = fromNat(x) + fromNat(y)
proof
  arbitrary x:Nat, y:Nat
  suffices toNat(fromNat(x + y)) = toNat(fromNat(x) + fromNat(y))
    by uint_toNat_injective
  replace toNat_add | uint_toNat_fromNat.
end

theorem lit_add_fromNat: all x:Nat, y:Nat.
  fromNat(lit(x)) + fromNat(lit(y)) = fromNat(lit(x) + lit(y))
proof
  arbitrary x:Nat, y:Nat
  symmetric fromNat_add
end

auto lit_add_fromNat

/*
theorem lit_add_fromNat: all x:Nat, y:Nat.
  fromNat(lit(x)) + fromNat(y) = fromNat(lit(x) + y)
proof
  arbitrary x:Nat, y:Nat
  symmetric fromNat_add
end

auto lit_add_fromNat

theorem add_lit_fromNat: all x:Nat, y:Nat.
  fromNat(x) + fromNat(lit(y)) = fromNat(x + lit(y))
proof
  arbitrary x:Nat, y:Nat
  symmetric fromNat_add
end

auto add_lit_fromNat
*/

theorem uint_add_commute: all x:UInt, y:UInt.
  x + y = y + x
proof
  arbitrary x:UInt, y:UInt
  suffices toNat(x + y) = toNat(y + x) by uint_toNat_injective
  equations
	    toNat(x + y)
        = toNat(x) + toNat(y)   by toNat_add
    ... = toNat(y) + toNat(x)   by add_commute
    ... = #toNat(y + x)#        by replace toNat_add.
end

theorem uint_add_assoc: all x:UInt, y:UInt, z:UInt.
  (x + y) + z = x + (y + z)
proof
  arbitrary x:UInt, y:UInt, z:UInt
  suffices toNat((x + y) + z) = toNat(x + (y + z)) by uint_toNat_injective
  equations
	    toNat((x + y) + z)
        = toNat(x + y) + toNat(z)              by toNat_add
    ... = toNat(x) + toNat(y) + toNat(z)       by replace toNat_add.
    ... = toNat(x) + #toNat(y + z)#            by replace toNat_add.
    ... = #toNat(x + (y + z))#                 by replace toNat_add.
end

associative operator+ in UInt

lemma check_assoc: all x:UInt, y:UInt, z:UInt.
  (x + y) + z = x + (y + z)
proof
  .
end

lemma uint_bzero_add: all x:UInt.
  bzero + x = x
proof
  arbitrary x:UInt
  suffices toNat(bzero + x) = toNat(x) by uint_toNat_injective
  equations
  	toNat(bzero + x) 
      = toNat(bzero) + toNat(x)    by toNat_add
  ... = toNat(x)                   by expand toNat evaluate
end

lemma uint_add_bzero: all x:UInt.
  x + bzero = x
proof
  arbitrary x:UInt
  suffices toNat(x + bzero) = toNat(x) by uint_toNat_injective
  equations
  	toNat(x + bzero) 
      = toNat(x) + toNat(bzero)    by toNat_add
  ... = toNat(x)                   by { expand toNat evaluate } // bug?
end

auto uint_add_bzero

theorem uint_bzero_less_one_add: all n:UInt.
  bzero < inc_dub(bzero) + n
proof
  arbitrary n:UInt
  suffices toNat(bzero) < toNat(inc_dub(bzero) + n) by less_toNat
  replace toNat_add
  conclude toNat(bzero) < toNat(inc_dub(bzero)) + toNat(n) by {
    expand 2*toNat
    nat_zero_less_one_add[toNat(n)]
  }
end

theorem uint_add_both_sides_of_equal: all x:UInt, y:UInt, z:UInt.
  x + y = x + z  y = z
proof
  arbitrary x:UInt, y:UInt, z:UInt
  have fwd: if x + y = x + z then y = z by {
    assume prem
    have xy_xz: toNat(x + y) = toNat(x + z)
      by replace prem.
    have xy_xz2: toNat(x) + toNat(y) = toNat(x) + toNat(z)
      by replace toNat_add in xy_xz
    have yz: toNat(y) = toNat(z) by apply add_both_sides_of_equal to xy_xz2
    conclude y = z by apply uint_toNat_injective to yz
  }
  have bkwd: if y = z then x + y = x + z by {
    assume yz
    replace yz.
  }
  fwd, bkwd
end

theorem uint_add_to_bzero: all n:UInt, m:UInt.
  if n + m = bzero
  then n = bzero and m = bzero
proof
  arbitrary n:UInt, m:UInt
  assume prem
  have nm_z: toNat(n + m) = toNat(bzero) by replace prem.
  have nm_z2: toNat(n) + toNat(m) = ℕ0 by expand toNat in replace toNat_add in nm_z
  have nz_mz: toNat(n) = ℕ0 and toNat(m) = ℕ0 by {
    apply nat_add_to_zero[toNat(n), toNat(m)] to nm_z2
  }
  have nz: toNat(n) = toNat(bzero) by expand toNat nz_mz
  have mz: toNat(m) = toNat(bzero) by expand toNat nz_mz
  (apply uint_toNat_injective to nz), (apply uint_toNat_injective to mz)
end

theorem uint_less_equal_add: all x:UInt, y:UInt.
  x  x + y
proof
  arbitrary x:UInt, y:UInt
  suffices toNat(x)  toNat(x + y) by less_equal_toNat[x, x+y]
  replace toNat_add
  less_equal_add
end

theorem uint_less_equal_add_left: all x:UInt, y:UInt. y  x + y
proof
  arbitrary x:UInt, y:UInt
  replace uint_add_commute[x, y]
  uint_less_equal_add
end

theorem uint_less_add_pos: all x:UInt, y:UInt.
  if bzero < y
  then x < x + y
proof
  arbitrary x:UInt, y:UInt
  assume y_pos
  have ny_pos: ℕ0 < toNat(y) by expand toNat in apply toNat_less to y_pos
  have A: toNat(x) < toNat(x) + toNat(y) by {
    apply nat_less_add_pos[toNat(x), toNat(y)] to ny_pos
  }
  have B: toNat(x) < toNat(x + y) by { replace toNat_add A }
  conclude x < x + y by apply less_toNat to B
end
  
lemma uint_pred_inc: all b:UInt.
  pred(inc(b)) = b
proof
  arbitrary b:UInt
  suffices toNat(pred(inc(b))) = toNat(b) by uint_toNat_injective
  replace toNat_pred | toNat_inc
  show pred(suc(toNat(b))) = toNat(b)
  expand pred.
end

lemma inc_one_add: all b:UInt.
  inc(b) = inc_dub(bzero) + b
proof
  arbitrary b:UInt
  suffices toNat(inc(b)) = toNat(inc_dub(bzero) + b) by uint_toNat_injective
  replace toNat_inc | toNat_add
  expand 2*toNat
  show suc(toNat(b)) = suc(ℕ2 * ℕ0) + toNat(b)
  evaluate
end

lemma inc_add_one: all n:UInt.
  inc(n) = 1 + n
proof
  arbitrary n:UInt
  suffices toNat(inc(n)) = toNat(1 + n)  by uint_toNat_injective
  replace toNat_inc | toNat_add
  evaluate
end

theorem uint_less_plus1: all n:UInt.
  n < 1 + n
proof
  arbitrary n:UInt
  suffices toNat(n) < toNat(1 + n) by less_toNat[n, 1 + n]
  replace toNat_add
  suffices toNat(n)  toNat(n) by evaluate
  less_equal_refl
end

theorem uint_less_plus1_true: all n:UInt.
  (n < 1 + n) = true
proof
  arbitrary n:UInt
  apply eq_true to uint_less_plus1[n]
end

auto uint_less_plus1_true

theorem uint_add_both_sides_of_less: all x:UInt, y:UInt, z:UInt.
  x + y < x + z  y < z
proof
  arbitrary x:UInt, y:UInt, z:UInt
  have fwd: if x + y < x + z then y < z by {
    assume prem
    have A: toNat(x + y) < toNat(x + z) by apply toNat_less[x+y,x+z] to prem
    have B: toNat(x) + toNat(y) < toNat(x) + toNat(z) by replace toNat_add in A
    have C: toNat(y) < toNat(z) by apply add_both_sides_of_less to B
    conclude y < z by apply less_toNat to C
  }
  have bkwd: if y < z then x + y < x + z by {
    assume prem
    have A: toNat(y) < toNat(z) by apply toNat_less[y,z] to prem
    have B: toNat(x) + toNat(y) < toNat(x) + toNat(z)
      by apply add_both_sides_of_less[toNat(x)] to A
    have C: toNat(x + y) < toNat(x + z) by {
      replace toNat_add
      B
    }
    conclude x + y < x + z by apply less_toNat to C
  }
  fwd, bkwd
end

theorem uint_less_is_less_equal: all x:UInt, y:UInt.
  x < y = 1 + x  y
proof
  arbitrary x:UInt, y:UInt
  suffices x < y  1 + x  y by iff_equal
  have fwd: if (x < y) then (1 + x  y) by {
    assume prem
    have A: toNat(x) < toNat(y) by apply toNat_less to prem
    have C: ℕ1 + toNat(x)  toNat(y) by {
      replace nat_suc_one_add in
      expand operator< in A
    }
    have D: fromNat(ℕ1 + toNat(x))  fromNat(toNat(y))
      by apply less_equal_fromNat to C
    have E: 1 + fromNat(toNat(x))  fromNat(toNat(y))
      by replace fromNat_add[ℕ1, toNat(x)] | from_one in D
    conclude 1 + x  y by replace uint_fromNat_toNat in E
  }
  have bkwd: if (1 + x  y) then (x < y) by {
    assume prem
    have A: toNat(1 + x)  toNat(y) by apply toNat_less_equal to prem
    have B: toNat(1) + toNat(x)  toNat(y) by replace toNat_add in A
    have C: toNat(1) = ℕ1 by evaluate
    have D: ℕ1 + toNat(x)  toNat(y) by replace C in B
    have E: toNat(x) < toNat(y) by {
      expand operator<
      replace nat_suc_one_add
      D
    }
    conclude x < y by apply less_toNat to E
  }
  fwd, bkwd
end

postulate uint_add_both_sides_of_less_equal: all x:UInt. all y:UInt, z:UInt. x + y  x + z  y  z

postulate uint_add_both_sides_of_le_equal: all x:UInt. all y:UInt, z:UInt. (x + y  x + z) = (y  z)

auto uint_add_both_sides_of_le_equal

theorem uint_less_add_one_implies_less_equal: all x:UInt, y:UInt. (if x < 1 + y then x  y)
proof
  arbitrary x:UInt, y:UInt
  replace uint_less_is_less_equal.
end

theorem uint_not_one_add_zero: all n:UInt.
    not (1 + n = 0)
proof
  arbitrary n:UInt
  assume prem
  have eq1: toNat(1 + n) = toNat(0) by replace prem.
  have eq2: toNat(1) + toNat(n) = toNat(0) by replace toNat_add in eq1
  have eq3: ℕ1 + toNat(n) = ℕ0 by evaluate in eq2
  conclude false by apply nat_not_one_add_zero to eq3
end

postulate uint_zero_le_zero: all x:UInt. (if x  0 then x = 0)

theorem uint_not_one_add_le_zero: all n:UInt.
  not (1 + n  0)
proof
  arbitrary n:UInt
  assume prem
  apply uint_not_one_add_zero[n] to
  apply uint_zero_le_zero to prem
end

theorem uint_not_zero_pos: all n:UInt.
  if not (n = 0) then 0 < n
proof
  arbitrary n:UInt
  assume prem
  switch n {
    case bzero assume nz {
      conclude false by expand lit | fromNat in replace nz in prem
    }
    case dub_inc(n') {
      evaluate
    }
    case inc_dub(n') {
      evaluate
    }
  }
end

theorem uint_pos_not_zero: all n:UInt.
  if 0 < n then not (n = 0)
proof
  arbitrary n:UInt
  assume n_pos
  assume prem
  have zz: 0 < 0 by replace prem in n_pos
  conclude false by apply uint_less_irreflexive to zz
end

    
theorem uint_pos_implies_one_le: all n:UInt.
  if 0 < n
  then 1  n
proof
  arbitrary n:UInt
  assume n_pos
  have A: 1 + 0  n by replace uint_less_is_less_equal in n_pos
  A
end

theorem uint_positive_add_one: all n:UInt.
  if 0 < n
  then some n':UInt. n = 1 + n'
proof
  arbitrary n:UInt
  assume n_pos
  have pos_n: ℕ0 < toNat(n) by expand lit | fromNat | toNat in apply toNat_less to n_pos
  obtain x where eq: toNat(n) = ℕ1 + x
    from apply nat_positive_suc[toNat(n)] to pos_n
  have eq2: fromNat(toNat(n)) = fromNat(ℕ1 + x) by replace eq.
  have eq3: n = fromNat(ℕ1) + fromNat(x)
    by { replace uint_fromNat_toNat | fromNat_add in eq2 }
  choose fromNat(x)
  conclude n = 1 + fromNat(x) by {
    expand lit | 2*fromNat | inc
    expand lit | 2*fromNat | inc in eq3
  }
end

theorem uint_zero_or_add_one: all x:UInt. x = 0 or (some x':UInt. x = 1 + x')
proof
  arbitrary x:UInt
  have zero_or_pos: x = 0 or 0 < x by uint_zero_or_positive[x]
  cases zero_or_pos
  case xz {
    xz
  }
  case xp {
    obtain x' where xs: x = 1 + x' from apply uint_positive_add_one to xp
    have G: some y:UInt. x = 1 + y by {
      choose x'
      xs
    }
    G
  }
end

theorem uint_zero_less_one_add: all n:UInt.
  0 < 1 + n
proof
  expand lit | 2*fromNat | inc
  uint_bzero_less_one_add
end

theorem uint_zero_less_one_add_true: all n:UInt. (0 < 1 + n) = true
proof
  arbitrary n:UInt
  apply eq_true to uint_zero_less_one_add[n]
end

auto uint_zero_less_one_add_true

theorem uint_zero_add: all x:UInt.
  0 + x = x
proof
  expand lit | fromNat
  uint_bzero_add
end

auto uint_zero_add

theorem uint_add_zero: all x:UInt.
  x + 0 = x
proof
  expand lit | fromNat
  uint_add_bzero
end
  
auto uint_add_zero

theorem uint_add_to_zero: all n:UInt, m:UInt.
  if n + m = 0
  then n = 0 and m = 0
proof
  expand lit | fromNat
  uint_add_to_bzero
end
  
theorem uint_one_add_zero_false: all n:UInt.
    (1 + n = 0) = false
proof
  arbitrary n:UInt
  apply eq_false to uint_not_one_add_zero[n]
end

auto uint_one_add_zero_false

postulate uint_add_mono_less_equal: all a:UInt, b:UInt, c:UInt, d:UInt.
  (if a  c and b  d then a + b  c + d)

postulate uint_add_mono_less: all a:UInt, b:UInt, c:UInt, d:UInt.
  (if a < c and b < d then a + b < c + d)

theorem uint_add_one_le_double: all x:UInt. if 0 < x then 1 + x  x + x
proof
  arbitrary x:UInt
  assume pos
  have A: x < x + x by apply uint_less_add_pos[x,x] to expand lit | fromNat in pos
  replace uint_less_is_less_equal in A
end

auto uint_bzero_add

lemma uint_ind: all P:fn UInt -> bool, k : Nat, n : UInt.
  if n = fromNat(k) then
  if P(0) and (all m:UInt. if P(m) then P(1 + m)) then P(n)
proof
  arbitrary P:fn UInt -> bool
  induction Nat
  case zero {
    arbitrary n:UInt
    assume nz
    assume base_IH
    replace nz

    suffices __ by evaluate
    expand lit | fromNat in conjunct 0 of base_IH
  }
  case suc(k') assume IH {
    arbitrary n:UInt
    assume n_sk
    assume base_IH
    have n_ik: n = inc(fromNat(k')) by expand fromNat in n_sk
    have eq0: pred(n) = pred(inc(fromNat(k'))) by replace n_ik.
    have eq1: pred(n) = fromNat(k') by replace uint_pred_inc in eq0
    have: P(pred(n)) by apply (apply IH[pred(n)] to eq1) to base_IH
    have eq2: P(1 + pred(n))
      by apply (conjunct 1 of base_IH)[pred(n)] to recall P(pred(n))
    have eq3: 1 + pred(n) = n by expand lit | 2*fromNat replace eq1 | n_ik | inc_one_add.
    conclude P(n) by replace eq3 in eq2
  }
end

theorem uint_induction: all P:fn UInt -> bool.
  if P(0) and (all m:UInt. if P(m) then P(1 + m))
  then all n : UInt. P(n)
proof
  arbitrary P:fn UInt -> bool
  assume prem
  arbitrary n : UInt
  have eq: n = fromNat(toNat(n)) by replace uint_fromNat_toNat.
  apply (apply uint_ind[P, toNat(n), n] to eq) to prem
end

theorem uint_k_induction: all P: fn UInt -> bool, k:UInt.
  if (P(k) and (all m:UInt. (if P(m) then P(1 + m))))
  then all n:UInt. if k  n then P(n)
proof
  arbitrary P: fn UInt -> bool, k:UInt
  assume prem
  define Q = fun m:UInt { if k  m then P(m) }
  have base: Q(0) by {
    expand Q
    assume kle0
    have kz: k = 0 by apply uint_less_equal_zero to kle0
    conclude P(0) by replace kz in prem
  }
  have ind_case: all m:UInt. (if Q(m) then Q(1 + m)) by {
    arbitrary m:UInt
    expand Q
    assume IH: if k  m then P(m)
    assume k_1m: k  1 + m
    have kless_or_eq: k < 1 + m or k = 1 + m by expand operator≤ in k_1m
    cases kless_or_eq
    case k_l_1m {
      have: k  m by apply uint_less_add_one_implies_less_equal to k_l_1m
      have: P(m) by apply IH to recall k  m
      conclude P(1 + m) by apply (conjunct 1 of prem)[m] to recall P(m)
    }
    case k_e_1m {
      conclude P(1 + m) by replace k_e_1m in prem
    }
  }
  expand Q in apply uint_induction[Q] to base, ind_case
end