/*
  Represent multisets as function to Nat.
  */

import Base
import Nat
import Set

union MultiSet<T> {
  m_fun(fn(T) -> Nat)
}

fun cnt<T>(M : MultiSet<T>) {
  switch M {
    case m_fun(f) { f }
  }
}

fun m_one<T>(x:T) {
  m_fun(λy:T{if x = y then 1 else 0})
}

fun operator ⨄ <T>(P: MultiSet<T>, Q: MultiSet<T>) {
  m_fun(λx:T{ cnt(P)(x) + cnt(Q)(x) })
}

theorem cnt_m_fun: all T:type, f: fn T -> Nat. cnt(m_fun(f)) = f
proof
  arbitrary T:type, f:(fn T -> Nat)
  definition cnt
end
         
theorem m_one_cnt: all T:type. all x:T. cnt(m_one(x))(x) = 1
proof
  arbitrary T:type
  arbitrary x:T
  definition {m_one, cnt}
end

theorem m_empty_zero: all T:type. all x:T.
  cnt(m_fun(λy:T{0}))(x) = 0
proof
  arbitrary T:type
  arbitrary x:T
  definition {cnt}
end

theorem cnt_sum: all T:type. all A:MultiSet<T>, B:MultiSet<T>, x:T.
  cnt(A  B)(x) = cnt(A)(x) + cnt(B)(x)
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>, B:MultiSet<T>, x:T
  definition {operator ⨄, cnt}
end

// Testing alternate asci notation 
theorem m_sum_empty: all T:type. all A:MultiSet<T>.
  A .+. m_fun(λy:T{0}) = A
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>
  switch A for operator .+., cnt {
    case m_fun(f) {
      have eq: (λx{(f(x) + 0)} : fn T->Nat) = f
        by extensionality arbitrary x:T
           rewrite add_zero[f(x)]
      rewrite eq
    }
  }
end

theorem empty_m_sum: all T:type. all A:MultiSet<T>.
  m_fun(λx:T{0})  A = A
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>
  switch A for operator ⨄ {
    case m_fun(f) {
      have eq: (λx:T{0 + f(x)}) = f
        by extensionality arbitrary x:T
	   definition {operator +}
      definition cnt and rewrite eq
    }
  }
end


theorem m_sum_commutes: all T:type. all A:MultiSet<T>, B:MultiSet<T>.
  A  B = B  A
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>, B:MultiSet<T>
  switch A for operator ⨄ {
    case m_fun(f) {
      switch B {
        case m_fun(g) {
	  suffices m_fun(λx:T{f(x) + g(x)})
                 = m_fun(λx:T{g(x) + f(x)})
              by definition cnt
	  have eq: (λx:T{(f(x) + g(x))}) = λx:T{(g(x) + f(x))}
	    by extensionality arbitrary x:T
	       rewrite add_commute[f(x)][g(x)]
	  rewrite eq
	}
      }
    }
  }
end

theorem m_sum_assoc: all T:type. all A:MultiSet<T>, B:MultiSet<T>, C:MultiSet<T>.
  (A  B)  C = A  (B  C)
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>, B:MultiSet<T>, C:MultiSet<T>
  switch A for operator ⨄ {
    case m_fun(f) {
      switch B {
	      case m_fun(g) {
	        switch C {
	          case m_fun(h) {
	            suffices m_fun(λx:T{(f(x) + g(x)) + h(x)})
                     = m_fun(λx:T{f(x) + (g(x) + h(x))})
                  by definition cnt
              .
	          }
	        }
	      }
      }
    }
  }
end

define set_of_mset : fn<T> MultiSet<T> -> Set<T>
  = generic T {λM{char_fun(λx{ if cnt(M)(x) = 0 then false else true })}}

theorem som_one_single: all T:type. all x:T.
  set_of_mset(m_one(x)) = single(x)
proof
  arbitrary T:type
  arbitrary z:T
  have eq: λx:T{(if cnt(m_fun(λy:T{(if z = y then 1 else 0)}))(x) = 0
               then false else true)}
           = λy{z = y}
    by extensionality
       arbitrary y:T
       switch y = z for cnt {
         case true suppose yz_true {
	   have z_y: z = y
	     by symmetric rewrite yz_true
	   rewrite z_y
	 }
	 case false suppose yz_false {
	   have not_zy: not (z = y)
	     by suppose zy
	        rewrite yz_false in symmetric zy
	   rewrite (apply eq_false to not_zy)
	 }
       }
  definition {m_one, single, set_of_mset}
  and rewrite eq
end

theorem som_union: all T:type. all A:MultiSet<T>, B:MultiSet<T>.
  set_of_mset(A  B) = set_of_mset(A)  set_of_mset(B)
proof
  arbitrary T:type
  arbitrary A:MultiSet<T>, B:MultiSet<T>
  suffices char_fun(fun x:T{(if cnt(A)(x) + cnt(B)(x) = 0 then false else true)})
         = char_fun(fun x:T{((if cnt(A)(x) = 0 then false else true) or (if cnt(B)(x) = 0 then false else true))})
    by definition set_of_mset | operator∪ | operator ⨄ and replace rep_char_fun<T> | cnt_m_fun<T>
  have eq: λx:T{(if cnt(A)(x) + cnt(B)(x) = 0 then false else true)}
         = λx:T{((if cnt(A)(x) = 0 then false else true)
             or (if cnt(B)(x) = 0 then false else true))} by
  {
    extensionality
    arbitrary x:T
    switch cnt(A)(x) {
      case 0 suppose Ax_z {
	      suffices (if 0 + cnt(B)(x) = 0 then false else true)
          = (if cnt(B)(x) = 0 then false else true) by rewrite Ax_z
        definition operator +
      }
	    case suc(Ax) suppose Ax_s {
	      suffices (if suc(Ax) + cnt(B)(x) = 0 then false else true) = true
          by rewrite Ax_s
	      definition operator +
	    }
    }
  }
  rewrite eq
end