module UInt

import Base
import Nat
import UIntDefs
import UIntToFrom
import UIntLess
import UIntAdd
import UIntMult
import UIntMonus
import UIntPowLog

/*
  Finite summation over UInt ranges.

  `uint_summation(k, s, f)` sums `f(s), ..., f(s + k - 1)`. It is
  defined by transporting through `toNat`/`fromNat`, so the theorems
  below mirror those of the underlying Nat `summation`.
*/

// Sum `f(s), f(s+1), ..., f(s+k-1)` over UInt.
fun uint_summation(k:UInt, begin:UInt, f:fn UInt->UInt) {
  fromNat(summation(toNat(k), toNat(begin), λ i { toNat(f(fromNat(i))) }))
}

// `fromNat` commutes with adding a Nat onto a UInt.
lemma fromNat_toNat_add:
  all x:UInt, i:Nat. fromNat(toNat(x) + i) = x + fromNat(i)
proof
  arbitrary x:UInt, i:Nat
  suffices toNat(fromNat(toNat(x) + i)) = toNat(x + fromNat(i))
    by uint_toNat_injective
  replace uint_toNat_fromNat | toNat_add | uint_toNat_fromNat.
end

// Specification: a UInt summation is the corresponding Nat summation.
theorem toNat_uint_summation:
  all k:UInt, begin:UInt, f:fn UInt->UInt.
  toNat(uint_summation(k, begin, f))
    = summation(toNat(k), toNat(begin), λ i { toNat(f(fromNat(i))) })
proof
  arbitrary k:UInt, begin:UInt, f:fn UInt->UInt
  expand uint_summation
  uint_toNat_fromNat
end

// The empty summation is zero.
lemma uint_summation_bzero:
  all begin:UInt, f:fn UInt->UInt.
  uint_summation(bzero, begin, f) = bzero
proof
  arbitrary begin:UInt, f:fn UInt->UInt
  suffices toNat(uint_summation(bzero, begin, f)) = toNat(bzero)
    by uint_toNat_injective
  expand toNat
  replace toNat_uint_summation
  expand toNat
  evaluate
end

// UInt analogue of `summation_cong`: pointwise-equal summands over
// shifted ranges yield equal sums.
theorem uint_summation_cong:
  all k:UInt. all f:fn UInt->UInt, g:fn UInt->UInt, s:UInt, t:UInt.
  if (all i:Nat. if i < toNat(k) then f(s + fromNat(i)) = g(t + fromNat(i)))
  then uint_summation(k, s, f) = uint_summation(k, t, g)
proof
  arbitrary k:UInt
  arbitrary f:fn UInt->UInt, g:fn UInt->UInt, s:UInt, t:UInt
  suppose f_g
  suffices toNat(uint_summation(k, s, f)) = toNat(uint_summation(k, t, g))
    by uint_toNat_injective
  replace toNat_uint_summation
  apply summation_cong[toNat(k),
      λ i { toNat(f(fromNat(i))) },
      λ i { toNat(g(fromNat(i))) },
      toNat(s), toNat(t)]
  to arbitrary i:Nat
     suppose i_k
     have fg: f(s + fromNat(i)) = g(t + fromNat(i)) by apply f_g[i] to i_k
     replace fromNat_toNat_add[s, i] | fromNat_toNat_add[t, i]
     replace fg.
end

// Append a final term: `Σ_{i
theorem uint_summation_next:
  all n:UInt, s:UInt, f:fn UInt->UInt.
  uint_summation(1 + n, s, f) = uint_summation(n, s, f) + f(s + n)
proof
  arbitrary n:UInt, s:UInt, f:fn UInt->UInt
  suffices toNat(uint_summation(1 + n, s, f))
           = toNat(uint_summation(n, s, f) + f(s + n))
    by uint_toNat_injective
  replace toNat_uint_summation | toNat_add
  replace toNat_uint_summation
  replace uint_toNat_fromNat
  replace summation_next[toNat(n), toNat(s), λ i { toNat(f(fromNat(i))) }]
  replace fromNat_toNat_add[s, toNat(n)] | uint_fromNat_toNat.
end

// Split a length `a + b` summation into prefix `g` (length `a`) and
// suffix `h` (length `b`), each agreeing pointwise with `f` on its
// segment.
theorem uint_summation_add:
  all a:UInt. all b:UInt, s:UInt, t:UInt,
      f:fn UInt->UInt, g:fn UInt->UInt, h:fn UInt->UInt.
  if (all i:Nat. if i < toNat(a) then g(s + fromNat(i)) = f(s + fromNat(i)))
  and (all i:Nat. if i < toNat(b) then h(t + fromNat(i)) = f(s + a + fromNat(i)))
  then uint_summation(a + b, s, f) = uint_summation(a, s, g) + uint_summation(b, t, h)
