module UInt

import Nat
public import UIntDefs
import UIntAdd
import UIntMult

fun Even(n : UInt) {
  some m:UInt. n = 2 * m
}

fun Odd(n : UInt) {
  some m:UInt. n = 1 + (2 * m)
}

theorem uint_Even_or_Odd: all n:UInt. Even(n) or Odd(n)
proof
  induction UInt
  case bzero {
    have e: Even(bzero) by {
      expand Even
      choose 0
      evaluate
    }
    e
  }
  case dub_inc(n') assume IH {
    have e: Even(dub_inc(n')) by {
      expand Even
      choose 1 + n'
      dub_inc_mult2_add[n']
    }
    e
  }
  case inc_dub(n') assume IH {
    have o: Odd(inc_dub(n')) by {
      expand Odd
      choose n'
      inc_dub_add_mult2[n']
    }
    o
  }
end

theorem uint_odd_one_even: all n:UInt. if Odd(1 + n) then Even(n)
proof
  arbitrary n:UInt
  assume prem: Odd(1 + n)
  obtain m where eq: 1 + n = 1 + 2 * m from expand Odd in prem
  have n_eq: n = 2 * m by apply uint_add_both_sides_of_equal to eq
  expand Even
  choose m
  n_eq
end

theorem uint_even_one_odd: all n:UInt. if Even(1 + n) then Odd(n)
proof
  arbitrary n:UInt
  assume prem: Even(1 + n)
  obtain m where eq: 1 + n = 2 * m from expand Even in prem
  cases uint_zero_or_add_one[m]
  case mz {
    have h: 1 + n = 0 by replace mz in eq
    conclude false by apply uint_not_one_add_zero to h
  }
  case mp {
    obtain m' where m_eq: m = 1 + m' from mp
    have eq1: 1 + n = 2 * (1 + m') by replace m_eq in eq
    have step: 2 * (1 + m') = 1 + (1 + 2 * m') by {
      replace uint_dist_mult_add
      evaluate
    }
    have eq2: 1 + n = 1 + (1 + 2 * m') by transitive eq1 step
    have n_eq: n = 1 + 2 * m' by apply uint_add_both_sides_of_equal to eq2
    expand Odd
    choose m'
    n_eq
  }
end

theorem uint_Even_not_Odd: all n:UInt. Even(n)  not Odd(n)
proof
  arbitrary n:UInt
  have not_both: not (Even(n) and Odd(n)) by {
    assume both: Even(n) and Odd(n)
    obtain a where ea: n = 2 * a from expand Even in (conjunct 0 of both)
    obtain b where ob: n = 1 + 2 * b from expand Odd in (conjunct 1 of both)
    have eq: 2 * a = 1 + 2 * b by {
      transitive (symmetric ea) ob
    }
    have nat_eq: toNat(2 * a) = toNat(1 + 2 * b) by replace eq.
    have even_nat: EvenNat(toNat(2 * a)) by {
      expand EvenNat
      choose toNat(a)
      replace toNat_mult
      evaluate
    }
    have odd_nat: OddNat(toNat(1 + 2 * b)) by {
      expand OddNat
      choose toNat(b)
      replace toNat_add | toNat_mult
      evaluate
    }
    have ev: EvenNat(toNat(1 + 2 * b)) by replace nat_eq in even_nat
    conclude false by apply (apply Even_not_Odd to ev) to odd_nat
  }
  have fwd: if Even(n) then not Odd(n) by {
    assume e
    assume o
    have both: Even(n) and Odd(n) by e, o
    apply not_both to both
  }
  have bkwd: if not Odd(n) then Even(n) by {
    assume not_o
    cases uint_Even_or_Odd[n]
    case e { e }
    case o { conclude false by apply not_o to o }
  }
  fwd, bkwd
end