module UInt

import Base
import Nat
import UIntDefs
import UIntLess
import UIntAdd
import UIntMonus
import UIntMult

recfun operator /(n : UInt, m : UInt) -> UInt
  measure n of UInt
{
  if n < m then 0
  else if m = 0 then 0
  else 1 + ((n  m) / m)
}
terminates {
  arbitrary n:UInt, m:UInt
  assume cond: not (n < m) and not (m = 0)
  suffices m + (n  m) < m + n by uint_add_both_sides_of_less[m,nm,n]
  suffices n < m + n by {
    have m_n: m  n by apply uint_not_less_implies_less_equal to conjunct 0 of cond
    replace apply uint_monus_add_identity[n,m] to m_n.
  }
  have m_pos: 0 < m by apply uint_not_zero_pos to conjunct 1 of cond
  conclude n < m + n by {
    replace uint_add_commute in
    apply uint_less_add_pos[n, m]
    to expand lit | fromNat in m_pos
  }
}

fun operator % (n:UInt, m:UInt) {
  n  (n / m) * m
}

theorem uint_zero_div: all x:UInt. if 0 < x then 0 / x = 0
proof
  arbitrary x:UInt
  assume zx: 0 < x
  expand operator/
  replace (apply eq_true to zx).
end

private fun UIntDivPred(n:UInt) {
  all m:UInt. if 0 < m then some r:UInt. (n / m) * m + r = n and r < m
}

private fun NatHasUIntDivPred(k:Nat) {
  all n:UInt. if toNat(n) = k then UIntDivPred(n)
}

lemma uint_division: all n:UInt, m:UInt.
  if 0 < m
  then some r:UInt. (n/m)*m + r = n and r < m
proof
  have SI: all j:Nat. if (all i:Nat. if i < j then NatHasUIntDivPred(i))
                      then NatHasUIntDivPred(j) by {
    arbitrary j:Nat
    assume prem: all i:Nat. if i < j then NatHasUIntDivPred(i)
    expand NatHasUIntDivPred
    arbitrary nn:UInt
    assume tn_j: toNat(nn) = j
    expand UIntDivPred
    arbitrary m:UInt
    assume m_pos: 0 < m
    switch nn < m {
      case true assume n_m {
        suffices some r:UInt. r = nn and r < m by {
          expand operator/
          simplify with n_m.
        }
        choose nn
        conclude nn = nn and nn < m by simplify with n_m.
      }
      case false assume not_n_m {
        have m_ne_z: not (m = 0) by apply uint_pos_not_zero to m_pos
        have m_le_n: m  nn by apply uint_not_less_implies_less_equal to not_n_m
        have nm_n: nn  m < nn by {
          have step1: (nn  m) + m = nn by {
            replace uint_add_commute[nn  m, m]
            apply uint_monus_add_identity[nn, m] to m_le_n
          }
          have step2: nn  m < (nn  m) + m
            by apply uint_less_add_pos[nn  m, m] to expand lit | fromNat in m_pos
          replace step1 in step2
        }
        have tnm_lt_tn: toNat(nn  m) < toNat(nn) by apply toNat_less to nm_n
        have tnm_lt_j: toNat(nn  m) < j by replace tn_j in tnm_lt_tn
        have IH_a: NatHasUIntDivPred(toNat(nn  m)) by apply prem to tnm_lt_j
        have IH_b: UIntDivPred(nn  m)
          by apply (expand NatHasUIntDivPred in IH_a)[nn  m] to .
        have IH_c: some r:UInt. ((nn  m) / m) * m + r = (nn  m) and r < m
          by apply (expand UIntDivPred in IH_b)[m] to m_pos
        obtain r0 where R: ((nn  m) / m) * m + r0 = (nn  m) and r0 < m from IH_c
        choose r0
        have conj0: (nn / m) * m + r0 = nn by {
          expand operator/
          replace (apply eq_false to not_n_m) | (apply eq_false to m_ne_z)
          show (1 + (nn  m) / m) * m + r0 = nn
          replace uint_dist_mult_add_right[1, (nn  m) / m, m]
          show m + ((nn  m) / m) * m + r0 = nn
          replace conjunct 0 of R
          apply uint_monus_add_identity[nn, m] to m_le_n
        }
        conj0, conjunct 1 of R
      }
    }
  }
  arbitrary n:UInt
  have Q: NatHasUIntDivPred(toNat(n)) by apply strong_induction[NatHasUIntDivPred, toNat(n)] to SI
  have R: UIntDivPred(n) by apply (expand NatHasUIntDivPred in Q)[n] to .
  expand UIntDivPred in R