proof
  arbitrary a:UInt
  arbitrary b:UInt, s:UInt, t:UInt,
      f:fn UInt->UInt, g:fn UInt->UInt, h:fn UInt->UInt
  suppose g_f_and_h_f
  suffices toNat(uint_summation(a + b, s, f))
           = toNat(uint_summation(a, s, g) + uint_summation(b, t, h))
    by uint_toNat_injective
  replace toNat_uint_summation | toNat_add
  replace toNat_uint_summation
  have p1: all i:Nat. (if i < toNat(a)
      then toNat(g(fromNat(toNat(s) + i)))
         = toNat(f(fromNat(toNat(s) + i)))) by {
    arbitrary i:Nat
    suppose i_a
    have gf: g(s + fromNat(i)) = f(s + fromNat(i))
      by apply (conjunct 0 of g_f_and_h_f)[i] to i_a
    replace fromNat_toNat_add[s, i]
    replace gf.
  }
  have p2: all i:Nat. (if i < toNat(b)
      then toNat(h(fromNat(toNat(t) + i)))
         = toNat(f(fromNat(toNat(s) + toNat(a) + i)))) by {
    arbitrary i:Nat
    suppose i_b
    have hf: h(t + fromNat(i)) = f(s + a + fromNat(i))
      by apply (conjunct 1 of g_f_and_h_f)[i] to i_b
    replace fromNat_toNat_add[t, i]
    suffices toNat(h(t + fromNat(i))) = toNat(f(fromNat(toNat(s) + toNat(a) + i)))
      by .
    replace symmetric toNat_add[s, a]
    replace fromNat_toNat_add[s + a, i] | hf.
  }
  replace (apply summation_add[toNat(a), toNat(b), toNat(s), toNat(t),
      λ i { toNat(f(fromNat(i))) },
      λ i { toNat(g(fromNat(i))) },
      λ i { toNat(h(fromNat(i))) }] to p1, p2).
end

// UInt analogue of `summation_const_one`: summing the constant `1`
// over a UInt range of length `n` equals `n`. Transported through
// `toNat`/`fromNat` to the Nat lemma.
theorem uint_summation_const_one: all n:UInt. all s:UInt.
  uint_summation(n, s, fun i:UInt { 1 }) = n
proof
  arbitrary n:UInt
  arbitrary s:UInt
  suffices toNat(uint_summation(n, s, fun i:UInt { 1 })) = toNat(n)
    by uint_toNat_injective
  replace toNat_uint_summation
  have body_eq:
    summation(toNat(n), toNat(s),
              λ j { toNat((fun u:UInt { 1 })(fromNat(j))) })
    = summation(toNat(n), toNat(s), fun j:Nat { suc(zero) }) by {
    apply summation_cong[toNat(n),
        λ j { toNat((fun u:UInt { 1 })(fromNat(j))) },
        fun j:Nat { suc(zero) },
        toNat(s), toNat(s)]
    to arbitrary i:Nat
       assume _
       evaluate
  }
  replace body_eq
  summation_const_one[toNat(n), toNat(s)]
end

// UInt analogue of `summation_pow2_succ` (lib/Nat.pf): the geometric
// sum closed form `1 + Σ_{i
// `toNat`/`fromNat` to the Nat lemma. The additive form keeps the
// closed form in UInt (no monus).
theorem uint_summation_pow2_succ: all n:UInt.
  (1:UInt) + uint_summation(n, 0, fun i:UInt { 2 ^ i }) = 2 ^ n
proof
  arbitrary n:UInt
  suffices toNat((1:UInt) + uint_summation(n, 0, fun i:UInt { 2 ^ i }))
           = toNat(2 ^ n) by uint_toNat_injective
  replace toNat_add | toNat_uint_summation
  have body_eq:
    summation(toNat(n), toNat(0:UInt),
              λ j { toNat((fun u:UInt { 2 ^ u })(fromNat(j))) })
    = summation(toNat(n), ℕ0, λ j:Nat { ℕ2 ^ j }) by {
    apply summation_cong[toNat(n),
        λ j { toNat((fun u:UInt { 2 ^ u })(fromNat(j))) },
        λ j:Nat { ℕ2 ^ j },
        toNat(0:UInt), ℕ0]
    to arbitrary i:Nat
       assume _
       have tonat_zero: toNat(0:UInt) = ℕ0
         by replace symmetric from_zero in uint_toNat_fromNat[ℕ0]
       replace tonat_zero
       show toNat((2:UInt) ^ fromNat(ℕ0 + i)) = ℕ2 ^ (ℕ0 + i)
       replace toNat_expt | uint_toNat_fromNat.
  }
  replace body_eq
  have tonat_one: toNat(1:UInt) = ℕ1 by {
    have h: toNat(fromNat(ℕ1)) = ℕ1 by uint_toNat_fromNat
    h
  }
  have tonat_two: toNat(2:UInt) = ℕ2 by {
    have h: toNat(fromNat(ℕ2)) = ℕ2 by uint_toNat_fromNat
    h
  }
  replace toNat_expt | tonat_one | tonat_two
  summation_pow2_succ[toNat(n)]
