module UInt

import Nat
public import UIntDefs
import UIntAdd
import UIntMult

/*
  UInt parity predicates.

  `Even(n)` means `n` is twice another UInt, and `Odd(n)` means `n`
  is one plus twice another UInt.
*/

// `Even(n)` iff `n` equals `2 * m` for some `m`.
fun Even(n : UInt) {
  some m:UInt. n = 2 * m
}

// `Odd(n)` iff `n` equals `1 + 2 * m` for some `m`.
fun Odd(n : UInt) {
  some m:UInt. n = 1 + (2 * m)
}

// Any multiple of two is even.
theorem uint_two_even: all n:UInt. Even(2 * n)
proof
  arbitrary n:UInt
  expand Even
  choose n.
end

// `1 + 2*n` is odd for every `n`.
theorem uint_one_two_odd: all n:UInt. Odd(1 + 2 * n)
proof
  arbitrary n:UInt
  expand Odd
  choose n.
end

// even + even = even
theorem uint_even_add_even: all x:UInt, y:UInt.
  if Even(x) and Even(y) then Even(x + y)
proof
  arbitrary x:UInt, y:UInt
  assume prem: Even(x) and Even(y)
  obtain a where x_2a: x = 2 * a from expand Even in (conjunct 0 of prem)
  obtain b where y_2b: y = 2 * b from expand Even in (conjunct 1 of prem)
  expand Even
  choose a + b
  equations
      x + y
    = 2 * a + 2 * b by replace x_2a | y_2b.
  ... = #2 * (a + b)# by replace uint_dist_mult_add.
end

// even + odd = odd
theorem uint_even_add_odd: all x:UInt, y:UInt.
  if Even(x) and Odd(y) then Odd(x + y)
proof
  arbitrary x:UInt, y:UInt
  assume prem: Even(x) and Odd(y)
  obtain a where x_2a: x = 2 * a from expand Even in (conjunct 0 of prem)
  obtain b where y_1_2b: y = 1 + 2 * b from expand Odd in (conjunct 1 of prem)
  expand Odd
  choose a + b
  equations
      x + y
    = 2 * a + (1 + 2 * b) by replace x_2a | y_1_2b.
  ... = (2 * a + 1) + 2 * b by .
  ... = (1 + 2 * a) + 2 * b by replace uint_add_commute[2 * a, 1].
  ... = 1 + (2 * a + 2 * b) by .
  ... = 1 + #2 * (a + b)# by replace uint_dist_mult_add.
end

// odd + even = odd
theorem uint_odd_add_even: all x:UInt, y:UInt.
  if Odd(x) and Even(y) then Odd(x + y)
proof
  arbitrary x:UInt, y:UInt
  assume prem: Odd(x) and Even(y)
  have odd_yx: Odd(y + x) by {
    apply uint_even_add_odd[y, x] to (conjunct 1 of prem), (conjunct 0 of prem)
  }
  replace uint_add_commute[x, y]
  odd_yx
end

// odd + odd = even
theorem uint_odd_add_odd: all x:UInt, y:UInt.
  if Odd(x) and Odd(y) then Even(x + y)
proof
  arbitrary x:UInt, y:UInt
  assume prem: Odd(x) and Odd(y)
  obtain a where x_1_2a: x = 1 + 2 * a from expand Odd in (conjunct 0 of prem)
  obtain b where y_1_2b: y = 1 + 2 * b from expand Odd in (conjunct 1 of prem)
  expand Even
  choose 1 + (a + b)
  equations
      x + y
    = (1 + 2 * a) + (1 + 2 * b) by replace x_1_2a | y_1_2b.
  ... = 1 + (2 * a + (1 + 2 * b)) by .
  ... = 1 + ((2 * a + 1) + 2 * b) by .
  ... = 1 + ((1 + 2 * a) + 2 * b) by replace uint_add_commute[2 * a, 1].
  ... = 1 + (1 + (2 * a + 2 * b)) by .
  ... = 2 + (2 * a + 2 * b) by .
  ... = 2 + #2 * (a + b)# by replace uint_dist_mult_add.
  ... = #2 * (1 + (a + b))# by replace uint_dist_mult_add.
end

// A product is even if the left factor is.
theorem uint_even_mult_left: all x:UInt, y:UInt.
  if Even(x) then Even(x * y)
proof
  arbitrary x:UInt, y:UInt
  assume prem: Even(x)
  obtain a where x_2a: x = 2 * a from expand Even in prem
  expand Even
  choose a * y
  equations
      x * y
    = (2 * a) * y by replace x_2a.
  ... = 2 * (a * y) by .
end

// A product is even if the right factor is.
theorem uint_even_mult_right: all x:UInt, y:UInt.
  if Even(y) then Even(x * y)
proof
  arbitrary x:UInt, y:UInt
  assume prem: Even(y)
  have even_yx: Even(y * x) by apply uint_even_mult_left[y, x] to prem
  replace uint_mult_commute[x, y]
  even_yx