end

theorem uint_div_mod: all n:UInt, m:UInt.
  if 0 < m
  then (n / m) * m + (n % m) = n
proof
  arbitrary n:UInt, m:UInt
  assume m_pos: 0 < m
  have ex: some r:UInt. (n/m)*m + r = n and r < m
    by apply uint_division[n, m] to m_pos
  obtain r where R: (n/m)*m + r = n and r < m from ex
  expand operator%
  define a = (n/m) * m
  have ar_n: a + r = n by R
  have a_le_n: a  n by {
    have a_le_a_r: a  a + r by uint_less_equal_add
    have eq_n_ar: n = a + r by symmetric (conjunct 0 of R)
    replace eq_n_ar
    a_le_a_r
  }
  have id: a + (n  a) = n by apply uint_monus_add_identity[n, a] to a_le_n
  id
end

theorem uint_mod_less_divisor: all n:UInt, m:UInt. if 0 < m then n % m < m
proof
  arbitrary n:UInt, m:UInt
  assume m_pos: 0 < m
  expand operator%
  obtain r where R: (n/m)*m + r = n and r < m
    from apply uint_division[n, m] to m_pos
  define a = (n/m)*m
  have ar_n: a + r = n by R
  have r_na: r = n  a by {
    replace symmetric ar_n.
  }
  replace symmetric r_na
  conjunct 1 of R
end

fun divides(a : UInt, b : UInt) {
  some k:UInt. a * k = b
}

theorem uint_divides_mod: all d:UInt, m:UInt, n:UInt.
  if divides(d, n) and divides(d, m % n) and 0 < n then divides(d, m)
proof
  arbitrary d:UInt, m:UInt, n:UInt
  assume prem
  obtain k1 where dk1_n: d*k1 = n from expand divides in conjunct 0 of prem
  obtain k2 where dk2_mn: d*k2 = m % n from expand divides in conjunct 1 of prem
  have n_pos: 0 < n by prem
  have eq1: (m / n) * n + m % n = m by apply uint_div_mod[m, n] to n_pos
  expand divides
  have eq2: (m / n) * n + d*k2 = m by replace symmetric dk2_mn in eq1
  define X = m / n
  have eq3: (m / n)*d*k1 + d*k2 = m by expand X in replace symmetric dk1_n in eq2
  have eq4: d*(m/n)*k1 + d*k2 = m by replace uint_mult_commute[m/n, d] in eq3
  have eq5: d*((m/n)*k1 + k2) = m by replace symmetric uint_dist_mult_add[d, (m/n)*k1, k2] in eq4
  choose (m/n)*k1 + k2
  eq5
end

theorem uint_div_cancel: all y:UInt. if 0 < y then y / y = 1
proof
  arbitrary y:UInt
  assume y_pos
  have y_ne_zero: not (y = 0) by apply uint_pos_not_zero to y_pos
  expand operator/
  replace (apply eq_false to y_ne_zero)
  show 1 + 0 / y = 1
  replace (apply uint_zero_div to y_pos).
end

theorem uint_zero_mod: all x:UInt. 0 % x = 0
proof
  arbitrary x:UInt
  expand operator%
  expand lit | fromNat
  uint_bzero_monus
end

theorem uint_mod_self_zero: all y:UInt. y % y = 0
proof
  arbitrary y:UInt
  have y_z_p: y = 0 or 0 < y by uint_zero_or_positive[y]
  cases y_z_p
  case y_z {
    replace y_z
    uint_zero_mod
  }
  case y_pos {
    expand operator%
    have yyc: y/y = 1 by apply uint_div_cancel to y_pos
    replace yyc.
  }
