import Base
import Nat
import UIntDefs
import UIntToFrom
import UIntLess
import UIntAdd
import UIntMult

lemma two: ℕ1 + ℕ1 = ℕ2
proof
  evaluate
end

lemma inc_dub_monus_dub_inc_less: all x:UInt, y:UInt.
  if x < y
  then inc_dub(x)  dub_inc(y) = bzero
proof
  arbitrary x:UInt, y:UInt
  assume prem
  expand operator∸
  replace apply eq_true to prem.
end

lemma inc_dub_monus_dub_inc_greater: all x:UInt, y:UInt.
  if not (x < y)
  then inc_dub(x)  dub_inc(y) = pred(dub(x  y))
proof
  arbitrary x:UInt, y:UInt
  assume prem
  expand operator ∸
  replace apply eq_false to prem.
end

theorem toNat_monus: all x:UInt, y:UInt.
  toNat(x  y) = toNat(x)  toNat(y)
proof
  induction UInt
  case bzero {
    arbitrary y:UInt
    evaluate
  }
  case dub_inc(x') assume IH {
    arbitrary y:UInt
    switch y {
      case bzero {
        expand operator∸ | toNat
        evaluate
      }
      case dub_inc(y') {
        expand operator∸ | toNat
        replace toNat_dub
        replace IH[y']
        replace nat_suc_one_add | dist_mult_monus
        equations
            ℕ2 * toNat(x')  ℕ2 * toNat(y')
              = #ℕ2 * toNat(x') + ℕ0#  ℕ2 * toNat(y')          by evaluate
          ... = (ℕ2 * toNat(x') + #ℕ2  ℕ2#)  ℕ2 * toNat(y')    by .
          ... = ((ℕ2 * toNat(x') + ℕ2)  ℕ2)  ℕ2 * toNat(y')    by replace (apply monus_add_assoc[ℕ2, ℕ2 * toNat(x'), ℕ2]
                                                                         to less_equal_refl).
          ... = ((ℕ2 + ℕ2 * toNat(x'))  ℕ2)  ℕ2 * toNat(y')    by replace add_commute[ℕ2 * toNat(x'), ℕ2].
          ... = (ℕ2 + ℕ2 * toNat(x'))  (ℕ2 + ℕ2 * toNat(y'))    by monus_monus_eq_monus_add
      }
      case inc_dub(y') {
        expand operator∸ | toNat
        switch x' < y' {
          case true assume x_l_y_true {
            replace nat_suc_one_add
            have x_y: x' < y' by replace x_l_y_true.
            have nx_ny: toNat(x') < toNat(y') by apply toNat_less to x_y
            have C: ℕ2*toNat(x') < ℕ2*toNat(y')
              by apply pos_mult_both_sides_of_less[ℕ2,toNat(x'),toNat(y')] to (evaluate, nx_ny)
            have D: ℕ1 + ℕ2*toNat(x') < ℕ1 + ℕ2*toNat(y') by apply add_both_sides_of_less[ℕ1] to C
            have E: ℕ2 + ℕ2*toNat(x')  ℕ1 + ℕ2*toNat(y') by replace nat_suc_one_add in expand operator< in D
            have F: (ℕ2 + ℕ2 * toNat(x'))  (ℕ1 + ℕ2 * toNat(y')) = ℕ0
              by apply nat_monus_zero_iff_less_eq to E
            replace F.
          }
          case false assume x_l_y_false {
            replace IH[y']
            replace dist_mult_monus | nat_suc_one_add
            replace symmetric monus_monus_eq_monus_add[ℕ2 + ℕ2*toNat(x'), ℕ1, ℕ2*toNat(y')]
            have: ℕ1  ℕ2 by evaluate
            replace symmetric (replace add_commute[ℕ2 * toNat(x'), ℕ2] in
                               apply monus_add_assoc[ℕ2, ℕ2*toNat(x'), ℕ1] to recall ℕ1  ℕ2)
            have: ℕ2  ℕ1 = ℕ1 by evaluate
            have xy1: not (x' < y') by replace x_l_y_false.
            have xy2: y'  x' by apply uint_not_less_implies_less_equal to xy1
            have xy3: toNat(y')  toNat(x') by apply toNat_less_equal to xy2
            have y2_x2: ℕ2 * toNat(y')  ℕ2 * toNat(x')  by apply mult_mono_le[ℕ2] to xy3
            replace add_commute[ℕ2 * toNat(x'), ℕ1]
            conclude ℕ1 + (ℕ2 * toNat(x')  ℕ2 * toNat(y')) = (ℕ1 + ℕ2 * toNat(x'))  ℕ2 * toNat(y')
              by apply monus_add_assoc[ℕ2 * toNat(x'), ℕ1, ℕ2 * toNat(y')] to y2_x2
          }
        }
      }
    }
  }
  case inc_dub(x') assume IH {
    arbitrary y:UInt
    switch y {
      case bzero {
        evaluate
      }
      case dub_inc(y') {
        switch x' < y' {
          case true assume x_l_y_true {
            expand operator∸ | toNat
            replace nat_suc_one_add
            have xx_l_syy: toNat(x') + toNat(x')  ℕ1 + toNat(y') + toNat(y') by {
              have x_y: x' < y' by replace x_l_y_true.
              have nx_l_ny: toNat(x') < toNat(y') by apply toNat_less to x_y
              have snx_ny: ℕ1 + toNat(x')  toNat(y')
                by replace nat_suc_one_add in expand operator< in nx_l_ny
              have nx_snx: toNat(x')  ℕ1 + toNat(x')
                by less_equal_add_left
              have nx_ny: toNat(x')  toNat(y')
                by apply less_equal_trans to nx_snx, snx_ny
              have nx2_ny2: toNat(x') + toNat(x')  toNat(y') + toNat(y')
                by apply add_mono to nx_ny, nx_ny
              have ny2_sny2: toNat(y') + toNat(y')  ℕ1 + toNat(y') + toNat(y')
                by less_equal_add_left
              conclude toNat(x') + toNat(x')  ℕ1 + toNat(y') + toNat(y')
                by apply less_equal_trans to nx2_ny2, ny2_sny2
              }
            replace x_l_y_true
            expand lit | operator+
            replace nat_suc_one_add
            replace (replace mult_two in
              apply nat_monus_zero_iff_less_eq[toNat(x') + toNat(x'), ℕ1 + toNat(y') + toNat(y')]
              to xx_l_syy)
            expand lit.
          }
          case false assume x_l_y_false {
            expand toNat
            have x_g_y: not (x' < y') by replace x_l_y_false.
            replace (apply inc_dub_monus_dub_inc_greater to x_g_y)
            replace toNat_pred | toNat_dub | nat_suc_one_add
            replace IH[y']
            replace dist_mult_monus
            replace symmetric monus_monus_eq_monus_add[ℕ1+ℕ2*toNat(x'), ℕ1, ℕ1 + (ℕ1 + ℕ1) * toNat(y')]
            replace add_monus_identity[ℕ1,ℕ2*toNat(x')]
            replace symmetric nat_monus_one_pred[(ℕ2 * toNat(x')  ℕ2 * toNat(y'))]
            replace nat_suc_one_add
            show ℕ2 * toNat(x')  ℕ2 * toNat(y')  ℕ1 = ℕ2 * toNat(x')  (ℕ1 + ℕ2 * toNat(y'))
            replace symmetric (replace add_commute[ℕ2 * toNat(y'), ℕ1] in
                               monus_monus_eq_monus_add[ℕ2*toNat(x'), ℕ2*toNat(y'),ℕ1]).
          }
        }
      }
      case inc_dub(y') {
        expand toNat | operator∸
        replace toNat_dub
        replace IH[y']
        show ℕ2 * (toNat(x')  toNat(y')) = ℕ2 * toNat(x')  ℕ2 * toNat(y')
        replace dist_mult_monus.
      }
    }
  }
end

theorem uint_zero_monus: all x:UInt. 0  x = 0
proof
  arbitrary x:UInt
  expand operator∸.
end

theorem uint_monus_zero: all n:UInt. n  0 = n
proof
  arbitrary n:UInt
  have X: toNat(n  bzero) = toNat(n) by {
    replace toNat_monus
    expand toNat
    .
  }
  conclude n  bzero = n by apply toNat_injective to X
end

theorem uint_monus_cancel: all n:UInt. n  n = 0
proof
  arbitrary n:UInt
  have X: toNat(n  n) = toNat(bzero) by {
    replace toNat_monus
    expand toNat
    replace nat_monus_cancel.
  }
  conclude n  n = bzero by apply toNat_injective to X
end

theorem uint_add_monus_identity: all m:UInt, n:UInt. 
  (m + n)  m = n
proof
  arbitrary m:UInt, n:UInt
  have X: toNat((m + n)  m) = toNat(n) by {
    replace toNat_monus | toNat_add
    replace add_monus_identity.
  }
  conclude (m + n)  m = n by apply toNat_injective to X
end

theorem uint_monus_monus_eq_monus_add : all x:UInt, y:UInt, z:UInt.
  (x  y)  z = x  (y + z)
proof
  arbitrary x:UInt, y:UInt, z:UInt
  have X: toNat((x  y)  z) = toNat(x  (y + z)) by {
    replace toNat_monus | toNat_add | toNat_monus
    replace monus_monus_eq_monus_add.
  }
  conclude (x  y)  z = x  (y + z) by apply toNat_injective to X
end

theorem uint_monus_order : all x : UInt, y : UInt, z : UInt.
  (x  y)  z = (x  z)  y
proof
  arbitrary x:UInt, y:UInt, z:UInt
  have X: toNat((x  y)  z) = toNat((x  z)  y) by {
    replace toNat_monus | toNat_monus
    equations
      (toNat(x)  toNat(y))  toNat(z) = (toNat(x)  toNat(z))  toNat(y)
         by replace monus_order.
  }
  conclude (x  y)  z = (x  z)  y by apply toNat_injective to X
end

theorem uint_add_both_monus: all z:UInt, y:UInt, x:UInt.
  (z + y)  (z + x) = y  x
proof
  arbitrary z:UInt, y:UInt, x:UInt
  suffices toNat((z + y)  (z + x)) = toNat(y  x) by toNat_injective
  replace toNat_monus | toNat_add
  add_both_monus
end