end

// odd * odd = odd
theorem uint_odd_mult_odd: all x:UInt, y:UInt.
  if Odd(x) and Odd(y) then Odd(x * y)
proof
  arbitrary x:UInt, y:UInt
  assume prem: Odd(x) and Odd(y)
  obtain a where x_1_2a: x = 1 + 2 * a from expand Odd in (conjunct 0 of prem)
  obtain b where y_1_2b: y = 1 + 2 * b from expand Odd in (conjunct 1 of prem)
  expand Odd
  choose b + a * (1 + 2 * b)
  equations
      x * y
    = (1 + 2 * a) * y by replace x_1_2a.
  ... = 1 * y + (2 * a) * y by replace uint_dist_mult_add_right.
  ... = y + (2 * a) * y by .
  ... = (1 + 2 * b) + (2 * a) * (1 + 2 * b) by replace y_1_2b.
  ... = 1 + (2 * b + (2 * a) * (1 + 2 * b)) by .
  ... = 1 + (2 * b + 2 * (a * (1 + 2 * b))) by .
  ... = 1 + #2 * (b + a * (1 + 2 * b))# by replace uint_dist_mult_add.
end

// Every UInt is either even or odd.
theorem uint_Even_or_Odd: all n:UInt. Even(n) or Odd(n)
proof
  induction UInt
  case bzero {
    have e: Even(bzero) by {
      expand Even
      choose 0
      evaluate
    }
    e
  }
  case dub_inc(n') assume IH {
    have e: Even(dub_inc(n')) by {
      expand Even
      choose 1 + n'
      dub_inc_mult2_add[n']
    }
    e
  }
  case inc_dub(n') assume IH {
    have o: Odd(inc_dub(n')) by {
      expand Odd
      choose n'
      inc_dub_add_mult2[n']
    }
    o
  }
end

// Subtracting one from an odd successor leaves an even number.
theorem uint_odd_one_even: all n:UInt. if Odd(1 + n) then Even(n)
proof
  arbitrary n:UInt
  assume prem: Odd(1 + n)
  obtain m where eq: 1 + n = 1 + 2 * m from expand Odd in prem
  have n_eq: n = 2 * m by apply uint_add_both_sides_of_equal to eq
  expand Even
  choose m
  n_eq
end

// Subtracting one from a positive even number leaves an odd number.
theorem uint_even_one_odd: all n:UInt. if Even(1 + n) then Odd(n)
proof
  arbitrary n:UInt
  assume prem: Even(1 + n)
  obtain m where eq: 1 + n = 2 * m from expand Even in prem
  cases uint_zero_or_add_one[m]
  case mz {
    have h: 1 + n = 0 by replace mz in eq
    conclude false by apply uint_not_one_add_zero to h
  }
  case mp {
    obtain m' where m_eq: m = 1 + m' from mp
    have eq1: 1 + n = 2 * (1 + m') by replace m_eq in eq
    have step: 2 * (1 + m') = 1 + (1 + 2 * m') by {
      replace uint_dist_mult_add
      evaluate
    }
    have eq2: 1 + n = 1 + (1 + 2 * m') by transitive eq1 step
    have n_eq: n = 1 + 2 * m' by apply uint_add_both_sides_of_equal to eq2
    expand Odd
    choose m'
    n_eq
  }
end

// Even and not-odd are equivalent.
theorem uint_Even_not_Odd: all n:UInt. Even(n)  not Odd(n)
proof
  arbitrary n:UInt
  have not_both: not (Even(n) and Odd(n)) by {
    assume both: Even(n) and Odd(n)
    obtain a where ea: n = 2 * a from expand Even in (conjunct 0 of both)
    obtain b where ob: n = 1 + 2 * b from expand Odd in (conjunct 1 of both)
    have eq: 2 * a = 1 + 2 * b by {
      transitive (symmetric ea) ob
    }
    have nat_eq: toNat(2 * a) = toNat(1 + 2 * b) by replace eq.
    have even_nat: EvenNat(toNat(2 * a)) by {
      expand EvenNat
      choose toNat(a)
      replace toNat_mult
      evaluate
    }
    have odd_nat: OddNat(toNat(1 + 2 * b)) by {
      expand OddNat
      choose toNat(b)
      replace toNat_add | toNat_mult
      evaluate
    }
    have ev: EvenNat(toNat(1 + 2 * b)) by replace nat_eq in even_nat
    conclude false by apply (apply Even_not_Odd to ev) to odd_nat
  }
  have fwd: if Even(n) then not Odd(n) by {
    assume e
    assume o
    have both: Even(n) and Odd(n) by e, o
    apply not_both to both
  }
  have bkwd: if not Odd(n) then Even(n) by {
    assume not_o
    cases uint_Even_or_Odd[n]
    case e { e }
    case o { conclude false by apply not_o to o }
  }
  fwd, bkwd
end