end

theorem uint_mod_one: all n:UInt. n % 1 = 0
proof
  arbitrary n:UInt
  have one_pos: 0 < 1 by .
  have nm_lt_1: n % 1 < 1 by apply uint_mod_less_divisor[n, 1] to one_pos
  have nm_le_0: n % 1  0
    by apply uint_less_add_one_implies_less_equal[n % 1, 0] to nm_lt_1
  apply uint_less_equal_zero to nm_le_0
end

theorem uint_div_one: all n:UInt. n / 1 = n
proof
  arbitrary n:UInt
  have one_pos: 0 < 1 by .
  have eq1: (n / 1) * 1 + (n % 1) = n by apply uint_div_mod[n, 1] to one_pos
  have eq2: (n / 1) + (n % 1) = n by eq1
  replace uint_mod_one in eq2
end

theorem uint_add_div_one: all n:UInt, m:UInt.
  if 0 < m
  then (n + m) / m = 1 + n / m
proof
  arbitrary n:UInt, m:UInt
  assume m_pos
  have m_nz: not (m = 0) by apply uint_pos_not_zero to m_pos
  have m_le_nm: m  n + m by {
    have h: m  m + n by uint_less_equal_add[m, n]
    replace uint_add_commute[m, n] in h
  }
  have not_nm_m: not (n + m < m) by {
    assume nm_m: n + m < m
    have nm_le_m: n + m  m by apply uint_less_implies_less_equal to nm_m
    have eq_nm: n + m = m by apply uint_less_equal_antisymmetric to nm_le_m, m_le_nm
    have m_lt_m: m < m by replace eq_nm in nm_m
    apply uint_less_irreflexive to m_lt_m
  }
  equations
          (n + m) / m
        = 1 + ((n + m)  m) / m  by {
            expand operator/
            replace (apply eq_false to not_nm_m)
                  | (apply eq_false to m_nz).
          }
    ... = 1 + n / m              by {
            replace uint_add_commute[n, m].
          }
end

theorem uint_mult_div_inverse: all n:UInt, m:UInt.
  (if 0 < m then (n * m) / m = n)
proof
  define P = fun n:UInt { all m:UInt. if 0 < m then (n * m) / m = n }
  have base: P(0) by {
    expand P
    arbitrary m:UInt
    assume m_pos
    show (0 * m) / m = 0
    apply uint_zero_div to m_pos
  }
  have ind: all n:UInt. if P(n) then P(1 + n) by {
    arbitrary n:UInt
    expand P
    assume IH: all m:UInt. if 0 < m then (n * m) / m = n
    arbitrary m:UInt
    assume m_pos
    show ((1 + n) * m) / m = 1 + n
    have eq1: (1 + n) * m = n * m + m by {
      replace uint_dist_mult_add_right[1, n, m]
      replace uint_add_commute[m, n * m].
    }
    replace eq1
    show (n * m + m) / m = 1 + n
    equations
      (n * m + m) / m = 1 + (n * m) / m  by apply uint_add_div_one[n * m, m] to m_pos
                  ... = 1 + n             by replace apply IH[m] to m_pos.
  }
  expand P in apply uint_induction[P] to base, ind
end

// DIFFICULTY: A direct proof would prove the helper lemma
//   fromNat_div: fromNat(x) / fromNat(y) = fromNat(x / y)
// by strong induction on x. The case x < y and the recursive case
// (when x >= y > 0) work, but the case y = 0 is awkward because
// expanding `operator/` mixes literals from the recfun body ((0:UInt))
// with the substituted argument (bzero), and the `replace` tactic
// fails to unify `0` with `bzero` even though they are equal.
postulate uint_lit_div: all x:Nat, y:Nat. (fromNat(lit(x)) / fromNat(lit(y))) = fromNat(lit(x) / lit(y))

auto uint_lit_div