module Int

import UInt
import Base

public import IntDefs
public import IntAddSub

// Properties of Multiplication

lemma mult_pos_pos: all x:UInt, y:UInt. pos(x) * pos(y) = pos(x * y)
proof
  arbitrary x:UInt, y:UInt
  evaluate
end

lemma mult_pos_negsuc: all x:UInt, y:UInt.
  pos(x) * negsuc(y) = - pos(x + x * y)
proof
  arbitrary x:UInt, y:UInt
  show #pos(x) * negsuc(y)# = - pos(x + x * y)
  expand operator*
  expand sign | abs
  replace mult_pos_neg | mult_neg_uint | uint_dist_mult_add.
end

theorem int_one_mult: all x:Int. +1 * x = x
proof
  arbitrary x:Int
  switch x {
    case pos(x') {
      replace mult_pos_pos.
    }
    case negsuc(x') {
      replace mult_pos_negsuc | neg_pos.
    }
  }
end

auto int_one_mult

theorem int_zero_mult: all x:Int. +0 * x = +0
proof
  arbitrary x:Int
  switch x {
    case pos(x') {
      expand operator* | sign | abs | 2*operator*.
    }
    case negsuc(x') {
      expand operator* | sign | abs | 2*operator* | operator-.
    }
  }
end

auto int_zero_mult

lemma sign_neg_flip: all x:Int.
  if not (x = +0)
  then sign(- x) = flip(sign(x))
proof
  arbitrary x:Int
  switch x {
    case pos(x') {
      assume prem
      expand operator-
      have xnz: not (x' = 0) by {
        assume xz
        conclude false by replace xz in prem
      }
      replace (apply eq_false to xnz)
      expand sign | flip.
    }
    case negsuc(x') {
      expand operator- | sign | flip.
    }
  }
end

theorem dist_neg_mult: all x:Int,y:Int.
  -(x * y) = (-x) * y