end

// UInt analogue of `summation_pow_succ` (lib/Nat.pf): the generic
// geometric sum closed form `1 + (a − 1) · Σ_{i
// `1 ≤ a`. Transported through `toNat`/`fromNat` to the Nat lemma.
// The additive form keeps the closed form in UInt (no rationals).
theorem uint_summation_pow_succ: all a:UInt. if 1  a then
  all n:UInt.
  (1:UInt) + (a  1) * uint_summation(n, 0, fun i:UInt { a ^ i }) = a ^ n
proof
  arbitrary a:UInt
  assume one_le_a: 1  a
  arbitrary n:UInt
  suffices toNat((1:UInt) + (a  1) * uint_summation(n, 0, fun i:UInt { a ^ i }))
           = toNat(a ^ n) by uint_toNat_injective
  replace toNat_add | toNat_mult | toNat_uint_summation | toNat_monus
  have body_eq:
    summation(toNat(n), toNat(0:UInt),
              λ j { toNat((fun u:UInt { a ^ u })(fromNat(j))) })
    = summation(toNat(n), ℕ0, λ j:Nat { toNat(a) ^ j }) by {
    apply summation_cong[toNat(n),
        λ j { toNat((fun u:UInt { a ^ u })(fromNat(j))) },
        λ j:Nat { toNat(a) ^ j },
        toNat(0:UInt), ℕ0]
    to arbitrary i:Nat
       assume _
       have tonat_zero: toNat(0:UInt) = ℕ0
         by replace symmetric from_zero in uint_toNat_fromNat[ℕ0]
       replace tonat_zero
       show toNat(a ^ fromNat(ℕ0 + i)) = toNat(a) ^ (ℕ0 + i)
       replace toNat_expt | uint_toNat_fromNat.
  }
  replace body_eq
  have tonat_one: toNat(1:UInt) = ℕ1 by {
    have h: toNat(fromNat(ℕ1)) = ℕ1 by uint_toNat_fromNat
    h
  }
  replace toNat_expt | tonat_one
  have one_le_toNat_a: ℕ1  toNat(a) by {
    have h: toNat(1:UInt)  toNat(a) by apply toNat_less_equal to one_le_a
    replace tonat_one in h
  }
  apply summation_pow_succ[toNat(a)] to one_le_toNat_a
end

// UInt analogue of `sum_n` (lib/Nat.pf): twice the sum 0+1+...+(n-1)
// equals n*(n ∸ 1). Transported through `toNat`/`fromNat` to the Nat
// lemma. The doubled form keeps the closed form in UInt (no rationals).
theorem uint_summation_id: all n:UInt.
  2 * uint_summation(n, 0, fun i:UInt { i }) = n * (n  1)
proof
  arbitrary n:UInt
  suffices toNat(2 * uint_summation(n, 0, fun i:UInt { i }))
           = toNat(n * (n  1)) by uint_toNat_injective
  replace toNat_mult | toNat_uint_summation | toNat_monus
  have body_eq:
    summation(toNat(n), toNat(0:UInt),
              λ j { toNat((fun u:UInt { u })(fromNat(j))) })
    = summation(toNat(n), ℕ0, λ x:Nat { x }) by {
    apply summation_cong[toNat(n),
        λ j { toNat((fun u:UInt { u })(fromNat(j))) },
        λ x:Nat { x },
        toNat(0:UInt), ℕ0]
    to arbitrary i:Nat
       assume _
       have tonat_zero: toNat(0:UInt) = ℕ0
         by replace symmetric from_zero in uint_toNat_fromNat[ℕ0]
       replace tonat_zero | uint_toNat_fromNat.
  }
  replace body_eq
  have tonat_one: toNat(1:UInt) = ℕ1 by {
    have h: toNat(fromNat(ℕ1)) = ℕ1 by uint_toNat_fromNat
    h
  }
  have tonat_two: toNat(2:UInt) = ℕ2 by {
    have h: toNat(fromNat(ℕ2)) = ℕ2 by uint_toNat_fromNat
    h
  }
  replace tonat_two | tonat_one
  sum_n[toNat(n)]
end