module Nat

import Option
import Base

/*
  Public entry point for natural numbers.

  This module re-exports the Nat definition, arithmetic, order,
  division, parity, powers/logarithms, and summation support, then
  adds cross-cutting facts such as boolean equality, gcd, and literal
  arithmetic rewrite support.
*/

public import NatDefs
public import NatAdd
public import NatMonus
public import NatMult
public import NatLess
public import NatDiv
public import NatEvenOdd
public import NatPowLog
public import NatSum

/*
 Properties of equal
*/
theorem equal_refl: all n:Nat. equal(n,n)
proof
  induction Nat
  case zero {
    expand equal.
  }
  case suc(n') suppose IH {
    suffices equal(n',n')  by expand equal.
    IH
  }
end

theorem equal_complete_sound : all m:Nat. all n:Nat.
  m = n  equal(m, n)
proof
  induction Nat
  case zero {
    arbitrary n:Nat
    switch n {
      case zero { expand equal. }
      case suc(n') { expand equal. }
    }
  }
  case suc(m') suppose IH {
    arbitrary n:Nat 
    switch n {
      case zero { expand equal. }
      case suc(n') {
        have right : (if suc(m') = suc(n') then equal(suc(m'), suc(n')))
          by suppose sm_sn: suc(m') = suc(n')
             suffices equal(m', n')  by expand equal.
             have m_n: m' = n' by injective suc sm_sn
             suffices equal(n', n')  by replace m_n.
             equal_refl[n']
        have left : (if equal(suc(m'), suc(n')) then suc(m') = suc(n'))
          by suppose sm_sn : equal(suc(m'), suc(n'))
             have e_m_n : equal(m', n') by expand equal in sm_sn
             have m_n : m' = n' by apply IH to e_m_n
             replace m_n.
        right, left
      }
    }
  }
end

theorem not_equal_not_eq: all m:Nat, n:Nat.
  if not equal(m, n) then not (m = n)
proof
  arbitrary m:Nat, n:Nat
  suppose not_m_n
  suppose m_n
  have eq_m_n: equal(m, n) by {
    suffices equal(n,n)  by replace m_n.
    equal_refl[n]
  }
  apply not_m_n to eq_m_n
end

/*
  Greatest Common Divisor
*/

recfun gcd(a : Nat, b : Nat) -> Nat
  measure b of Nat
{
  if b = zero then a
  else gcd(b, a % b)
}
terminates {
  arbitrary a:Nat, b:Nat
  assume bnz: not (b = zero)
  have b_pos: zero < b  by apply or_not to zero_or_positive[b], bnz
  conclude a % b < b by apply mod_less_divisor[a,b] to b_pos
}

theorem gcd_divides: all b:Nat, a:Nat. divides(gcd(a,b), a) and divides(gcd(a,b), b)
proof
  define P = fun b':Nat {all a:Nat. divides(gcd(a,b'), a) and divides(gcd(a,b'), b')}
  have X: all j:Nat. (if (all i:Nat. (if i < j then P(i))) then P(j)) by {
    arbitrary j:Nat
    assume IH: all i:Nat. (if i < j then P(i))
    expand P
    switch j {
      case zero {
        arbitrary a:Nat
        have A: divides(gcd(a, zero), a) by {
          expand divides | gcd
          choose suc(zero)
          conclude a * suc(zero) = a  by mult_one
        }
        have B: divides(gcd(a, zero), zero) by {
          expand divides | gcd
          choose zero
          conclude a * zero = zero  by mult_zero
        }
        A, B
      }
      case suc(j') assume j_suc {
        arbitrary a:Nat
        replace symmetric j_suc
        have j_pos: zero < j by {
          replace recall j = suc(j')
          evaluate
        }
        have smaller: a % j < j
          by apply mod_less_divisor[a,j] to j_pos
        have div_j_div_aj: divides(gcd(j, a % j), j) and divides(gcd(j, a % j), a % j)
          by (expand P in apply IH[a%j] to smaller)[j]
        have A: divides(gcd(a, j), a) by {
          replace j_suc expand gcd replace symmetric j_suc
          conclude divides(gcd(j, a % j), a)
            by apply divides_mod[gcd(j, a % j), a, j] to div_j_div_aj, j_pos
        }
        have B: divides(gcd(a, j), j) by {
          replace j_suc expand gcd replace symmetric j_suc
          conclude divides(gcd(j, a % j), j) by div_j_div_aj
        }
        A, B
      }
    }
  }
  arbitrary b:Nat
  expand P in apply strong_induction[P,b] to X
end

/*
  Support for automatic arithmetic on literals.
*/

theorem nat_zero_monus: all m:Nat.
  lit(zero)  lit(m) = lit(zero)
proof
  arbitrary m:Nat
  expand lit | operator∸.
end

theorem nat_monus_zero: all n:Nat.
  n  lit(zero) = n
proof
  arbitrary n:Nat
  expand lit
  monus_zero
end

auto nat_monus_zero

theorem lit_suc_monus_suc: all n:Nat, m:Nat.
  lit(suc(n))  lit(suc(m)) = lit(n)  lit(m)
proof
  arbitrary n:Nat, m:Nat
  expand lit | operator∸.
end

auto lit_suc_monus_suc

theorem lit_dist_mult_add:
  all a:Nat, x:Nat, y:Nat.
  lit(a) * (x + y) = lit(a) * x + lit(a) * y
proof
  arbitrary a:Nat, x:Nat, y:Nat
  expand lit
  dist_mult_add
end

auto lit_dist_mult_add

theorem lit_dist_mult_add_right:
  all x:Nat, y:Nat, a:Nat.
  (x + y) * lit(a) = x * lit(a) + y * lit(a)
proof
  arbitrary x:Nat, y:Nat, a:Nat
  expand lit
  dist_mult_add_right
end

auto lit_dist_mult_add_right

theorem mult_two: all n:Nat.
  n + n = lit(suc(suc(zero))) * n
proof
  arbitrary n:Nat
  expand lit
  replace two_mult.
end

theorem lit_suc_add2: all x:Nat, y:Nat.
  suc(lit(x) + y) = lit(suc(x)) + y
proof
  arbitrary x:Nat, y:Nat
  expand lit
  replace suc_add.
end

// This causes problem with pattern matching in switch. -Jeremy
//auto lit_suc_add2

// The following causes infinite loop for
// lit(x) * lit(y)

// theorem lit_mult_commute: all n:Nat, m:Nat.
//   n * lit(m) = lit(m) * n
// proof
//   arbitrary n:Nat, m:Nat
//   expand lit
//   mult_commute
// end

// auto lit_mult_commute

theorem lit_add_suc: all n:Nat, m:Nat.
  lit(n) + suc(m) = lit(suc(n)) + m
proof
  arbitrary n:Nat, m:Nat
  expand lit
  replace add_suc
  expand operator+.
end

auto lit_add_suc

theorem lit_mult_left_cancel : all m : Nat, a : Nat, b : Nat.
  if lit(suc(m)) * a = lit(suc(m)) * b then a = b
proof
  arbitrary m : Nat, a : Nat, b : Nat
  expand lit
  mult_left_cancel
end

/*
  More Properties of Summation
  */

theorem sum_n : all n : Nat. 
    ℕ2 * summation(n, ℕ0, λ x {x}) = n * (n  ℕ1)
proof
    induction Nat
    case zero {
        evaluate
    }
    case suc(n') suppose IH {
      have step1: (all i:Nat. (if i < ℕ1 then n' + i = ℕ0 + (n' + i))) by {
        arbitrary i:Nat
        suppose prem : i < ℕ1
        evaluate
      }
      replace nat_suc_one_add[n']
      | add_commute[ℕ1, n']
      | apply summation_add[n', ℕ1, ℕ0, n', λn{n}, λn{n}, λn{n}] to step1
      | IH
      expand lit | 2* summation
      replace nat_suc_one_add | add_commute[n', ℕ1]
      replace add_monus_identity
      replace nat_suc_one_add | dist_mult_add_right
      replace mult_commute[n', n'  ℕ1]
      replace symmetric dist_mult_add_right[n'  ℕ1, ℕ2, n']
      switch n' {
        case zero {
          .
        }
        case suc(n'') {
          replace nat_suc_one_add | add_monus_identity
          replace dist_mult_add | dist_mult_add_right
          replace mult_two
          replace add_commute[n'', ℕ1] | add_commute[n'', ℕ2] | add_commute[n'' * n'', ℕ2 * n''].
        }
      }
    }
end

theorem sum_n' : all n : Nat. 
    ℕ2 * summation(suc(n), ℕ0, λ x {x}) = n * (n + ℕ1)
proof
    induction Nat
    case zero {
      expand 2* summation.
    }
    case suc(n') suppose IH {
      have step1: (all i:Nat. (if i < ℕ1 then suc(n') + i = ℕ0 + (suc(n') + i))) by {
        arbitrary i:Nat
        suppose prem : i < ℕ1
        .
      }
      replace nat_suc_one_add
      replace (replace add_commute[n',ℕ1] in
          (apply summation_add[ℕ1 + n'][ℕ1, ℕ0, suc(n'), λn{n}, λn{n}, λn{n}]
           to replace nat_suc_one_add.))
      replace (replace nat_suc_one_add in IH)
      expand lit | 2*summation
      replace nat_suc_one_add | dist_mult_add | dist_mult_add_right
      replace dist_mult_add[n', n', ℕ1] | add_commute[ℕ1, n' + n'] | mult_two[n']
      | add_commute[ℕ2 * n', ℕ2] | add_commute[ℕ2 + ℕ2 * n', n' * n' + n'].
    }
end

/*
  Public versions of theorems involving literals
*/

theorem pos_mult_left_cancel : all m : Nat, a : Nat, b : Nat.
  if ℕ0 < m and m * a = m * b then a = b
proof
  arbitrary m : Nat, a : Nat, b : Nat
  assume prem
  switch m {
    case zero assume mz {
      conclude false by evaluate in replace mz in prem
    }
    case suc(m') assume ms {
      apply mult_left_cancel[m', a, b] to replace ms in prem
    }
  }
end

theorem pos_mult_right_cancel_less : all c : Nat, a : Nat, b : Nat.
  if ℕ0 < c and a * c < b * c then a < b
proof
  arbitrary c : Nat, a : Nat, b : Nat
  expand lit
  assume prem
  apply (apply mult_lt_mono_r[c,a,b] to prem) to prem
end

theorem pos_mult_left_cancel_less_equal : all n : Nat, x : Nat, y : Nat.
  if ℕ0 < n and n * x  n * y then x  y
proof
  arbitrary n : Nat, x : Nat, y : Nat
  expand lit
  assume prem  
  obtain n' where ns: n = suc(n') from apply positive_suc[n] to prem
  apply mult_nonzero_mono_le[n', x, y] to (replace ns in prem)
end
  
theorem pos_mult_both_sides_of_less : all n : Nat, x : Nat, y : Nat.
  if ℕ0 < n and x < y then n * x < n * y
proof
  arbitrary n : Nat, x : Nat, y : Nat
  expand lit
  assume prem
  obtain n' where ns: n = suc(n') from apply positive_suc[n] to prem
  replace ns
  apply mono_nonzero_mult_le[n', x, y] to (replace ns in prem)
end
  
theorem nat_zero_less_one_add: all n:Nat.
  ℕ0 < ℕ1 + n
proof
  arbitrary n:Nat
  expand lit
  zero_less_one_add
end

theorem nat_add_to_zero: all n:Nat, m:Nat.
  if n + m = ℕ0
  then n = ℕ0 and m = ℕ0
proof
  arbitrary n:Nat, m:Nat
  expand lit
  add_to_zero
end

theorem nat_less_add_pos: all x:Nat, y:Nat.
  if ℕ0 < y
  then x < x + y
proof
  arbitrary x:Nat, y:Nat
  expand lit
  less_add_pos
end

theorem nat_monus_zero_iff_less_eq : all x : Nat, y : Nat.
  x  y    x  y = ℕ0
proof
  arbitrary x : Nat, y : Nat
  expand lit
  monus_zero_iff_less_eq[x, y]
end

theorem nat_monus_one_pred : all x : Nat. x  ℕ1 = pred(x)
proof
  arbitrary x:Nat
  expand lit
  monus_one_pred
end

theorem nat_monus_cancel: all n:Nat. n  n = ℕ0
proof
  arbitrary n:Nat
  expand lit
  monus_cancel
end

theorem nat_zero_or_positive: all x:Nat. x = ℕ0 or ℕ0 < x
proof
  arbitrary x:Nat
  expand lit
  zero_or_positive[x]
end

theorem nat_not_one_add_zero: all n:Nat.
  not (ℕ1 + n = ℕ0)
proof
  arbitrary n:Nat
  expand lit
  not_one_add_zero[n]
end

theorem nat_positive_suc: all n:Nat.
  if ℕ0 < n
  then some n':Nat. n = ℕ1 + n'
proof
  arbitrary n:Nat
  expand lit
  assume prem
  expand operator+
  apply positive_suc[n] to prem
end

theorem nat_zero_le_zero: all x:Nat. if x  ℕ0 then x = ℕ0
proof
  arbitrary x:Nat
  expand lit
  zero_le_zero
end
  
theorem summation_next: all n:Nat, s:Nat, f:fn Nat->Nat.
  summation(ℕ1 + n, s, f) = summation(n, s, f) + f(s + n)
proof
  arbitrary n:Nat, s:Nat, f:fn Nat->Nat
  expand lit
  summation_suc_add
end

theorem less_zero_false: all x:Nat. (x < zero) = false
proof
  arbitrary x:Nat
  apply eq_false to not_less_zero[x]
end

auto less_zero_false
  
theorem zero_less_equal_true: all x:Nat. (zero  x) = true
proof
  arbitrary x:Nat
  expand operator≤.
end

auto zero_less_equal_true

theorem not_suc_less_equal_zero: all x:Nat. not (suc(x)  zero)
proof
  arbitrary x:Nat
  expand operator≤.
end

theorem lit_suc_less_equal_zero_false: all x:Nat. (lit(suc(x))  lit(zero)) = false
proof
  arbitrary x:Nat
  expand lit
  apply eq_false to not_suc_less_equal_zero[x]
end

auto lit_suc_less_equal_zero_false
  

theorem le_lit_suc: all x:Nat, y:Nat. (lit(suc(x))  lit(suc(y))) = (lit(x)  lit(y))
proof
  arbitrary x:Nat, y:Nat
  expand lit
  replace apply iff_equal to suc_less_equal_iff_less_equal_suc[x,y].
end
auto le_lit_suc  

theorem less_lit_suc: all x:Nat, y:Nat. (lit(suc(x)) < lit(suc(y))) = (lit(x) < lit(y))
proof
  arbitrary x:Nat, y:Nat
  expand lit
  replace apply iff_equal to less_suc_iff_suc_less[x,y].
end
auto less_lit_suc  

theorem less_lit_zero_suc: all y:Nat. (lit(zero) < lit(suc(y))) = true
proof
  arbitrary y:Nat
  expand lit | operator< | operator≤.
end
auto less_lit_zero_suc

theorem lit_zero_div: all x:Nat. lit(zero) / lit(suc(x)) = lit(zero)
proof
  arbitrary x:Nat
  have pos: zero < suc(x) by expand operator< | operator≤.
  have zero_div_theorem: lit(zero) / lit(suc(x)) = lit(zero) by {
    expand lit
    apply zero_div[suc(x)] to pos
  }
  zero_div_theorem
end
auto lit_zero_div

theorem lit_div_cancel: all y:Nat. lit(suc(y)) / lit(suc(y)) = lit(suc(zero))
proof
  arbitrary y:Nat
  have pos: zero < suc(y) by expand operator< | operator≤.
  have cancel_theorem: lit(suc(y)) / lit(suc(y)) = lit(suc(zero)) by {
    expand lit
    apply div_cancel[suc(y)] to pos
  }
  cancel_theorem
end
auto lit_div_cancel

fun add_div(a : Nat, b : Nat, y : Nat) { (a + b) / y } // invariant: a ≤ y

theorem lit_div: all x:Nat, y:Nat. lit(x) / lit(y) = add_div(zero, x, y)
proof
  arbitrary x:Nat, y:Nat
  expand lit | add_div.
end
auto lit_div

theorem lit_add_div: all b:Nat, y:Nat. add_div(suc(y), b, suc(y)) = lit(suc(zero)) + add_div(zero, b, suc(y))
proof
  arbitrary b:Nat, y:Nat
  have pos: zero < suc(y) by expand operator< | operator≤.
  expand lit | add_div | operator+
  replace add_commute[y,b]
  have X: (b + suc(y)) / suc(y) = suc(zero) + b / suc(y) by apply add_div_one[b, suc(y)] to pos
  expand operator+ in replace add_suc in X
end
auto lit_add_div

theorem lit_add_div_suc: all a:Nat, b:Nat, y:Nat. add_div(a, suc(b), y) = add_div(suc(a), b, y)
proof
  arbitrary a:Nat, b:Nat, y:Nat
  expand add_div
  replace add_suc
  expand operator+.
end
auto lit_add_div_suc

// This postulate is not true. It would be true if we added the premise a < y.
// But the rewriting system does not yet handle conditional rewriting. -Jeremy
postulate lit_add_div_zero: all a:Nat, y:Nat. add_div(a, zero, y) = lit(zero)
auto lit_add_div_zero

theorem lit_less_zero_false: all x:Nat. (x < lit(zero)) = false
proof
  arbitrary x:Nat
  expand lit.
end

auto lit_less_zero_false

theorem lit_zero_less_equal_true: all x:Nat. (lit(zero)  x) = true
proof
  arbitrary x:Nat
  expand lit.
end

auto lit_zero_less_equal_true

theorem lit_expt_two : all n : Nat.
  n ^ ℕ2 = n * n
proof
  arbitrary n:Nat
  expand lit
  expt_two
end

theorem lit_one_expt : all n :Nat.
  ℕ1 ^ n = ℕ1
proof
  arbitrary n:Nat
  expand lit
  one_expt
end

auto lit_one_expt