/*
  Represent multisets as function to Nat.
  */

import Base
import Nat
import Set

opaque union MultiSet<T> {
  m_fun(fn(T) -> Nat)
}

opaque fun cnt<T>(M : MultiSet<T>) {
  switch M {
    case m_fun(f) { f }
  }
}

opaque fun m_empty<T>() {
  m_fun(λy:T{0})
}

opaque fun m_one<T>(x:T) {
  m_fun(λy:T{if x = y then 1 else 0})
}

opaque fun operator ⨄ <T>(P: MultiSet<T>, Q: MultiSet<T>) {
  m_fun(λx:T{ cnt(P)(x) + cnt(Q)(x) })
}

lemma cnt_m_fun: all T:type, f: fn T -> Nat. cnt(m_fun(f)) = f
proof
  arbitrary T:type, f:(fn T -> Nat)
  expand cnt.
end
         
theorem cnt_empty: all T:type. all x:T. cnt(@m_empty<T>())(x) = 0
proof
  arbitrary T:type
  arbitrary x:T
  expand m_empty | cnt.
end

theorem cnt_one: all T:type. all x:T. cnt(m_one(x))(x) = 1
proof
  arbitrary T:type
  arbitrary x:T
  expand m_one | cnt.
end

theorem cnt_sum: all T:type. all A:MultiSet<T>, B:MultiSet<T>, x:T.
  cnt(A  B)(x) = cnt(A)(x) + cnt(B)(x)
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>, B:MultiSet<T>, x:T
  expand operator ⨄ | cnt.
end

theorem cnt_equal: all T:type, A:MultiSet<T>, B:MultiSet<T>.
  if cnt(A) = cnt(B) then A = B
proof
  arbitrary T:type, A:MultiSet<T>, B:MultiSet<T>
  switch A for cnt {
    case m_fun(f) {
      switch B {
        case m_fun(g) {
          assume fg
          replace fg.
        }
      }
    }
  }
end

// Testing alternate asci notation 
theorem m_sum_empty: all T:type. all A:MultiSet<T>.
  A .+. @m_empty<T>() = A
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>
  expand m_empty
  switch A for operator .+., cnt {
    case m_fun(f) {
      have eq: (λx{(f(x) + 0)} : fn T->Nat) = f
        by extensionality arbitrary x:T
           replace add_zero[f(x)].
      replace eq.
    }
  }
end

theorem empty_m_sum: all T:type. all A:MultiSet<T>.
  @m_empty<T>()  A = A
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>
  expand m_empty
  switch A for operator ⨄ {
    case m_fun(f) {
      have eq: (λx:T{0 + f(x)}) = f by {
        extensionality arbitrary x:T
 	      expand operator +.
      }
      expand cnt
      replace eq.
    }
  }
end


theorem m_sum_commutes: all T:type. all A:MultiSet<T>, B:MultiSet<T>.
  A  B = B  A
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>, B:MultiSet<T>
  switch A for operator ⨄ {
    case m_fun(f) {
      switch B {
        case m_fun(g) {
	        suffices m_fun(λx:T{f(x) + g(x)})
                 = m_fun(λx:T{g(x) + f(x)})
              by expand cnt.
	        have eq: (λx:T{(f(x) + g(x))}) = λx:T{(g(x) + f(x))}
	          by extensionality arbitrary x:T
	             replace add_commute[f(x)][g(x)].
	        replace eq.
	      }
      }
    }
  }
end

theorem m_sum_assoc:
  all T:type. all A:MultiSet<T>, B:MultiSet<T>, C:MultiSet<T>.
  (A  B)  C = A  (B  C)
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>, B:MultiSet<T>, C:MultiSet<T>
  switch A for operator ⨄ {
    case m_fun(f) {
      switch B {
	      case m_fun(g) {
	        switch C {
	          case m_fun(h) {
	            suffices m_fun(λx:T{(f(x) + g(x)) + h(x)})
                     = m_fun(λx:T{f(x) + (g(x) + h(x))})
                  by expand cnt.
              .
	          }
	        }
	      }
      }
    }
  }
end

define set_of_mset : fn<T> MultiSet<T> -> Set<T>
  = generic T {λM{set_of_pred(λx{ if cnt(M)(x) = 0 then false else true })}}

theorem som_empty: all T:type.
  set_of_mset( @m_empty<T>() ) = @empty_set<T>()
proof
  arbitrary T:type
  expand set_of_mset
  define A = set_of_pred(fun x:T { (if cnt(@m_empty<T>())(x) = 0 then false else true) })
  suffices (all x:T. x  A  x  ) by member_equal<T>[A,]
  arbitrary x:T
  have fwd: if x  A then x  empty_set() by {
    expand A
    assume prem
    conclude false by evaluate in apply member_set_of_pred to prem
  }
  have bkwd: if x  empty_set() then x  A by {
    assume prem
    conclude false by apply empty_no_members to prem
  }
  fwd, bkwd
end

  
theorem som_one_single: all T:type. all x:T.
  set_of_mset(m_one(x)) = single(x)
proof
  arbitrary T:type
  arbitrary z:T
  have eq: λx:T{(if cnt(m_fun(λy:T{(if z = y then 1 else 0)}))(x) = 0
               then false else true)}
           = λy{z = y}
    by extensionality
       arbitrary y:T
       switch y = z for cnt {
         case true suppose yz_true {
	   have z_y: z = y
	     by symmetric replace yz_true.
	   replace z_y.
	 }
	 case false suppose yz_false {
	   have not_zy: not (z = y)
	     by suppose zy
	        replace yz_false in symmetric zy
	   replace (apply eq_false to not_zy).
	 }
       }
  expand m_one | set_of_mset
  replace eq | single_pred<T>[z].
end

theorem som_union: all T:type. all A:MultiSet<T>, B:MultiSet<T>.
  set_of_mset(A  B) = set_of_mset(A)  set_of_mset(B)
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>, B:MultiSet<T>
  expand set_of_mset | operator ⨄
  replace union_set_of_pred<T>
  suffices (fun x:T { (if cnt(m_fun(fun y:T { cnt(A)(y) + cnt(B)(y) }))(x) = 0 then false else true) })
    = (fun x:T { ((if cnt(A)(x) = 0 then false else true) or (if cnt(B)(x) = 0 then false else true)) })
    by set_of_pred_equal_bkwd<T>
  extensionality
  arbitrary x:T
  switch A {
    case m_fun(f) {
      switch B {
        case m_fun(g) {
          expand cnt
          switch f(x) {
            case 0 assume fxz {
              replace fxz
              evaluate
            }
            case suc(fx') assume fxs {
              replace fxs
              evaluate
            }
          }
        }
      }
    }
  }
end