module Nat

import Option
import Base
import NatDefs
import NatAdd
import NatMonus

/*
 Properties of Multiplication
*/

lemma zero_mult: all n:Nat.
  zero * n = zero
proof
  arbitrary n:Nat
  expand operator*.
end

auto zero_mult

theorem nat_zero_mult: all y:Nat.
  lit(zero) * lit(y) = lit(zero)
proof
  arbitrary y:Nat
  expand lit.
end

auto nat_zero_mult

lemma mult_zero: all n:Nat.
  n * zero = zero
proof
  induction Nat
  case zero {
    conclude zero * zero = zero   by .
  }
  case suc(n') suppose IH {
    equations  suc(n') * zero = zero + n' * zero   by expand operator*.
                          ... = zero               by IH
  }
end

auto mult_zero

lemma suc_mult: all m:Nat, n:Nat.
  suc(m) * n = n + m * n
proof
  arbitrary m:Nat, n:Nat
  expand operator*.
end

theorem lit_suc_mult: all m:Nat, n:Nat.
  lit(suc(m)) * lit(n) = lit(n) + lit(m) * lit(n)
proof
  arbitrary m:Nat, n:Nat
  expand lit | operator*.
end

auto lit_suc_mult


lemma mult_suc: all m:Nat. all n:Nat.
  m * suc(n) = m + m * n
proof
  induction Nat
  case zero {
    arbitrary n:Nat
    conclude zero * suc(n) = zero + zero * n    by .
  }
  case suc(m') suppose IH: all n:Nat. m' * suc(n) = m' + m' * n {
    arbitrary n:Nat
    suffices suc(n + m' * suc(n)) = suc(m' + (n + m' * n))
        by expand operator* | 2*operator+.
    equations   suc(n + m' * suc(n))
              = suc(n + m' + m' * n)       by replace IH.
          ... = suc(m' + n + m' * n)       by replace add_commute[n][m'].
  }
end

theorem lit_mult_lit_suc: all m:Nat, n:Nat.
  lit(m) * lit(suc(n)) = lit(m) + lit(m) * lit(n)
proof
  arbitrary m:Nat, n:Nat
  expand lit
  mult_suc
end

auto lit_mult_lit_suc

theorem lit_mult_suc: all m:Nat, n:Nat.
  lit(m) * suc(n) = lit(m) + lit(m) * n
proof
  arbitrary m:Nat, n:Nat
  expand lit
  mult_suc
end

auto lit_mult_suc

theorem mult_commute: all m:Nat. all n:Nat.
  m * n = n * m
proof
  induction Nat
  case zero {
    arbitrary n:Nat
    equations    zero * n = zero          by .
                   ... = n * zero         by .
  }
  case suc(m') suppose IH: all n:Nat. m' * n = n * m' {
    arbitrary n:Nat
    equations    suc(m') * n
               = n + m' * n     by expand operator*.
           ... = n + (n * m')   by replace IH.
           ... = n * suc(m')    by symmetric mult_suc[n][m']
  }
end

lemma one_mult: all n:Nat.
  suc(zero) * n = n
proof
  arbitrary n:Nat
  equations     suc(zero) * n = n + zero * n    by suc_mult[zero,n]
                  ... = n + zero        by .
                  ... = n            by .
end

auto one_mult

theorem nat_one_mult: all n:Nat.
  lit(suc(zero)) * n = n
proof
  arbitrary n:Nat
  expand lit.
end

auto nat_one_mult

lemma mult_one: all n:Nat.
  n * suc(zero) = n
proof
  arbitrary n:Nat
  equations    n * suc(zero) = suc(zero) * n    by mult_commute[n][suc(zero)]
                 ... = n        by one_mult[n]
end

auto mult_one

theorem nat_mult_one: all n:Nat.
  n * lit(suc(zero)) = n
proof
  arbitrary n:Nat
  expand lit.
end

auto nat_mult_one

lemma mult_to_zero : all n :Nat, m : Nat.
  if m * n = zero then m = zero or n = zero
proof
  arbitrary n : Nat, m :Nat
  switch n {
    case zero { . }
    case suc(n') assume eq_sn_t {
      switch m {
        case zero { . }
        case suc(m') assume eq_sm_t {
          evaluate
        }
      }
    }
  }
end

lemma one_mult_one : all x : Nat, y : Nat.
  if x * y = suc(zero) then x = suc(zero) and y = suc(zero)
proof
  arbitrary x : Nat, y : Nat
  switch x {  
    case zero {
      .
    }
    case suc(x') suppose IH {
      switch y {
        case zero {
          .
        }
        case suc(y') {
          suffices (if suc(y' + x' * suc(y')) = suc(zero) then (suc(x') = suc(zero) and suc(y') = suc(zero))) 
            by expand operator* | operator+ .
          suppose prem
          have blah : y' + x' * suc(y') = zero by { injective suc prem }
          have y'_zero : y' = zero by apply add_to_zero to blah
          have x'sy'_zero : x' * suc(y') = zero by apply add_to_zero to blah
          have x'_zero : x' = zero by {
            switch x' {
              case zero { . }
              case suc(n') suppose x_sn' {
                expand operator* | operator+ in replace x_sn' in x'sy'_zero
              }
            }
          }
          replace x'_zero | y'_zero.
        }
      }
    }
  }
end

lemma two_mult: all n:Nat.
  suc(suc(zero)) * n = n + n
proof
  arbitrary n:Nat
  equations   suc(suc(zero)) * n = n + suc(zero) * n   by suc_mult[suc(zero),n]
                ... = n + n           by .
end

theorem dist_mult_add:
  all a:Nat, x:Nat, y:Nat.
  a * (x + y) = a * x + a * y
proof
  induction Nat
  case zero {
    arbitrary x:Nat, y:Nat
    equations   zero * (x + y)
              = zero                        by .
          ... = zero + zero                 by .
          ... =# zero * x + zero * y #      by .
  }
  case suc(a') suppose IH {
    arbitrary x:Nat, y:Nat
    suffices (x + y) + a' * (x + y) = (x + a' * x) + (y + a' * y)
      by expand operator *.
    equations
            (x + y) + a' * (x + y)
          = x + y + a'*x + a'*y           by replace IH.
      ... = x + a'*x + y + a'*y           by replace add_commute[y, a'*x].
  }
end

theorem dist_mult_add_right:
  all x:Nat, y:Nat, a:Nat.
  (x + y) * a = x * a + y * a
proof
  arbitrary x:Nat, y:Nat, a:Nat
  equations
  (x + y) * a = a * (x + y)         by replace mult_commute.
          ... = a * x + a * y       by dist_mult_add[a][x,y]
          ... = x * a + y * a       by replace mult_commute.
end
  
theorem mult_assoc: all m:Nat. all n:Nat, o:Nat.
  (m * n) * o = m * (n * o)
proof
  induction Nat
  case zero {
    arbitrary n:Nat, o:Nat
    equations   (zero * n) * o = zero * o         by .
                        ... = zero             by zero_mult[o]
                        ... = zero * (n * o)   by .
  }
  case suc(m') suppose IH: all n:Nat, o:Nat. (m' * n) * o = m' * (n * o) {
    arbitrary n:Nat, o:Nat
    equations
          (suc(m') * n) * o
        = (n + m' * n) * o          by expand operator*.
    ... = n * o + (m' * n) * o      by dist_mult_add_right[n, m'*n, o]
    ... = n * o + m' * (n * o)      by replace IH.
    ... =# suc(m') * (n * o) #      by expand operator*.
  }
end

associative operator* in Nat

lemma mult_right_cancel : all m : Nat, n : Nat, o : Nat.
  if m * suc(o) = n * suc(o) then m = n
proof
  induction Nat
  case zero {
    arbitrary n:Nat, o:Nat
    switch n {
      case zero { . }
      case suc(n') {
        assume contra: zero * suc(o) = suc(n') * suc(o)
        conclude false by evaluate in contra
      }
    }
  }
  case suc(m') assume IH {
    arbitrary n:Nat, o:Nat
    switch n {
      case zero {
        assume contra: suc(m') * suc(o) = zero * suc(o)
        conclude false by evaluate in contra
      }
      case suc(n') {
        assume prem: suc(m') * suc(o) = suc(n') * suc(o)
        have prem2: suc(o) + m' * suc(o) = suc(o) + n' * suc(o)
          by expand operator* in prem
        have prem3: m' * suc(o) = n' * suc(o)
          by apply left_cancel to prem2
        have: m' = n' by apply IH to prem3
        conclude suc(m') = suc(n') by replace recall m' = n'.
      }
    }
  }
end

lemma mult_left_cancel : all m : Nat, a : Nat, b : Nat.
  if suc(m) * a = suc(m) * b then a = b
proof
  arbitrary m:Nat, a:Nat, b:Nat
  assume prem: suc(m) * a = suc(m) * b
  have prem2: a * suc(m) = b * suc(m)  by replace mult_commute in prem
  conclude a = b by apply mult_right_cancel to prem2
end