module Nat

import Base
import Pair
import NatDefs
import NatAdd
import NatMonus
import NatMult
import NatLess
import NatEvenOdd

/*
 Properties of div2
 */

lemma div2_zero: div2(zero) = zero
proof
  evaluate
end
 
lemma div2_add_2: all n:Nat.  div2(suc(suc(n))) = suc(div2(n))
proof
  arbitrary n:Nat
  expand div2 | div2_helper | div2_helper.
end

lemma div2_double: all n:Nat.
  div2(n + n) = n
proof
  induction Nat
  case zero {
    expand operator+ | div2 | div2_helper.
  }
  case suc(n') suppose IH {
    suffices div2(suc(n' + suc(n'))) = suc(n')  by expand operator+.
    suffices div2(suc(suc(n' + n'))) = suc(n')  by replace add_suc.
    suffices suc(div2(n' + n')) = suc(n')       by replace div2_add_2.
    replace IH.
  }
end

lemma div2_times2: all n:Nat.
  div2(suc(suc(zero)) * n) = n
proof
  arbitrary n:Nat
  suffices div2(n + (n + zero)) = n  by expand operator*.
  suffices div2(n + n) = n        by .
  div2_double[n]
end

lemma div2_suc_double: all n:Nat.
  div2(suc(n + n)) = n
proof
  induction Nat
  case zero {
    expand operator+ | div2 | div2_helper | div2_helper.
  }
  case suc(n') suppose IH {
    expand div2 | div2_helper | operator+ | div2_helper
    replace add_suc
    show suc(div2_helper(suc(n' + n'), true)) = suc(n')
    suffices suc(div2_helper(n' + n', false)) = suc(n')
      by expand div2_helper.
    replace (expand div2 | div2_helper in IH).
  }
end

lemma div2_suc_times2: all n:Nat.
  div2(suc(suc(suc(zero)) * n)) = n
proof
  arbitrary n:Nat
  suffices div2(suc(n + (n + zero))) = n  by expand operator*.
  suffices div2(suc(n + n)) = n        by .
  div2_suc_double[n]
end

lemma monus_div2: all n:Nat. n  div2(n)  suc(div2(n))
proof
  arbitrary n:Nat
  cases Even_or_Odd[n]
  case even {
    obtain k where n_2k: n = suc(suc(zero)) * k from expand EvenNat in even
    suffices k  suc(k) by replace n_2k | div2_times2[k] | two_mult[k] | add_monus_identity[k,k].
    less_equal_suc
  }
  case odd {
    obtain k where n_s2k: n = suc(suc(suc(zero)) * k) from expand OddNat in odd
    suffices suc(k + k)  k  suc(k)
      by replace n_s2k | div2_suc_times2[k] | two_mult[k].
    have eq_1: suc(k + k) = suc(zero) + (k + k) by expand 2* operator+.
    suffices (suc(zero) + (k + k))  k  suc(k) by replace eq_1.
    have k_kk: k  k + k by less_equal_add
    have eq_2: (suc(zero) + (k + k))  k = suc(zero) + ((k + k)  k) by symmetric apply monus_add_assoc[k + k, suc(zero), k] to k_kk
    suffices suc(zero) + ((k + k)  k)  suc(k) by replace eq_2.
    suffices ((k + k)  k)  k by expand 1* operator+ | operator≤.
    suffices k  k by replace add_monus_identity[k,k].
    less_equal_refl
  }
end

lemma monus_div2_pos: all n:Nat. if zero < n then zero < n  div2(n)
proof
  arbitrary n:Nat
  cases Even_or_Odd[n]
  case even {
    obtain k where n_2k: n = suc(suc(zero)) * k from expand EvenNat in even
    suffices if zero < k + k then zero < k
      by replace n_2k | div2_times2[k] | two_mult[k] | add_monus_identity[k,k].
    switch k {
      case zero {
        conclude if zero < zero + zero then zero < zero
           by .
      }
      case suc(k') {
        assume _
        expand operator< | 2*operator≤.
      }
    }
  }
  case odd {
    obtain k where n_s2k: n = suc(suc(suc(zero)) * k) from expand OddNat in odd
    suffices if zero < suc(k + k) then zero < suc(k + k)  k
      by replace n_s2k | div2_suc_times2[k] | two_mult[k].
    assume _
    have eq_1: suc(k + k) = suc(zero) + (k + k) by expand 2* operator+.
    suffices zero < (suc(zero) + (k + k))  k by replace eq_1.
    have k_kk: k  k + k by less_equal_add
    have eq_2: (suc(zero) + (k + k))  k = suc(zero) + ((k + k)  k) by symmetric apply monus_add_assoc[k + k, suc(zero), k] to k_kk
    suffices zero < suc(zero) + ((k + k)  k) by replace eq_2.
    expand operator+ | operator< | 2* operator≤.
  }
end

lemma div2_aux_pos_less: all n:Nat. if zero < n then div2(n) < n and div2_aux(n)  n
proof
  induction Nat
  case zero {
    assume: zero < zero
    conclude false by apply less_irreflexive to recall zero < zero
  }
  case suc(n') assume IH {
    assume: zero < suc(n')
    switch n' {
      case zero {
        evaluate
      }
      case suc(n'') assume np_eq {
        have: zero < n' by {
            suffices zero < suc(n'') by replace np_eq.
            expand operator< | 2* operator≤.
        }
        have IH': div2(n') < n' and div2_aux(n')  n'
          by apply IH to (recall zero < n')
        suffices (div2_helper(n'', true)  n''
                 and div2_helper(n'', false)  suc(n''))
          by evaluate
        have _1: div2_helper(n'', true)  n'' by {
          expand div2_aux | div2_helper | operator≤ in
          replace np_eq in
          conjunct 1 of IH'
        }
        have _2: div2_helper(n'', false)  suc(n'') by {
          have lt1: div2_helper(n'', false)  n'' by {
            expand div2 | div2_helper | operator< | operator≤ in
            replace np_eq in
            conjunct 0 of IH'
          }
          have lt2: n''  suc(n'') by less_equal_suc
          apply less_equal_trans to lt1, lt2
        }
        _1, _2
      }
    }
  }
end

lemma div2_pos_less: all n:Nat. if zero < n then div2(n) < n
proof
  div2_aux_pos_less
end

lemma div2_aux_mono: all x:Nat, y:Nat.
  if x  y then div2(x)  div2(y) and div2_aux(x)  div2_aux(y)
proof
  induction Nat
  case zero {
    arbitrary y:Nat
    assume _
    conclude div2(zero)  div2(y) and div2_aux(zero)  div2_aux(y)
        by evaluate
  }
  case suc(x') assume IH {
    arbitrary y:Nat
    assume sx_y
    switch y {
      case zero assume y_eq {
        conclude false by expand operator≤ in replace y_eq in sx_y
      }
      case suc(y') assume y_eq {
        suffices div2_helper(x', false)  div2_helper(y', false)
             and div2_helper(x', true)  div2_helper(y', true)
            by evaluate
        have xy: x'  y' by {
          apply suc_less_equal_iff_less_equal_suc to
          replace y_eq in sx_y
        }
        expand div2 | div2_aux in
        apply IH[y'] to xy
      }
    }
  }
end

lemma div2_mono: all x:Nat, y:Nat.
  if x  y then div2(x)  div2(y)
proof
  div2_aux_mono
end

/*
 Properties of Division and Modulo
 */

// Division by iterated subtraction.
recfun operator /(n : Nat, m : Nat) -> Nat
  measure n of Nat
{
  if n < m then zero
  else if m = zero then zero
  else suc(zero) + (n  m) / m
}
terminates {
  arbitrary n:Nat, m:Nat
  assume cond: not (n < m) and not (m = zero)
  suffices m + (n  m) < m + n by add_both_sides_of_less[m,nm,n]
  suffices n < m + n by {
    have m_n: m  n by apply not_less_implies_less_equal to conjunct 0 of cond
    replace apply monus_add_identity[n,m] to m_n.
  }
  obtain m' where m_sm: m = suc(m') from apply not_zero_suc to conjunct 1 of cond
  suffices n < suc(m') + n by replace m_sm.
  suffices n  m' + n by expand operator+ | operator< | operator≤.
  replace add_commute
  show n  n + m'
  less_equal_add
}

fun operator % (n:Nat, m:Nat) {
  n  (n / m) * m
}

lemma strong_induction_aux: all P: fn Nat -> bool.
  if (all n:Nat. if (all k:Nat. if k < n then P(k)) then P(n))
  then (all j:Nat, k:Nat. if k < j then P(k))
proof
  arbitrary P: fn Nat -> bool
  assume H
  induction Nat
  case zero {
    arbitrary k:Nat
    assume neg_k
    conclude false by apply not_less_zero to neg_k
  }
  case suc(j') assume IH {
    arbitrary k:Nat
    assume: k < suc(j')
    have: k  j' by evaluate in recall k < suc(j')
    cases (apply less_equal_implies_less_or_equal to recall k  j')
    case: k < j' {
      conclude P(k) by apply IH to recall k < j'
    }
    case: k = j' {
      have A: all m:Nat. if m < j' then P(m) by {
        arbitrary m:Nat
        assume: m < j'
        have B: all k':Nat. (if k' < m then P(k')) by {
          arbitrary k':Nat
          assume: k' < m
          have: k' < j' by apply less_trans to (recall k' < m), (recall m < j')
          conclude P(k') by apply IH to recall k' < j'
        }
        conclude P(m) by apply H[m] to B
      }
      have: P(j') by apply H[j'] to A
      conclude P(k) by replace symmetric (recall k = j') in recall P(j')
    }
  }
end
  
theorem strong_induction: all P: fn Nat -> bool, n:Nat.
  if (all j:Nat. if (all i:Nat. if i < j then P(i)) then P(j))
  then P(n)
proof
  arbitrary P: fn Nat -> bool, n:Nat
  assume prem
  have X: all j:Nat, k:Nat. if k < j then P(k)
    by apply strong_induction_aux[P] to prem
  have: n < suc(n) by {
    suffices n  n  by evaluate
    less_equal_refl
  }
  conclude P(n) by apply X[suc(n),n] to recall n < suc(n)
end

private fun DivPred(n:Nat) {
  all m:Nat. if zero < m then some r:Nat. (n / m) * m + r = n and r < m
}

lemma division: all n:Nat, m:Nat.
  if zero < m
  then some r:Nat. (n/m)*m + r = n and r < m
proof
  arbitrary n:Nat
  have SI: all j:Nat. if (all i:Nat. if i < j then DivPred(i))
                      then DivPred(j) by {
    arbitrary j:Nat
    assume prem: all i:Nat. if i < j then DivPred(i)
    suffices all m:Nat. if zero < m then some r:Nat. ((j/m)*m + r = j and r < m)
      by expand DivPred.
    arbitrary m:Nat
    assume m_pos
    switch j < m {
      case true assume j_m_true {
        suffices some r:Nat. r = j and r < m by {
          expand operator /
          replace j_m_true.
        }
        choose j
        conclude j = j and j < m   by replace j_m_true.
      }
      case false assume j_m_false {
        switch m = zero {
          case true assume m_z_true {
            have: m = zero by replace m_z_true.
            conclude false
              by apply less_irreflexive to replace (recall m = zero) in m_pos
          }
          case false assume m_z_false {
            have not_m_z: not (m = zero) by replace m_z_false.
            have not_j_m: not (j < m) by replace j_m_false.
            have m_j: m  j by apply not_less_implies_less_equal to not_j_m
            obtain m' where m_sm: m = suc(m') from apply not_zero_suc to not_m_z
            expand operator / replace j_m_false | m_z_false
            show some r:Nat. ((suc(zero) + (j  m) / m) * m + r = j and r < m)
            replace dist_mult_add_right
            show some r:Nat. (m + ((j  m) / m) * m + r = j and r < m)
            have jm_j: j  m < j by {
              suffices m + (j  m) < m + j by add_both_sides_of_less[m,jm,j]
              suffices j < m + j by replace apply monus_add_identity[j,m] to m_j.
              replace m_sm
              suffices j  m' + j by expand operator+ | operator< | operator≤.
              replace add_commute
              less_equal_add
            }
            have div_mp: DivPred(j  m) by apply prem[j  m] to jm_j
            obtain r where div_jm: ((j  m) / m) * m + r = j  m and r < m
              from apply (expand DivPred in div_mp)[m] to m_pos
            choose r
            suffices (m + (j  m) = j and r < m)
              by replace conjunct 0 of div_jm.
            have mjm_j: m + (j  m) = j by apply monus_add_identity[j, m] to m_j
            mjm_j, conjunct 1 of div_jm
          }
        }
      }
    }
  }
  have: DivPred(n) by apply strong_induction[DivPred, n] to SI
  expand DivPred in recall DivPred(n)
end

lemma div_mod: all n:Nat, m:Nat.
  if zero < m
  then (n / m) * m + (n % m) = n
proof
  arbitrary n:Nat, m:Nat
  assume m_pos: zero < m
  have p1: zero * m  n by evaluate
  have p2: n < (suc(zero) + (zero + (suc(zero) + n) * m)) * m by {
    obtain m' where m_sm: m = suc(m') from apply positive_suc to m_pos
    replace m_sm | dist_mult_add_right[suc(zero),n] | mult_commute[n]
    | suc_one_add[m'] | dist_mult_add_right[suc(zero),m',n] 
    | dist_mult_add | add_commute[m',n] | add_commute[suc(zero),n]
    expand operator<
    replace suc_one_add[n]
    less_equal_add[suc(zero)+n]
  }
  have ex: some r:Nat. (n/m)*m + r = n and r < m
    by apply division[n,m] to m_pos
  obtain r where R: (n/m)*m + r = n and r < m from ex
  expand operator%
  define a = (n / m) * m
  have ar_n: a + r = n by R
  have a_le_a_r: a  a + r by less_equal_add
  have n_eq_a_r: n = a + r by symmetric (conjunct 0 of R)
  have a_le_n: a  n by {
    replace n_eq_a_r
    show a  a + r
    a_le_a_r
  }
  conclude a + (n  a) = n
    by apply monus_add_identity to a_le_n
end

lemma mod_less_divisor: all n:Nat, m:Nat. if zero < m then n % m < m
proof
  arbitrary n:Nat, m:Nat
  assume m_pos: zero < m
  expand operator%
  obtain r where R: (n / m) * m + r = n and r < m
    from apply division[n, m] to m_pos
  define a = (n / m)*m
  have ar_n: a + r = n by R
  have ar_a_a: (a + r)  a = r  by add_monus_identity[a][r]
  have r_na: r = n  a by {
    replace symmetric ar_n
    symmetric ar_a_a
  }
  replace symmetric r_na
  show r < m
  conjunct 1 of R
end

lemma mult_div_mod_self: all y:Nat.
  if zero < y
  then y / y = suc(zero) and y % y = zero
proof
  arbitrary y:Nat
  assume y_pos: zero < y
  have div_rem: (y / y) * y + (y % y) = y
    by apply div_mod[y,y] to y_pos
  have rem_less: y % y < y
    by apply mod_less_divisor[y, y] to y_pos
  switch (y / y) {
    case zero assume yyz {
      have rem_eq: y % y = y by replace yyz in div_rem
      conclude false by apply (apply less_not_equal to rem_less) to rem_eq
    }
    case suc(y') assume ysy {
      have eq1: y + (y' * y + y % y) = y + zero by {
        expand operator* in replace ysy in div_rem
      }
      have eq2: y' * y + y%y = zero by apply left_cancel to eq1
      have rem_zero: y % y = zero by apply add_to_zero to eq2
      suffices suc(y') = suc(zero) by replace rem_zero | ysy.
      have yz: y' = zero by {
        have yyz: y' * y = zero by apply add_to_zero to eq2
        obtain y'' where ysy2: y = suc(y'') from apply positive_suc to y_pos
        have eq3: y' + y'' * y' = zero by {
          expand operator* in
          replace ysy2 | mult_commute in yyz
        }
        conclude y' = zero by apply add_to_zero to eq3
      }
      conclude suc(y') = suc(zero) by replace yz.
    }
  }
end

fun divides(a : Nat, b : Nat) {
  some k:Nat. a * k = b
}

lemma divides_mod: all d:Nat, m:Nat, n:Nat.
  if divides(d, n) and divides(d, m % n) and zero < n then divides(d, m)
proof
  arbitrary d:Nat, m:Nat, n:Nat
  assume prem
  obtain k1 where dk1_n: d*k1 = n from expand divides in conjunct 0 of prem
  obtain k2 where dk2_mn: d*k2 = m % n from expand divides in conjunct 1 of prem
  have n_pos: zero < n by prem
  have eq1: (m / n) * n + m % n = m by apply div_mod[m,n] to n_pos
  expand divides
  show some k:Nat. d * k = m
  have eq2: (m / n) * n + d*k2 = m by replace symmetric dk2_mn in eq1
  define X = m / n
  have eq3: (m / n)*d*k1 + d*k2 = m by expand X in replace symmetric dk1_n in eq2
  have eq4: d*(m/n)*k1 + d*k2 = m by replace mult_commute[m/n,d] in eq3
  have eq5: d*((m/n)*k1 + k2) = m by replace symmetric dist_mult_add[d,(m/n)*k1,k2] in eq4
  choose (m/n)*k1 + k2
  eq5
end

lemma div_cancel: all y:Nat. if zero < y then y / y = suc(zero)
proof
  arbitrary y:Nat
  assume y_pos: zero < y
  apply mult_div_mod_self[y] to y_pos
end

lemma mod_self_zero: all y:Nat. y % y = zero
proof
  arbitrary y:Nat
  switch y {
    case zero {
      expand operator%
      zero_monus
    }
    case suc(y') assume y_sy {
      have y_pos: zero < y by {
        replace y_sy
        evaluate
      }
      replace symmetric y_sy
      show y % y = zero
      apply mult_div_mod_self[y] to y_pos
    }
  }
end

lemma zero_mod: all x:Nat. zero % x = zero
proof
  arbitrary x:Nat
  expand operator%
  zero_monus
end

lemma zero_div: all x:Nat. if zero < x then zero / x = zero
proof
  arbitrary x:Nat
  assume xpos
  have eq1: (zero/x)*x + (zero % x) = zero by apply div_mod[zero,x] to xpos
  have eq2: (zero/x)*x = zero  by apply add_to_zero to eq1
  switch x {
    case zero assume xz {
      conclude false by apply (apply less_not_equal to xpos) to replace xz.
    }
    case suc(x') assume xsx {
      have eq3: zero / suc(x') + x' * (zero / suc(x')) = zero
        by expand operator* in replace xsx | mult_commute in eq2
      conclude zero / suc(x') = zero by apply add_to_zero to eq3
    }
  }
end

lemma mod_one: all n:Nat. n % suc(zero) = zero
proof
  arbitrary n:Nat
  have one_pos: zero < suc(zero) by evaluate
  have n1_1:  n % suc(zero) < suc(zero) by apply mod_less_divisor[n,suc(zero)] to one_pos
  switch n % suc(zero) {
    case zero assume n1z { n1z }
    case suc(n') assume n1_sn {
      conclude false by evaluate in replace n1_sn in n1_1
    }
  }
end

lemma div_one: all n:Nat. n / suc(zero) = n
proof
  arbitrary n:Nat
  have one_pos: zero < suc(zero) by evaluate
  have eq1: (n / suc(zero)) * suc(zero) + (n % suc(zero)) = n by apply div_mod[n,suc(zero)] to one_pos
  have eq2: (n / suc(zero)) + (n % suc(zero)) = n by eq1
  replace mod_one in eq2
end

lemma add_div_one: all n:Nat, m:Nat.
  if zero < m
  then (n + m) / m = suc(zero) + n / m
proof
  arbitrary n:Nat, m:Nat
  assume m_pos
  have not_nm_m: not (n + m < m) by {
    assume nm_m: n + m < m
    have not_m_nm: not (m  n + m) by apply not_less_equal_iff_greater to nm_m
    have: m  m + n by less_equal_add
    have m_nm: m  n + m by replace add_commute in recall m  m + n
    conclude false by apply not_m_nm to m_nm
  }
  have m_nz: not (m = zero) by {
    assume: m = zero
    conclude false by evaluate in replace (recall m = zero) in m_pos
  }
  equations
          (n + m) / m
        = suc(zero) + ((n + m)  m) / m  by expand operator/
                                    replace (apply eq_false to not_nm_m)
                                           | (apply eq_false to m_nz).
    ... = suc(zero) + n / m              by replace add_commute[n,m]
                                           | add_monus_identity.
end
  
lemma mult_div_inverse: all n:Nat, m:Nat.
  if zero < m then (n * m) / m = n
proof
  induction Nat
  case zero {
    arbitrary m:Nat
    assume prem
    show zero / m = zero
    apply zero_div to prem
  }
  case suc(n') assume IH {
    arbitrary m:Nat
    assume prem
    expand operator* replace add_commute
    show (n' * m + m) / m = suc(n')
    equations
        (n'*m + m) / m = suc(zero) + (n'*m) / m  by apply add_div_one[n'*m, m] to prem
                   ... = suc(zero) + n'            by replace apply IH to prem.
                   ... = suc(n')           by evaluate
  }
end

/*
Alternative division algorithm.
*/

recfun div_alt(a: Nat, b: Nat, y: Nat) -> Pair<Nat,Nat>
  measure a + b + b of Nat
{
  if b = zero then pair(zero, a)
  else if a = y and zero < a then
      (define f = div_alt(zero, b, y);
      pair(suc(first(f)), second(f)))
  else
      div_alt(suc(a), b  suc(zero), y)
}
terminates {
  arbitrary a:Nat, b:Nat, y:Nat
  have A: if (not (b = zero) and (a = y and zero < a)) then (zero + b) + b < (a + b) + b by {
    assume prem
    have X: b + b < b + b + a by apply less_add_pos[b+b, a] to prem
    //replace add_commute[b+b, a] in X  bug regarding associativity! -Jeremy
    replace add_commute[a, b+b]
    X
  }
  have B: if (not (b = zero) and not (a = y and zero < a)) then (suc(a) + (b  suc(zero))) + (b  suc(zero)) < (a + b) + b by {
    assume prem
    obtain b' where bs: b = suc(b') from apply not_zero_suc[b] to prem
    replace bs
    expand operator∸
    replace monus_zero | add_suc | add_suc
    expand 2* operator+ | operator< | 2*operator≤
    less_equal_refl
  }
  A, B
}