proof
  arbitrary x:Int, y:Int
  switch x {
    case pos(x') {
      switch y {
        case pos(y') {
          expand operator* 
          replace sign_pos
          replace mult_pos_any | abs_pos | mult_pos_uint
          cases uint_zero_or_positive[x']
          case xz {
            replace xz | neg_zero
            expand sign | abs
            replace mult_pos_any | mult_pos_uint.
          }
          case xp_pos {
            have xnz: not (x' = 0) by apply uint_pos_not_zero to xp_pos
            expand operator-
            replace (apply eq_false to xnz)
            expand abs | sign
            replace mult_neg_pos | mult_neg_uint
            cases uint_zero_or_positive[y']
            case yz {
              replace yz
              expand operator-.
            }
            case y_pos {
              have ynz: not (y' = 0) by apply uint_pos_not_zero to y_pos
              have nxyz: not (x' * y' = 0) by {
                assume xyz
                cases apply uint_mult_to_zero[x',y'] to xyz
                case xz {
                  conclude false by apply xnz to xz
                }
                case yz {
                  conclude false by apply ynz to yz
                }
              }
              replace (apply eq_false to nxyz)
              have one_x: 1  x' by apply uint_pos_implies_one_le to xp_pos
              replace (apply uint_monus_add_identity to one_x)
              replace (symmetric neg_pos[(x' * y'  1)])
              have one_xy: 1  x' * y' by {
                obtain y'' where yp: y' = 1 + y'' from apply uint_positive_add_one to y_pos
                replace yp | uint_dist_mult_add
                have x_xy: x'  x' + x' * y'' by uint_less_equal_add
                apply uint_less_equal_trans to one_x, x_xy
              }
              replace (apply uint_monus_add_identity to one_xy).
            }
          }
        }
        case negsuc(y') {
          expand operator*
          replace sign_pos | sign_negsuc | mult_pos_any | mult_neg_uint | abs_negsuc | abs_pos
          | neg_involutive | abs_neg | abs_pos
          cases uint_zero_or_positive[x']
          case xz {
            replace xz
            evaluate
          }
          case xp_pos {
            have xnz: not (x' = 0) by apply uint_pos_not_zero to xp_pos
            expand operator-
            replace (apply eq_false to xnz) | sign_negsuc | mult_neg_neg | mult_pos_uint.
          }
        }
      }
    }
    case negsuc(x') {
      switch y {
        case pos(y') {
          expand operator*
          replace sign_negsuc | sign_pos | abs_negsuc | abs_pos | mult_neg_pos | mult_neg_uint
          replace sign_neg_negsuc | mult_pos_any | abs_neg | abs_negsuc | mult_pos_uint | neg_involutive.
        }
        case negsuc(y') {
          expand operator*
          replace sign_negsuc | abs_negsuc | sign_neg_negsuc | mult_pos_neg | abs_neg | abs_negsuc | mult_neg_uint
            | mult_neg_neg | mult_pos_uint.
        }
      }
    }
  }
end

theorem int_mult_commute: all x:Int, y:Int. x * y = y * x
proof
  arbitrary x:Int, y:Int
  switch x {
    case pos(x') {
      switch y {
        case pos(y') {
          suffices pos(x' * y') = pos(y' * x')  by evaluate
          replace uint_mult_commute[x',y'].
        }
        case negsuc(y') {
          suffices - (x' * (1 + y')) = - ((1 + y') * x')  by evaluate
          replace uint_mult_commute[x', 1 + y'].
        }
      }
    }
    case negsuc(x') {
      switch y {
        case pos(y') {
          suffices - ((1 + x') * y') = - (y' * (1 + x'))  by evaluate
          replace uint_mult_commute[1 + x', y'].
        }
        case negsuc(y') {
          suffices pos((1 + x') * (1 + y')) = pos((1 + y') * (1 + x')) by evaluate
          replace uint_mult_commute[(1 + x'), (1 + y')].
        }
      }
    }
  }
end

theorem int_mult_one: all x:Int. x * +1 = x
proof
  arbitrary x:Int
  conclude x * +1 = x    by int_mult_commute[x,+1]
end

auto int_mult_one

theorem int_mult_zero: all x:Int. x * +0 = +0
proof
  arbitrary x:Int
  conclude x * +0 = +0  by int_mult_commute[x,+0]
end

auto int_mult_zero

lemma mult_assoc_helper: all x:UInt, y:UInt, z:UInt.
  z + (y + x * (1 + y)) * (1 + z)
  = (z + y * (1 + z)) + x * (1 + z + y * (1 + z))
proof
  arbitrary x:UInt, y:UInt, z:UInt
  equations
      z + (y + x * (1 + y)) * (1 + z)
    = z + (y + x + x * y) * (1 + z)
      by replace uint_dist_mult_add[x].
... = z + y + x + x * y + (y + x + x * y) * z
      by replace uint_dist_mult_add.
... = z + y + x + x * y + y*z + x*z + x*y*z
      by replace uint_dist_mult_add_right.
... = z + y + x + y*z + x*y + x*z + x*y*z
      by replace uint_add_commute[x*y, y*z].
... = z + y + x + y*z + x*z + x*y + x*y*z
      by replace uint_add_commute[x*y].
... = z + y + y*z + x + x*z + x*y + x*y*z
      by replace uint_add_commute[x].
... = #(z + y * (1 + z)) + x * (1 + z + y * (1 + z))#
      by replace uint_dist_mult_add.
end

theorem int_mult_assoc: all x:Int, y:Int, z:Int. (x * y) * z = x * (y * z)
proof
  arbitrary x:Int, y:Int, z:Int
  switch x {
    case pos(x') {
      cases uint_zero_or_positive[x']
      case xpz: x' = 0 {
        replace xpz.
      }
      case xp_pos: 0 < x' {
        have xnz: not (x' = 0) by apply uint_pos_not_zero to xp_pos
        obtain x'' where xp_eq: x' = 1 + x'' from apply uint_positive_add_one to xp_pos
        switch y {
          case pos(y') {
            cases uint_zero_or_positive[y']
            case ypz: y' = 0 {
              replace ypz.
            }
            case yp_pos: 0 < y' {
              have ynz: not (y' = 0) by apply uint_pos_not_zero to yp_pos
              switch z {
                case pos(z') {
                  replace mult_pos_pos | mult_pos_pos.
                }
                case negsuc(z') {
                  replace mult_pos_pos
                  replace mult_pos_negsuc
                  replace int_mult_commute[pos(x'), -pos(y'+y'*z')] | symmetric dist_neg_mult[pos(y'+y'*z'), pos(x')]
                  replace mult_pos_pos | uint_dist_mult_add_right | uint_mult_commute[y',x']
                  | uint_mult_commute[y'*z',x'].
                }
              }
            }
          }
          case negsuc(y') {
            switch z {
              case pos(z') {
                expand operator*
                replace sign_pos | sign_negsuc | abs_pos | abs_negsuc | mult_pos_neg | mult_neg_uint
                replace mult_neg_pos | mult_neg_uint | abs_neg | abs_pos | mult_any_pos | mult_pos_any
                replace uint_dist_mult_add | uint_dist_mult_add_right
                replace xp_eq | sign_neg_pos
                
                cases uint_zero_or_positive[z']
                case zpz: z' = 0 {
                  replace zpz
                  expand operator- | sign | 2*operator* | operator-.
                }
                case zp_pos: 0 < z' {
                  have znz: not (z' = 0) by apply uint_pos_not_zero to zp_pos
                  obtain z'' where zp_eq: z' = 1 + z'' from apply uint_positive_add_one to zp_pos
                  replace zp_eq
                  replace sign_neg_pos.
                }
              }
              case negsuc(z') {
                expand operator*
                replace sign_pos | sign_negsuc | abs_pos | abs_negsuc
                replace mult_pos_neg | mult_neg_uint | mult_pos_any | mult_neg_neg | mult_pos_uint | abs_pos | sign_pos
                replace mult_pos_uint | uint_dist_mult_add | uint_dist_mult_add
                replace xp_eq | sign_neg_pos | mult_neg_neg | abs_neg | abs_pos | mult_pos_uint.
              }
            }
          }
        }
      }
    }
    case negsuc(x') {
      switch y {
        case pos(y') {
          switch z {
            case pos(z') {
              expand operator*
              replace sign_negsuc | sign_pos | abs_negsuc | abs_pos | mult_neg_pos | mult_neg_uint | mult_pos_any
              | mult_pos_uint | uint_dist_mult_add_right | sign_pos | mult_neg_pos | abs_pos
              | mult_neg_uint
              cases uint_zero_or_positive[y']
              case ypz: y' = 0 {
                replace ypz
                expand operator- | abs | sign | 2* operator*.
              }
              case yp_pos: 0 < y' {
                obtain y'' where yp_eq: y' = 1 + y'' from apply uint_positive_add_one to yp_pos
                replace yp_eq
                replace sign_neg_pos | mult_neg_pos | abs_neg | abs_pos | mult_neg_uint | uint_dist_mult_add_right
                | uint_dist_mult_add | uint_dist_mult_add_right | uint_dist_mult_add_right.
              }
            }
            case negsuc(z') {
              expand operator*
              replace sign_negsuc | sign_pos | abs_negsuc | abs_pos | mult_neg_pos | mult_neg_uint | mult_pos_neg
              | mult_neg_uint | uint_dist_mult_add_right | uint_dist_mult_add
              cases uint_zero_or_positive[y']
              case ypz: y' = 0 {
                replace ypz
                expand operator-
                expand sign | abs
                evaluate
              }
              case yp_pos: 0 < y' {
                obtain y'' where yp_eq: y' = 1 + y'' from apply uint_positive_add_one to yp_pos
                replace yp_eq
                replace sign_neg_pos | abs_neg | abs_pos | mult_neg_neg | uint_dist_mult_add
                replace uint_dist_mult_add_right | mult_pos_uint
                replace uint_add_commute[x'+x'*y'',z']
                replace uint_add_commute[x'+x'*y'',y''*z'].
              }
            }
          }
            }
            case negsuc(y') {
              switch z {
                case pos(z') {
                  cases uint_zero_or_positive[z']
                  case zpz: z' = 0 {
                    replace zpz
                    evaluate
                  }
                  case zp_pos: 0 < z' {
                    have znz: not (z' = 0) by apply uint_pos_not_zero to zp_pos
                    obtain z'' where zp_eq: z' = 1 + z'' from apply uint_positive_add_one to zp_pos
                    replace zp_eq
                    expand operator*
                    replace sign_negsuc | abs_negsuc | mult_neg_neg | sign_pos | mult_pos_uint | mult_neg_pos | abs_pos
                    | mult_neg_uint | sign_pos | mult_pos_any | uint_dist_mult_add
                    | uint_dist_mult_add_right
                    | uint_dist_mult_add_right
                    | uint_dist_mult_add
                    | sign_neg_pos | mult_neg_neg | mult_pos_uint | abs_neg | abs_pos
                    | uint_add_commute[z'' + x' + x' * z'', y']
                    | uint_add_commute[x' + x' * z'', y' * z'']
                    | uint_dist_mult_add
                    | uint_add_commute[x' * z'', x' * y'].
                  }
                }
                case negsuc(z'') {
                  suffices negsuc(z'' + (y' + x' * (1 + y')) * (1 + z''))
                         = negsuc((z'' + y' * (1 + z'')) + x' * (1 + z'' + y' * (1 + z'')))
                    by evaluate
                  replace mult_assoc_helper[x',y',z''].
                }
              }
            }
          }      
    }
  }
end

associative operator* in Int