module Nat

import Option
import Base
public import NatDefs
public import NatAdd
public import NatMonus
public import NatMult
public import NatLess
public import NatDiv
public import NatEvenOdd
public import NatPowLog
public import NatSum

/*
 Properties of equal
*/
theorem equal_refl: all n:Nat. equal(n,n)
proof
  induction Nat
  case zero {
    expand equal.
  }
  case suc(n') suppose IH {
    suffices equal(n',n')  by expand equal.
    IH
  }
end

theorem equal_complete_sound : all m:Nat. all n:Nat.
  m = n  equal(m, n)
proof
  induction Nat
  case zero {
    arbitrary n:Nat
    switch n {
      case zero { expand equal. }
      case suc(n') { expand equal. }
    }
  }
  case suc(m') suppose IH {
    arbitrary n:Nat 
    switch n {
      case zero { expand equal. }
      case suc(n') {
        have right : (if suc(m') = suc(n') then equal(suc(m'), suc(n')))
          by suppose sm_sn: suc(m') = suc(n')
             suffices equal(m', n')  by expand equal.
             have m_n: m' = n' by injective suc sm_sn
             suffices equal(n', n')  by replace m_n.
             equal_refl[n']
        have left : (if equal(suc(m'), suc(n')) then suc(m') = suc(n'))
          by suppose sm_sn : equal(suc(m'), suc(n'))
             have e_m_n : equal(m', n') by expand equal in sm_sn
             have m_n : m' = n' by apply IH to e_m_n
             replace m_n.
        right, left
      }
    }
  }
end

theorem not_equal_not_eq: all m:Nat, n:Nat.
  if not equal(m, n) then not (m = n)
proof
  arbitrary m:Nat, n:Nat
  suppose not_m_n
  suppose m_n
  have eq_m_n: equal(m, n) by {
    suffices equal(n,n)  by replace m_n.
    equal_refl[n]
  }
  apply not_m_n to eq_m_n
end

/*
  Greatest Common Divisor
*/

recfun gcd(a : Nat, b : Nat) -> Nat
  measure b of Nat
{
  if b = zero then a
  else gcd(b, a % b)
}
terminates {
  arbitrary a:Nat, b:Nat
  assume bnz: not (b = zero)
  have b_pos: zero < b  by apply or_not to zero_or_positive[b], bnz
  conclude a % b < b by apply mod_less_divisor[a,b] to b_pos
}

theorem gcd_divides: all b:Nat, a:Nat. divides(gcd(a,b), a) and divides(gcd(a,b), b)
proof
  define P = fun b':Nat {all a:Nat. divides(gcd(a,b'), a) and divides(gcd(a,b'), b')}
  have X: all j:Nat. (if (all i:Nat. (if i < j then P(i))) then P(j)) by {
    arbitrary j:Nat
    assume IH: all i:Nat. (if i < j then P(i))
    expand P
    switch j {
      case zero {
        arbitrary a:Nat
        have A: divides(gcd(a, zero), a) by {
          expand divides | gcd
          choose suc(zero)
          conclude a * suc(zero) = a  by mult_one
        }
        have B: divides(gcd(a, zero), zero) by {
          expand divides | gcd
          choose zero
          conclude a * zero = zero  by mult_zero
        }
        A, B
      }
      case suc(j') assume j_suc {
        arbitrary a:Nat
        replace symmetric j_suc
        have j_pos: zero < j by {
          replace recall j = suc(j')
          evaluate
        }
        have smaller: a % j < j
          by apply mod_less_divisor[a,j] to j_pos
        have div_j_div_aj: divides(gcd(j, a % j), j) and divides(gcd(j, a % j), a % j)
          by (expand P in apply IH[a%j] to smaller)[j]
        have A: divides(gcd(a, j), a) by {
          replace j_suc expand gcd replace symmetric j_suc
          conclude divides(gcd(j, a % j), a)
            by apply divides_mod[gcd(j, a % j), a, j] to div_j_div_aj, j_pos
        }
        have B: divides(gcd(a, j), j) by {
          replace j_suc expand gcd replace symmetric j_suc
          conclude divides(gcd(j, a % j), j) by div_j_div_aj
        }
        A, B
      }
    }
  }
  arbitrary b:Nat
  expand P in apply strong_induction[P,b] to X
end

/*
  Support for automatic arithmetic on literals.
*/

theorem nat_zero_monus: all m:Nat.
  lit(zero)  lit(m) = lit(zero)
proof
  arbitrary m:Nat
  expand lit | operator∸.
end

theorem nat_monus_zero: all n:Nat.
  n  lit(zero) = n
proof
  arbitrary n:Nat
  expand lit
  monus_zero
end

auto nat_monus_zero

theorem lit_suc_monus_suc: all n:Nat, m:Nat.
  lit(suc(n))  lit(suc(m)) = lit(n)  lit(m)
proof
  arbitrary n:Nat, m:Nat
  expand lit | operator∸.
end

auto lit_suc_monus_suc

theorem lit_dist_mult_add:
  all a:Nat, x:Nat, y:Nat.
  lit(a) * (x + y) = lit(a) * x + lit(a) * y
proof
  arbitrary a:Nat, x:Nat, y:Nat
  expand lit
  dist_mult_add
end

auto lit_dist_mult_add

theorem lit_dist_mult_add_right:
  all x:Nat, y:Nat, a:Nat.
  (x + y) * lit(a) = x * lit(a) + y * lit(a)
proof
  arbitrary x:Nat, y:Nat, a:Nat
  expand lit
  dist_mult_add_right
end

auto lit_dist_mult_add_right

theorem mult_two: all n:Nat.
  n + n = lit(suc(suc(zero))) * n
proof
  arbitrary n:Nat
  expand lit
  replace two_mult.
end

theorem lit_suc_add2: all x:Nat, y:Nat.
  suc(lit(x) + y) = lit(suc(x)) + y
proof
  arbitrary x:Nat, y:Nat
  expand lit
  replace suc_add.
end

// This causes problem with pattern matching in switch. -Jeremy
//auto lit_suc_add2

// The following causes infinite loop for
// lit(x) * lit(y)

// theorem lit_mult_commute: all n:Nat, m:Nat.
//   n * lit(m) = lit(m) * n
// proof
//   arbitrary n:Nat, m:Nat
//   expand lit
//   mult_commute
// end

// auto lit_mult_commute

theorem lit_add_suc: all n:Nat, m:Nat.
  lit(n) + suc(m) = lit(suc(n)) + m
proof
  arbitrary n:Nat, m:Nat
  expand lit
  replace add_suc
  expand operator+.
end

auto lit_add_suc

theorem lit_mult_left_cancel : all m : Nat, a : Nat, b : Nat.
  if lit(suc(m)) * a = lit(suc(m)) * b then a = b
proof
  arbitrary m : Nat, a : Nat, b : Nat
  expand lit
  mult_left_cancel
end

/*
  More Properties of Summation
  */

theorem sum_n : all n : Nat. 
    ℕ2 * summation(n, ℕ0, λ x {x}) = n * (n  ℕ1)
proof
    induction Nat
    case zero {
        evaluate
    }
    case suc(n') suppose IH {
      have step1: (all i:Nat. (if i < ℕ1 then n' + i = ℕ0 + (n' + i))) by {
        arbitrary i:Nat
        suppose prem : i < ℕ1
        evaluate
      }
      replace nat_suc_one_add[n']
      | add_commute[ℕ1, n']
      | apply summation_add[n', ℕ1, ℕ0, n', λn{n}, λn{n}, λn{n}] to step1
      | IH
      expand lit | 2* summation
      replace nat_suc_one_add | add_commute[n', ℕ1]
      replace add_monus_identity
      replace nat_suc_one_add | dist_mult_add_right
      replace mult_commute[n', n'  ℕ1]
      replace symmetric dist_mult_add_right[n'  ℕ1, ℕ2, n']
      switch n' {
        case zero {
          .
        }
        case suc(n'') {
          replace nat_suc_one_add | add_monus_identity
          replace dist_mult_add | dist_mult_add_right
          replace mult_two
          replace add_commute[n'', ℕ1] | add_commute[n'', ℕ2] | add_commute[n'' * n'', ℕ2 * n''].
        }
      }
    }
end

theorem sum_n' : all n : Nat. 
    ℕ2 * summation(suc(n), ℕ0, λ x {x}) = n * (n + ℕ1)
proof
    induction Nat
    case zero {
      expand 2* summation.
    }
    case suc(n') suppose IH {
      have step1: (all i:Nat. (if i < ℕ1 then suc(n') + i = ℕ0 + (suc(n') + i))) by {
        arbitrary i:Nat
        suppose prem : i < ℕ1
        .
      }
      replace nat_suc_one_add
      replace (replace add_commute[n',ℕ1] in
          (apply summation_add[ℕ1 + n'][ℕ1, ℕ0, suc(n'), λn{n}, λn{n}, λn{n}]
           to replace nat_suc_one_add.))
      replace (replace nat_suc_one_add in IH)
      expand lit | 2*summation
      replace nat_suc_one_add | dist_mult_add | dist_mult_add_right
      replace dist_mult_add[n', n', ℕ1] | add_commute[ℕ1, n' + n'] | mult_two[n']
      | add_commute[ℕ2 * n', ℕ2] | add_commute[ℕ2 + ℕ2 * n', n' * n' + n'].
    }
end

/*
  Public versions of theorems involving literals
*/

theorem pos_mult_left_cancel : all m : Nat, a : Nat, b : Nat.
  if ℕ0 < m and m * a = m * b then a = b
proof
  arbitrary m : Nat, a : Nat, b : Nat
  assume prem
  switch m {
    case zero assume mz {
      conclude false by evaluate in replace mz in prem
    }
    case suc(m') assume ms {
      apply mult_left_cancel[m', a, b] to replace ms in prem
    }
  }
end

theorem pos_mult_right_cancel_less : all c : Nat, a : Nat, b : Nat.
  if ℕ0 < c and a * c < b * c then a < b
proof
  arbitrary c : Nat, a : Nat, b : Nat
  expand lit
  assume prem
  apply (apply mult_lt_mono_r[c,a,b] to prem) to prem
end

theorem pos_mult_left_cancel_less_equal : all n : Nat, x : Nat, y : Nat.
  if ℕ0 < n and n * x  n * y then x  y
proof
  arbitrary n : Nat, x : Nat, y : Nat
  expand lit
  assume prem  
  obtain n' where ns: n = suc(n') from apply positive_suc[n] to prem
  apply mult_nonzero_mono_le[n', x, y] to (replace ns in prem)
end
  
theorem pos_mult_both_sides_of_less : all n : Nat, x : Nat, y : Nat.
  if ℕ0 < n and x < y then n * x < n * y
proof
  arbitrary n : Nat, x : Nat, y : Nat
  expand lit
  assume prem
  obtain n' where ns: n = suc(n') from apply positive_suc[n] to prem
  replace ns
  apply mono_nonzero_mult_le[n', x, y] to (replace ns in prem)
end
  
theorem nat_zero_less_one_add: all n:Nat.
  ℕ0 < ℕ1 + n
proof
  arbitrary n:Nat
  expand lit
  zero_less_one_add
end

theorem nat_add_to_zero: all n:Nat, m:Nat.
  if n + m = ℕ0
  then n = ℕ0 and m = ℕ0
proof
  arbitrary n:Nat, m:Nat
  expand lit
  add_to_zero
end

theorem nat_less_add_pos: all x:Nat, y:Nat.
  if ℕ0 < y
  then x < x + y
proof
  arbitrary x:Nat, y:Nat
  expand lit
  less_add_pos
end

theorem nat_monus_zero_iff_less_eq : all x : Nat, y : Nat.
  x  y    x  y = ℕ0
proof
  arbitrary x : Nat, y : Nat
  expand lit
  monus_zero_iff_less_eq[x, y]
end

theorem nat_monus_one_pred : all x : Nat. x  ℕ1 = pred(x)
proof
  arbitrary x:Nat
  expand lit
  monus_one_pred
end

theorem nat_monus_cancel: all n:Nat. n  n = ℕ0
proof
  arbitrary n:Nat
  expand lit
  monus_cancel
end

theorem nat_zero_or_positive: all x:Nat. x = ℕ0 or ℕ0 < x
proof
  arbitrary x:Nat
  expand lit
  zero_or_positive[x]
end

theorem nat_not_one_add_zero: all n:Nat.
  not (ℕ1 + n = ℕ0)
proof
  arbitrary n:Nat
  expand lit
  not_one_add_zero[n]
end

theorem nat_positive_suc: all n:Nat.
  if ℕ0 < n
  then some n':Nat. n = ℕ1 + n'
proof
  arbitrary n:Nat
  expand lit
  assume prem
  expand operator+
  apply positive_suc[n] to prem
end

theorem nat_zero_le_zero: all x:Nat. if x  ℕ0 then x = ℕ0
proof
  arbitrary x:Nat
  expand lit
  zero_le_zero
end
  
theorem summation_next: all n:Nat, s:Nat, f:fn Nat->Nat.
  summation(ℕ1 + n, s, f) = summation(n, s, f) + f(s + n)
proof
  arbitrary n:Nat, s:Nat, f:fn Nat->Nat
  expand lit
  summation_suc_add
end

theorem less_zero_false: all x:Nat. (x < zero) = false
proof
  arbitrary x:Nat
  apply eq_false to not_less_zero[x]
end

auto less_zero_false
  
theorem zero_less_equal_true: all x:Nat. (zero  x) = true
proof
  arbitrary x:Nat
  expand operator≤.
end

auto zero_less_equal_true

theorem not_suc_less_equal_zero: all x:Nat. not (suc(x)  zero)
proof
  arbitrary x:Nat
  expand operator≤.
end

theorem lit_suc_less_equal_zero_false: all x:Nat. (lit(suc(x))  lit(zero)) = false
proof
  arbitrary x:Nat
  expand lit
  apply eq_false to not_suc_less_equal_zero[x]
end

auto lit_suc_less_equal_zero_false
  

theorem le_lit_suc: all x:Nat, y:Nat. (lit(suc(x))  lit(suc(y))) = (lit(x)  lit(y))
proof
  arbitrary x:Nat, y:Nat
  expand lit
  replace apply iff_equal to suc_less_equal_iff_less_equal_suc[x,y].
end
auto le_lit_suc  

theorem less_lit_suc: all x:Nat, y:Nat. (lit(suc(x)) < lit(suc(y))) = (lit(x) < lit(y))
proof
  arbitrary x:Nat, y:Nat
  expand lit
  replace apply iff_equal to less_suc_iff_suc_less[x,y].
end
auto less_lit_suc  

theorem less_lit_zero_suc: all y:Nat. (lit(zero) < lit(suc(y))) = true
proof
  arbitrary y:Nat
  expand lit | operator< | operator≤.
end
auto less_lit_zero_suc

theorem lit_zero_div: all x:Nat. lit(zero) / lit(suc(x)) = lit(zero)
proof
  arbitrary x:Nat
  have pos: zero < suc(x) by expand operator< | operator≤.
  have zero_div_theorem: lit(zero) / lit(suc(x)) = lit(zero) by {
    expand lit
    apply zero_div[suc(x)] to pos
  }
  zero_div_theorem
end
auto lit_zero_div

theorem lit_div_cancel: all y:Nat. lit(suc(y)) / lit(suc(y)) = lit(suc(zero))
proof
  arbitrary y:Nat
  have pos: zero < suc(y) by expand operator< | operator≤.
  have cancel_theorem: lit(suc(y)) / lit(suc(y)) = lit(suc(zero)) by {
    expand lit
    apply div_cancel[suc(y)] to pos
  }
  cancel_theorem
end
auto lit_div_cancel

fun add_div(a : Nat, b : Nat, y : Nat) { (a + b) / y } // invariant: a ≤ y

theorem lit_div: all x:Nat, y:Nat. lit(x) / lit(y) = add_div(zero, x, y)
proof
  arbitrary x:Nat, y:Nat
  expand lit | add_div.
end
auto lit_div

theorem lit_add_div: all b:Nat, y:Nat. add_div(suc(y), b, suc(y)) = lit(suc(zero)) + add_div(zero, b, suc(y))
proof
  arbitrary b:Nat, y:Nat
  have pos: zero < suc(y) by expand operator< | operator≤.
  expand lit | add_div | operator+
  replace add_commute[y,b]
  have X: (b + suc(y)) / suc(y) = suc(zero) + b / suc(y) by apply add_div_one[b, suc(y)] to pos
  expand operator+ in replace add_suc in X
end
auto lit_add_div

theorem lit_add_div_suc: all a:Nat, b:Nat, y:Nat. add_div(a, suc(b), y) = add_div(suc(a), b, y)
proof
  arbitrary a:Nat, b:Nat, y:Nat
  expand add_div
  replace add_suc
  expand operator+.
end
auto lit_add_div_suc

// This postulate is not true. It would be true if we added the premise a < y.
// But the rewriting system does not yet handle conditional rewriting. -Jeremy
postulate lit_add_div_zero: all a:Nat, y:Nat. add_div(a, zero, y) = lit(zero)
auto lit_add_div_zero

theorem lit_less_zero_false: all x:Nat. (x < lit(zero)) = false
proof
  arbitrary x:Nat
  expand lit.
end

auto lit_less_zero_false

theorem lit_zero_less_equal_true: all x:Nat. (lit(zero)  x) = true
proof
  arbitrary x:Nat
  expand lit.
end

auto lit_zero_less_equal_true

theorem lit_expt_two : all n : Nat.
  n ^ ℕ2 = n * n
proof
  arbitrary n:Nat
  expand lit
  expt_two
end

theorem lit_one_expt : all n :Nat.
  ℕ1 ^ n = ℕ1
proof
  arbitrary n:Nat
  expand lit
  one_expt
end

auto lit_one_expt