import Set
import Option

fun operator ∘ <T,U,V>(g:fn U->V, f:fn T->U) {
  fun x:T { g(f(x)) }
}

fun update<T, U>(f:fn T->U, x:T, v:U) {
  fun y:T { if y = x then v else f(y) }
}

theorem update_eq :
  < T, U > all f: fn (T) -> U, x:T, v:U.
  update(f, x, v)(x) = v
proof
  arbitrary T:type, U:type
  arbitrary f: fn (T) -> U, x:T, v:U
  definition update
end

theorem update_not_eq :
  < T, U > all f: fn (T) -> U, x:T, v:U, y:T.
  if not (x = y)
  then update(f, x, v)(y) = f(y)
proof
  arbitrary T:type, U:type
  arbitrary f: fn (T) -> U, x:T, v:U, y:T
  suppose not_eq
  switch y = x for update {
    case true suppose yx_true {
      have yx: y = x by rewrite yx_true
      have xy: x = y by symmetric yx
      conclude false by apply not_eq to xy
    }
    case false {
      .
    }
  }
end

theorem update_shadow : all T:type, U:type. all f:fn(T)->U, x:T, v:U, w:U.
  update(update(f, x, v), x, w) = update(f, x, w)
proof
  arbitrary T:type, U:type
  arbitrary f:fn(T)->U, x:T, v:U, w:U
  extensionality
  arbitrary y:T
  switch y = x for update {
    case true { . }
    case false { . }
  }
end

theorem update_permute :
  < T, U > all f:fn(T)->U, x:T, v:U, w:U, y:T.
  if not (x = y)
  then update(update(f, x, v), y, w) = update(update(f, y, w), x, v)
proof
  arbitrary T:type, U:type
  arbitrary f:fn(T)->U, x:T, v:U, w:U, y:T
  suppose x_neq_y
  extensionality
  arbitrary z:T
  switch z = y for update {
    case true suppose zy_true {
      have zy: z = y by rewrite zy_true
      switch z = x {
        case true suppose zx_true {
	        have zx: z = x by rewrite zx_true
	        have xy: x = y by transitive (symmetric zx) zy
	        conclude false by apply x_neq_y to xy
	      }
	      case false {
	        .
	      }
      }
    }
    case false {
      .
    }
  }
end

/* Partial Maps */

define empty_map = generic T, U { λk:T { @none<U> } }

define domain : fn<T,U>  (fn(T)->Option<U>) -> Set<T>
              = generic T, U { λf: fn(T)->Option<U> {char_fun(λx:T{
                  switch f(x) {
	            case none { false }
		    case just(y) { true }
		}})}}

define restrict = generic T, U { λf:fn(T)->Option<U>, P:Set<T> { λx:T {
                      if x  P then f(x) else @none<U> }}}

define combine = generic T, U { fun f:fn(T)->Option<U>, g:fn(T)->Option<U> { fun x:T {
                     if f(x) = none then g(x) else f(x) }}}

theorem combine_left: all T:type, U:type, f:fn(T)->Option<U>, g:fn(T)->Option<U>, x:T.
  if g(x) = none
  then combine(f, g)(x) = f(x)
proof
  arbitrary T:type, U:type
  arbitrary f:fn(T)->Option<U>, g:fn(T)->Option<U>, x:T
  assume: g(x) = none
  switch f(x) {
    case none assume: f(x) = none {
      conclude combine(f, g)(x) = f(x)
          by definition combine and rewrite (recall f(x) = none) | (recall g(x) = none)
    }
    case just(v) assume: f(x) = just(v) {
      conclude combine(f, g)(x) = f(x)
          by definition combine and rewrite (recall f(x) = just(v))
    }
  }
end

theorem combine_right: all T:type, U:type, f:fn(T)->Option<U>, g:fn(T)->Option<U>, x:T.
  if f(x) = none
  then combine(f, g)(x) = g(x)
proof
  arbitrary T:type, U:type
  arbitrary f:fn(T)->Option<U>, g:fn(T)->Option<U>, x:T
  assume: f(x) = none
  conclude combine(f, g)(x) = g(x)
      by definition combine and rewrite recall f(x) = none
end

theorem restrict_domain:
  < T, U > all f:fn(T)->Option<U>, P:Set<T>.
  domain(restrict(f, P))  P
proof
  arbitrary T:type, U:type
  arbitrary f:fn(T)->Option<U>, P:Set<T>
  suffices (all x:T. (if x  char_fun(λx'{switch (if x'  P then f(x')
                                                 else @none<U>)
            { case none{false} case just(y){true} }}) then x  P))
      by definition {restrict, domain, operator ⊆}
  arbitrary x:T
  suppose prem
  switch x  P {
    case true { . }
    case false suppose Px_false {
      have eq: rep(P)(x) = false by definition operator∈ in Px_false
      conclude false by {
        rewrite rep_char_fun<T> | eq in
        definition operator∈ in prem
      }
    }
  }
end

define flip : fn<T,U,V> (fn T,U->V) ->(fn U,T->V)
  = generic T,U,V { λf{ λx,y{ f(y,x) }}}

theorem flip_flip:
  all T:type. all f:fn T,T->T. flip(flip(f)) = f
proof
  arbitrary T:type
  arbitrary f:fn T,T->T
  extensionality
  arbitrary x:T, y:T
  definition flip
end