module UInt

import Base
import Nat
import UIntDefs

lemma from_zero: fromNat(ℕ0) = 0
proof
  evaluate
end

lemma from_one: fromNat(ℕ1) = 1
proof
  evaluate
end

lemma toNat_dub : all x:UInt. toNat(dub(x)) = ℕ2 * toNat(x)
proof
  induction UInt
  case bzero {
    evaluate
  }
  case dub_inc(x) assume IH {
    expand dub | 2* toNat.
  }
  case inc_dub(x) assume IH {
    expand dub | toNat
    replace IH.
  }
end

lemma toNat_inc: all x:UInt. toNat(inc(x)) = suc(toNat(x))
proof
  induction UInt
  case bzero {
    evaluate
  }
  case dub_inc(x) assume IH {
    expand inc | toNat
    replace IH.
  }
  case inc_dub(x) assume IH {
    suffices __ by evaluate
    replace nat_suc_one_add | nat_suc_one_add | add_commute[toNat(x), ℕ1].
  }
end

lemma toNat_pred: all x:UInt. toNat(pred(x)) = pred(toNat(x))
proof
  induction UInt
  case bzero {
    evaluate
  }
  case dub_inc(x) assume IH {
    suffices __ by evaluate
    replace nat_suc_one_add | add_commute[toNat(x), ℕ1].
  }
  case inc_dub(x) assume IH {
    suffices __ by evaluate
    replace toNat_dub | mult_two.
  }
end

theorem uint_toNat_fromNat: all x:Nat. toNat(fromNat(x)) = x
proof
  induction Nat
  case 0 {
    evaluate
  }
  case suc(x') assume IH {
    suffices __ by evaluate
    replace toNat_inc | IH.
  }
end

theorem uint_toNat_injective: all x:UInt, y:UInt.
  if toNat(x) = toNat(y) then x = y
proof
  induction UInt
  case bzero {
    arbitrary y:UInt
    expand toNat
    switch y {
      case bzero { . }
      case dub_inc(y') { evaluate }
      case inc_dub(y') { evaluate }
    }
  }
  case dub_inc(x') assume IH {
    arbitrary y:UInt
    switch y {
      case bzero { evaluate }
      case dub_inc(y') {
        expand toNat
        replace nat_suc_one_add // | two | dist_mult_add
        assume prem
        have A: ℕ2 * toNat(x') = ℕ2 * toNat(y')
          by apply add_both_sides_of_equal to prem
        have B: toNat(x') = toNat(y')
          by apply lit_mult_left_cancel to A
        have: x' = y' by apply IH to B
        conclude dub_inc(x') = dub_inc(y')
          by replace recall x' = y'.
      }
      case inc_dub(y') {
        expand toNat
        replace nat_suc_one_add// | dist_mult_add
        assume prem
        have even: EvenNat(ℕ2 + ℕ2*toNat(x')) by {
          expand EvenNat
          choose ℕ1 + toNat(x')
          //replace dist_mult_add.
          replace dist_mult_add[suc(suc(zero)), ℕ1, toNat(x')]
          evaluate
        }
        have even2: EvenNat(ℕ1 + ℕ2*toNat(y')) by {
          replace prem in even
        }
        have odd: OddNat(ℕ1 + ℕ2*toNat(y')) by {
          expand OddNat
          choose toNat(y')
          evaluate
        }
        conclude false by apply (apply Even_not_Odd to even2) to odd
      }
    }
  }
  case inc_dub(x') assume IH {
    arbitrary y:UInt
    switch y {
      case bzero { evaluate }
      case dub_inc(y') {
        expand toNat replace nat_suc_one_add// | two | dist_mult_add
        assume prem
        have odd: OddNat(ℕ1 + ℕ2*toNat(x')) by {
          expand OddNat
          choose toNat(x')
          evaluate
        }
        have even: EvenNat(ℕ2 + ℕ2*toNat(y')) by {
          expand EvenNat
          choose ℕ1 + toNat(y')
          replace dist_mult_add[suc(suc(zero)), ℕ1, toNat(y')]
          evaluate
        }
        have odd2: OddNat(ℕ2 + ℕ2*toNat(y')) by replace prem in odd
        conclude false by apply (apply Even_not_Odd to even) to odd2
      }
      case inc_dub(y') {
        expand toNat replace nat_suc_one_add
        assume prem
        have A: ℕ2 * toNat(x') = ℕ2 * toNat(y')
          by apply add_both_sides_of_equal to prem
        have pos: ℕ0 < ℕ2 by evaluate
        have B: toNat(x') = toNat(y')
          by apply pos_mult_left_cancel[ℕ2, toNat(x'), toNat(y')] to pos, A
        have: x' = y' by apply IH to B
        conclude inc_dub(x') = inc_dub(y') by replace recall x' = y'.
      }
    }
  }
end

theorem uint_fromNat_toNat: all b:UInt. fromNat(toNat(b)) = b
proof
  induction UInt
  case bzero {
    evaluate
  }
  case dub_inc(b') assume IH {
    expand toNat
    replace nat_suc_one_add// | dist_mult_add | two
    suffices toNat(fromNat(ℕ2 + ℕ2 * toNat(b'))) = toNat(dub_inc(b'))
      by uint_toNat_injective[fromNat(ℕ2 + ℕ2 * toNat(b')), dub_inc(b')]
    replace uint_toNat_fromNat
    show ℕ2 + ℕ2 * toNat(b') = toNat(dub_inc(b'))
    expand toNat.
  }
  case inc_dub(b') assume IH {
    expand toNat
    replace nat_suc_one_add
    suffices toNat(fromNat(ℕ1 + ℕ2 * toNat(b'))) = toNat(inc_dub(b'))
      by expand lit expand lit in uint_toNat_injective[fromNat(ℕ1 + ℕ2 * toNat(b')), inc_dub(b')]
    replace uint_toNat_fromNat
    expand toNat
    show ℕ1 + ℕ2 * toNat(b') = suc(ℕ2 * toNat(b'))
    replace nat_suc_one_add.
  }
end

theorem fromNat_injective: all x:Nat, y:Nat.
  if fromNat(x) = fromNat(y) then x = y
proof
  arbitrary x:Nat, y:Nat
  assume prem
  have eq1: toNat(fromNat(x)) = toNat(fromNat(y)) by replace prem.
  conclude x = y by replace uint_toNat_fromNat in eq1
end

theorem uint_equal: all x:Nat, y:Nat.
  (fromNat(lit(x)) = fromNat(lit(y))) = (x = y)
proof
  arbitrary x:Nat, y:Nat
  expand lit
  suffices (fromNat(x) = fromNat(y))  (x = y) by iff_equal
  have fwd: if fromNat(x) = fromNat(y) then x = y by fromNat_injective
  have bkwd: if x = y then fromNat(x) = fromNat(y) by assume eq replace eq.
  fwd , bkwd
end

auto uint_equal