module UInt

import Base
import Nat
import UIntDefs
import UIntToFrom
import UIntLess
import UIntAdd

/*
  Finite summation over UInt ranges.

  `uint_summation(k, s, f)` sums `f(s), ..., f(s + k - 1)`. It is
  defined by transporting through `toNat`/`fromNat`, so the theorems
  below mirror those of the underlying Nat `summation`.
*/

// Sum `f(s), f(s+1), ..., f(s+k-1)` over UInt.
fun uint_summation(k:UInt, begin:UInt, f:fn UInt->UInt) {
  fromNat(summation(toNat(k), toNat(begin), λ i { toNat(f(fromNat(i))) }))
}

// `fromNat` commutes with adding a Nat onto a UInt.
lemma fromNat_toNat_add:
  all x:UInt, i:Nat. fromNat(toNat(x) + i) = x + fromNat(i)
proof
  arbitrary x:UInt, i:Nat
  suffices toNat(fromNat(toNat(x) + i)) = toNat(x + fromNat(i))
    by uint_toNat_injective
  replace uint_toNat_fromNat | toNat_add | uint_toNat_fromNat.
end

// Specification: a UInt summation is the corresponding Nat summation.
theorem toNat_uint_summation:
  all k:UInt, begin:UInt, f:fn UInt->UInt.
  toNat(uint_summation(k, begin, f))
    = summation(toNat(k), toNat(begin), λ i { toNat(f(fromNat(i))) })
proof
  arbitrary k:UInt, begin:UInt, f:fn UInt->UInt
  expand uint_summation
  uint_toNat_fromNat
end

// The empty summation is zero.
lemma uint_summation_bzero:
  all begin:UInt, f:fn UInt->UInt.
  uint_summation(bzero, begin, f) = bzero
proof
  arbitrary begin:UInt, f:fn UInt->UInt
  suffices toNat(uint_summation(bzero, begin, f)) = toNat(bzero)
    by uint_toNat_injective
  expand toNat
  replace toNat_uint_summation
  expand toNat
  evaluate
end

// UInt analogue of `summation_cong`: pointwise-equal summands over
// shifted ranges yield equal sums.
theorem uint_summation_cong:
  all k:UInt. all f:fn UInt->UInt, g:fn UInt->UInt, s:UInt, t:UInt.
  if (all i:Nat. if i < toNat(k) then f(s + fromNat(i)) = g(t + fromNat(i)))
  then uint_summation(k, s, f) = uint_summation(k, t, g)
proof
  arbitrary k:UInt
  arbitrary f:fn UInt->UInt, g:fn UInt->UInt, s:UInt, t:UInt
  suppose f_g
  suffices toNat(uint_summation(k, s, f)) = toNat(uint_summation(k, t, g))
    by uint_toNat_injective
  replace toNat_uint_summation
  apply summation_cong[toNat(k),
      λ i { toNat(f(fromNat(i))) },
      λ i { toNat(g(fromNat(i))) },
      toNat(s), toNat(t)]
  to arbitrary i:Nat
     suppose i_k
     have fg: f(s + fromNat(i)) = g(t + fromNat(i)) by apply f_g[i] to i_k
     replace fromNat_toNat_add[s, i] | fromNat_toNat_add[t, i]
     replace fg.
end

// Append a final term: `Σ_{i
theorem uint_summation_next:
  all n:UInt, s:UInt, f:fn UInt->UInt.
  uint_summation(1 + n, s, f) = uint_summation(n, s, f) + f(s + n)
proof
  arbitrary n:UInt, s:UInt, f:fn UInt->UInt
  suffices toNat(uint_summation(1 + n, s, f))
           = toNat(uint_summation(n, s, f) + f(s + n))
    by uint_toNat_injective
  replace toNat_uint_summation | toNat_add
  replace toNat_uint_summation
  replace uint_toNat_fromNat
  replace summation_next[toNat(n), toNat(s), λ i { toNat(f(fromNat(i))) }]
  replace fromNat_toNat_add[s, toNat(n)] | uint_fromNat_toNat.
end

// Split a length `a + b` summation into prefix `g` (length `a`) and
// suffix `h` (length `b`), each agreeing pointwise with `f` on its
// segment.
theorem uint_summation_add:
  all a:UInt. all b:UInt, s:UInt, t:UInt,
      f:fn UInt->UInt, g:fn UInt->UInt, h:fn UInt->UInt.
  if (all i:Nat. if i < toNat(a) then g(s + fromNat(i)) = f(s + fromNat(i)))
  and (all i:Nat. if i < toNat(b) then h(t + fromNat(i)) = f(s + a + fromNat(i)))
  then uint_summation(a + b, s, f) = uint_summation(a, s, g) + uint_summation(b, t, h)
proof
  arbitrary a:UInt
  arbitrary b:UInt, s:UInt, t:UInt,
      f:fn UInt->UInt, g:fn UInt->UInt, h:fn UInt->UInt
  suppose g_f_and_h_f
  suffices toNat(uint_summation(a + b, s, f))
           = toNat(uint_summation(a, s, g) + uint_summation(b, t, h))
    by uint_toNat_injective
  replace toNat_uint_summation | toNat_add
  replace toNat_uint_summation
  have p1: all i:Nat. (if i < toNat(a)
      then toNat(g(fromNat(toNat(s) + i)))
         = toNat(f(fromNat(toNat(s) + i)))) by {
    arbitrary i:Nat
    suppose i_a
    have gf: g(s + fromNat(i)) = f(s + fromNat(i))
      by apply (conjunct 0 of g_f_and_h_f)[i] to i_a
    replace fromNat_toNat_add[s, i]
    replace gf.
  }
  have p2: all i:Nat. (if i < toNat(b)
      then toNat(h(fromNat(toNat(t) + i)))
         = toNat(f(fromNat(toNat(s) + toNat(a) + i)))) by {
    arbitrary i:Nat
    suppose i_b
    have hf: h(t + fromNat(i)) = f(s + a + fromNat(i))
      by apply (conjunct 1 of g_f_and_h_f)[i] to i_b
    replace fromNat_toNat_add[t, i]
    suffices toNat(h(t + fromNat(i))) = toNat(f(fromNat(toNat(s) + toNat(a) + i)))
      by .
    replace symmetric toNat_add[s, a]
    replace fromNat_toNat_add[s + a, i] | hf.
  }
  replace (apply summation_add[toNat(a), toNat(b), toNat(s), toNat(t),
      λ i { toNat(f(fromNat(i))) },
      λ i { toNat(g(fromNat(i))) },
      λ i { toNat(h(fromNat(i))) }] to p1, p2).
end