module Int

import Base
import Nat
import UInt
import IntDefs
import IntAddSub

/*
  Finite summation of Int-valued terms over UInt ranges.

  `int_summation(k, s, f)` sums `f(s), ..., f(s + k - 1)` for an
  `Int`-valued summand. Unlike `uint_summation`, the result cannot be
  transported through Nat summation because Int values can be
  negative; `int_summation_nat` is therefore defined by direct Nat
  recursion on the length, with a Nat-typed start index that is
  embedded back into UInt for each summand call.

  The lemmas mirror `lib/UIntSum.pf` (congruence, next-step,
  additivity). The genuinely new content over the unsigned story is
  the negation and subtraction interaction at the bottom of the file.
*/

// Internal Nat-counted recursion: `f(fromNat(begin)) + ... + f(fromNat(begin + k - 1))`.
// Both the count and the start index are Nat-typed so that proofs by
// `induction Nat` work cleanly with the auto identities `n + 0 = n` and
// `n + suc(m) = suc(n + m)` from Nat addition.
recursive int_summation_nat(Nat, Nat, fn UInt -> Int) -> Int {
  int_summation_nat(zero, begin, f) = +0
  int_summation_nat(suc(k), begin, f) = f(fromNat(begin)) + int_summation_nat(k, suc(begin), f)
}

// Sum `f(s), f(s+1), ..., f(s+k-1)` of Int-valued terms over a UInt-indexed range.
fun int_summation(k:UInt, begin:UInt, f:fn UInt -> Int) {
  int_summation_nat(toNat(k), toNat(begin), f)
}

// `(a + b) + (c + d) = (a + c) + (b + d)`: inner swap used to combine two parallel
// summations of pointwise sums into a sum of two separate summations.
lemma int_add_swap_inner: all a:Int, b:Int, c:Int, d:Int.
  (a + b) + (c + d) = (a + c) + (b + d)
proof
  arbitrary a:Int, b:Int, c:Int, d:Int
  replace int_add_commute[b, c].
end

// `fromNat` commutes with adding a Nat onto a UInt (same helper as in UIntSum).
lemma fromNat_toNat_add_uint:
  all x:UInt, i:Nat. fromNat(toNat(x) + i) = x + fromNat(i)
proof
  arbitrary x:UInt, i:Nat
  suffices toNat(fromNat(toNat(x) + i)) = toNat(x + fromNat(i))
    by uint_toNat_injective
  replace uint_toNat_fromNat | toNat_add | uint_toNat_fromNat.
end

// The empty sum is zero.
theorem int_summation_zero:
  all begin:UInt, f:fn UInt -> Int.
  int_summation(0, begin, f) = +0
proof
  arbitrary begin:UInt, f:fn UInt -> Int
  expand int_summation
  show int_summation_nat(toNat((0:UInt)), toNat(begin), f) = +0
  // toNat(0) reduces to ℕ0, which is `lit(zero)`; then int_summation_nat(zero,…) = +0.
  have z: toNat((0:UInt)) = ℕ0 by {
    have h: toNat(fromNat(ℕ0)) = ℕ0 by uint_toNat_fromNat
    replace from_zero in h
  }
  replace z
  expand lit
  expand int_summation_nat
  evaluate
end

// Internal snoc lemma: append a final term to a Nat-counted summation.
// Stated and proved entirely in Nat-land for `begin`, sidestepping the
// `fromNat(zero)` vs `fromNat(ℕ0)` mismatch that arises when `begin` is UInt.
lemma int_summation_nat_snoc:
  all k:Nat. all s:Nat, f:fn UInt -> Int.
  int_summation_nat(suc(k), s, f) = int_summation_nat(k, s, f) + f(fromNat(s + k))
proof
  induction Nat
  case zero {
    arbitrary s:Nat, f:fn UInt -> Int
    expand 2 * int_summation_nat.
  }
  case suc(k') assume IH {
    arbitrary s:Nat, f:fn UInt -> Int
    have step: int_summation_nat(suc(k'), suc(s), f)
             = int_summation_nat(k', suc(s), f) + f(fromNat(suc(s) + k'))
      by IH[suc(s), f]
    have shift: suc(s) + k' = s + suc(k') by {
      replace add_suc[s, k']
      expand operator+.
    }
    equations
          int_summation_nat(suc(suc(k')), s, f)
        = f(fromNat(s)) + int_summation_nat(suc(k'), suc(s), f)
            by expand int_summation_nat.
    ... = f(fromNat(s)) + (int_summation_nat(k', suc(s), f) + f(fromNat(suc(s) + k')))
            by replace step.
    ... = f(fromNat(s)) + (int_summation_nat(k', suc(s), f) + f(fromNat(s + suc(k'))))
            by replace shift.
    ... = (f(fromNat(s)) + int_summation_nat(k', suc(s), f)) + f(fromNat(s + suc(k')))
            by symmetric int_add_assoc[f(fromNat(s)), int_summation_nat(k', suc(s), f), f(fromNat(s + suc(k')))]
    ... = # int_summation_nat(suc(k'), s, f) # + f(fromNat(s + suc(k')))
            by expand int_summation_nat.
  }
end

// Append a final term: `Σ_{i
theorem int_summation_next:
  all n:UInt, s:UInt, f:fn UInt -> Int.
  int_summation(1 + n, s, f) = int_summation(n, s, f) + f(s + n)
proof
  arbitrary n:UInt, s:UInt, f:fn UInt -> Int
  expand int_summation
  // toNat(1 + n) = ℕ1 + toNat(n) = suc(toNat(n)); then apply the snoc lemma.
  have to_one_plus: toNat(1 + n) = suc(toNat(n)) by {
    replace toNat_add
    // Goal: toNat(1) + toNat(n) = suc(toNat(n)). toNat(1) = ℕ1; ℕ1 + x = suc(x).
    have one: toNat((1:UInt)) = ℕ1 by {
      have h: toNat(fromNat(ℕ1)) = ℕ1 by uint_toNat_fromNat
      replace from_one in h
    }
    replace one
    expand lit
    symmetric nat_suc_one_add[toNat(n)]
  }
  replace to_one_plus
  have snoc: int_summation_nat(suc(toNat(n)), toNat(s), f)
           = int_summation_nat(toNat(n), toNat(s), f) + f(fromNat(toNat(s) + toNat(n)))
    by int_summation_nat_snoc[toNat(n), toNat(s), f]
  replace snoc
  // fromNat(toNat(s) + toNat(n)) = s + fromNat(toNat(n)) = s + n.
  replace fromNat_toNat_add_uint[s, toNat(n)] | uint_fromNat_toNat.
end

// Pointwise-equal summands over shifted ranges yield equal sums.
lemma int_summation_nat_cong:
  all k:Nat. all f:fn UInt -> Int, g:fn UInt -> Int, s:Nat, t:Nat.
  if (all i:Nat. if i < k then f(fromNat(s + i)) = g(fromNat(t + i)))
  then int_summation_nat(k, s, f) = int_summation_nat(k, t, g)
proof
  induction Nat
  case zero {
    arbitrary f:fn UInt -> Int, g:fn UInt -> Int, s:Nat, t:Nat
    assume _
    expand int_summation_nat.
  }
  case suc(k') assume IH {
    arbitrary f:fn UInt -> Int, g:fn UInt -> Int, s:Nat, t:Nat
    assume hyp: all i:Nat. if i < suc(k') then f(fromNat(s + i)) = g(fromNat(t + i))
    have fs_gt: f(fromNat(s)) = g(fromNat(t)) by {
      have z_lt: ℕ0 < suc(k') by {
        expand lit
        expand operator< | operator≤.
      }
      apply hyp[ℕ0] to z_lt
    }
    have rest_hyp: all i:Nat. if i < k' then f(fromNat(suc(s) + i)) = g(fromNat(suc(t) + i)) by {
      arbitrary i:Nat
      assume i_lt: i < k'
      have si_lt: suc(i) < suc(k') by apply less_suc_iff_suc_less to i_lt
      have h: f(fromNat(s + suc(i))) = g(fromNat(t + suc(i))) by apply hyp[suc(i)] to si_lt
      have ssuc: s + suc(i) = suc(s) + i by {
        replace add_suc[s, i]
        expand operator+.
      }
      have tsuc: t + suc(i) = suc(t) + i by {
        replace add_suc[t, i]
        expand operator+.
      }
      replace ssuc | tsuc in h
    }
    have IH_step: int_summation_nat(k', suc(s), f) = int_summation_nat(k', suc(t), g)
      by apply IH[f, g, suc(s), suc(t)] to rest_hyp
    suffices f(fromNat(s)) + int_summation_nat(k', suc(s), f)
           = g(fromNat(t)) + int_summation_nat(k', suc(t), g)
      by expand int_summation_nat.
    replace fs_gt | IH_step.
  }
end

// UInt-indexed congruence: pointwise-equal summands over shifted ranges
// yield equal sums.
theorem int_summation_cong:
  all k:UInt. all f:fn UInt -> Int, g:fn UInt -> Int, s:UInt, t:UInt.
  if (all i:Nat. if i < toNat(k) then f(s + fromNat(i)) = g(t + fromNat(i)))
  then int_summation(k, s, f) = int_summation(k, t, g)
proof
  arbitrary k:UInt
  arbitrary f:fn UInt -> Int, g:fn UInt -> Int, s:UInt, t:UInt
  assume hyp
  expand int_summation
  apply int_summation_nat_cong[toNat(k), f, g, toNat(s), toNat(t)]
  to arbitrary i:Nat
     assume i_lt: i < toNat(k)
     have h: f(s + fromNat(i)) = g(t + fromNat(i)) by apply hyp[i] to i_lt
     replace fromNat_toNat_add_uint[s, i] | fromNat_toNat_add_uint[t, i] | h.
end

// Split a length `a + b` summation into prefix `g` (length `a`) and
// suffix `h` (length `b`), each agreeing pointwise with `f` on its segment.
lemma int_summation_nat_add:
  all a:Nat. all b:Nat, s:Nat, t:Nat,
      f:fn UInt -> Int, g:fn UInt -> Int, h:fn UInt -> Int.
  if (all i:Nat. if i < a then g(fromNat(s + i)) = f(fromNat(s + i)))
  and (all i:Nat. if i < b then h(fromNat(t + i)) = f(fromNat(s + a + i)))
  then int_summation_nat(a + b, s, f)
     = int_summation_nat(a, s, g) + int_summation_nat(b, t, h)
proof
  induction Nat
  case zero {
    arbitrary b:Nat, s:Nat, t:Nat,
      f:fn UInt -> Int, g:fn UInt -> Int, h:fn UInt -> Int
    assume gf_and_hf
    have hh: all i:Nat. if i < b then h(fromNat(t + i)) = f(fromNat(s + i)) by gf_and_hf
    have sym: all i:Nat. if i < b then f(fromNat(s + i)) = h(fromNat(t + i)) by {
      arbitrary i:Nat
      assume i_lt
      symmetric (apply hh[i] to i_lt)
    }
    suffices int_summation_nat(b, s, f) = int_summation_nat(b, t, h)
      by expand int_summation_nat.
    apply int_summation_nat_cong[b, f, h, s, t] to sym
  }
  case suc(a') assume IH {
    arbitrary b:Nat, s:Nat, t:Nat,
      f:fn UInt -> Int, g:fn UInt -> Int, h:fn UInt -> Int
    assume gf_and_hf
    have gf: all i:Nat. if i < suc(a') then g(fromNat(s + i)) = f(fromNat(s + i)) by gf_and_hf
    have hf: all i:Nat. if i < b then h(fromNat(t + i)) = f(fromNat(s + suc(a') + i)) by gf_and_hf
    // Split off the first term on both sides.
    have fs_gs: f(fromNat(s)) = g(fromNat(s)) by {
      have z_lt: ℕ0 < suc(a') by {
        expand lit
        expand operator< | operator≤.
      }
      symmetric (apply gf[ℕ0] to z_lt)
    }
    // Recursively split the rest.
    have gf_step: all i:Nat. if i < a' then g(fromNat(suc(s) + i)) = f(fromNat(suc(s) + i)) by {
      arbitrary i:Nat
      assume i_lt
      have si_lt: suc(i) < suc(a') by apply less_suc_iff_suc_less to i_lt
      have e0: g(fromNat(s + suc(i))) = f(fromNat(s + suc(i))) by apply gf[suc(i)] to si_lt
      have shift: s + suc(i) = suc(s) + i by {
        replace add_suc[s, i]
        expand operator+.
      }
      conclude g(fromNat(suc(s) + i)) = f(fromNat(suc(s) + i)) by replace shift in e0
    }
    have hf_step: all i:Nat. if i < b then h(fromNat(t + i)) = f(fromNat(suc(s) + a' + i)) by {
      arbitrary i:Nat
      assume i_lt
      have e0: h(fromNat(t + i)) = f(fromNat(s + suc(a') + i)) by apply hf[i] to i_lt
      have shift: s + suc(a') = suc(s) + a' by {
        replace add_suc[s, a']
        expand operator+.
      }
      conclude h(fromNat(t + i)) = f(fromNat(suc(s) + a' + i)) by replace shift in e0
    }
    have IH_step:
      int_summation_nat(a' + b, suc(s), f)
        = int_summation_nat(a', suc(s), g) + int_summation_nat(b, t, h)
      by apply IH[b, suc(s), t, f, g, h] to gf_step, hf_step
    // LHS: int_summation_nat(suc(a') + b, s, f).
    //   suc(a') + b = suc(a' + b), so this unfolds to
    //   f(fromNat(s)) + int_summation_nat(a' + b, suc(s), f).
    have lhs_shape: suc(a') + b = suc(a' + b) by expand operator+.
    suffices f(fromNat(s)) + int_summation_nat(a' + b, suc(s), f)
           = (g(fromNat(s)) + int_summation_nat(a', suc(s), g)) + int_summation_nat(b, t, h)
      by {
        replace lhs_shape
        expand int_summation_nat.
      }
    replace fs_gs | IH_step
    int_add_assoc[g(fromNat(s)), int_summation_nat(a', suc(s), g), int_summation_nat(b, t, h)]
  }
end

// Split a length `a + b` summation into prefix `g` (length `a`) and
// suffix `h` (length `b`), each agreeing pointwise with `f` on its segment.
theorem int_summation_add:
  all a:UInt. all b:UInt, s:UInt, t:UInt,
      f:fn UInt -> Int, g:fn UInt -> Int, h:fn UInt -> Int.
  if (all i:Nat. if i < toNat(a) then g(s + fromNat(i)) = f(s + fromNat(i)))
  and (all i:Nat. if i < toNat(b) then h(t + fromNat(i)) = f(s + a + fromNat(i)))
  then int_summation(a + b, s, f)
     = int_summation(a, s, g) + int_summation(b, t, h)
proof
  arbitrary a:UInt
  arbitrary b:UInt, s:UInt, t:UInt,
      f:fn UInt -> Int, g:fn UInt -> Int, h:fn UInt -> Int
  assume gf_and_hf
  expand int_summation
  replace toNat_add[a, b]
  have gf_nat: all i:Nat. if i < toNat(a)
        then g(fromNat(toNat(s) + i)) = f(fromNat(toNat(s) + i)) by {
    arbitrary i:Nat
    assume i_lt
    have h0: g(s + fromNat(i)) = f(s + fromNat(i))
      by apply (conjunct 0 of gf_and_hf)[i] to i_lt
    replace fromNat_toNat_add_uint[s, i]
    h0
  }
  have hf_nat: all i:Nat. if i < toNat(b)
        then h(fromNat(toNat(t) + i)) = f(fromNat(toNat(s) + toNat(a) + i)) by {
    arbitrary i:Nat
    assume i_lt
    have h0: h(t + fromNat(i)) = f(s + a + fromNat(i))
      by apply (conjunct 1 of gf_and_hf)[i] to i_lt
    replace fromNat_toNat_add_uint[t, i]
    have rhs_eq: f(fromNat(toNat(s) + toNat(a) + i)) = f(s + a + fromNat(i)) by {
      replace symmetric toNat_add[s, a]
      replace fromNat_toNat_add_uint[s + a, i].
    }
    replace rhs_eq
    h0
  }
  apply int_summation_nat_add[toNat(a), toNat(b), toNat(s), toNat(t), f, g, h]
    to gf_nat, hf_nat
end

/*
  Negation and subtraction interaction.

  The genuinely new content over `lib/UIntSum.pf`: distributing a
  unary minus through a sum, and the pointwise-subtraction identity.
*/

// Negating each summand negates the sum.
lemma int_summation_nat_neg:
  all k:Nat. all s:Nat, f:fn UInt -> Int.
  int_summation_nat(k, s, fun i:UInt { - f(i) })
    = - int_summation_nat(k, s, f)
proof
  induction Nat
  case zero {
    arbitrary s:Nat, f:fn UInt -> Int
    expand int_summation_nat
    symmetric neg_zero
  }
  case suc(k') assume IH {
    arbitrary s:Nat, f:fn UInt -> Int
    suffices (- f(fromNat(s))) + int_summation_nat(k', suc(s), fun i:UInt { - f(i) })
           = - (f(fromNat(s)) + int_summation_nat(k', suc(s), f))
      by expand int_summation_nat.
    replace IH[suc(s), f]
    symmetric neg_distr_add[f(fromNat(s)), int_summation_nat(k', suc(s), f)]
  }
end

// `Σ_i (- f(i)) = - Σ_i f(i)`.
theorem int_summation_neg:
  all k:UInt, s:UInt, f:fn UInt -> Int.
  int_summation(k, s, fun i:UInt { - f(i) }) = - int_summation(k, s, f)
proof
  arbitrary k:UInt, s:UInt, f:fn UInt -> Int
  expand int_summation
  int_summation_nat_neg[toNat(k), toNat(s), f]
end

// Pointwise sum: `Σ_i (f(i) + g(i)) = Σ_i f(i) + Σ_i g(i)`.
lemma int_summation_nat_add_pointwise:
  all k:Nat. all s:Nat, f:fn UInt -> Int, g:fn UInt -> Int.
  int_summation_nat(k, s, fun i:UInt { f(i) + g(i) })
    = int_summation_nat(k, s, f) + int_summation_nat(k, s, g)
proof
  induction Nat
  case zero {
    arbitrary s:Nat, f:fn UInt -> Int, g:fn UInt -> Int
    expand int_summation_nat.
  }
  case suc(k') assume IH {
    arbitrary s:Nat, f:fn UInt -> Int, g:fn UInt -> Int
    suffices (f(fromNat(s)) + g(fromNat(s)))
              + int_summation_nat(k', suc(s), fun i:UInt { f(i) + g(i) })
           = (f(fromNat(s)) + int_summation_nat(k', suc(s), f))
              + (g(fromNat(s)) + int_summation_nat(k', suc(s), g))
      by expand int_summation_nat.
    replace IH[suc(s), f, g]
    int_add_swap_inner[f(fromNat(s)),
                       g(fromNat(s)),
                       int_summation_nat(k', suc(s), f),
                       int_summation_nat(k', suc(s), g)]
  }
end

// Pointwise sum of summands: `Σ_i (f(i) + g(i)) = Σ_i f(i) + Σ_i g(i)`.
theorem int_summation_add_pointwise:
  all k:UInt, s:UInt, f:fn UInt -> Int, g:fn UInt -> Int.
  int_summation(k, s, fun i:UInt { f(i) + g(i) })
    = int_summation(k, s, f) + int_summation(k, s, g)
proof
  arbitrary k:UInt, s:UInt, f:fn UInt -> Int, g:fn UInt -> Int
  expand int_summation
  int_summation_nat_add_pointwise[toNat(k), toNat(s), f, g]
end

// Helper: binary subtraction unfolds to addition of negation. The definition
// `n - m = n + (- m)` is non-opaque, but `expand operator-` is ambiguous with
// unary negation, so this lemma packages the rewrite in a form callers can use.
lemma int_sub_as_add_neg: all x:Int, y:Int. x - y = x + (- y)
proof
  arbitrary x:Int, y:Int
  evaluate
end

// Pointwise difference: `Σ_i (f(i) - g(i)) = Σ_i f(i) - Σ_i g(i)`.
theorem int_summation_sub:
  all k:UInt, s:UInt, f:fn UInt -> Int, g:fn UInt -> Int.
  int_summation(k, s, fun i:UInt { f(i) - g(i) })
    = int_summation(k, s, f) - int_summation(k, s, g)
proof
  arbitrary k:UInt, s:UInt, f:fn UInt -> Int, g:fn UInt -> Int
  have cong_hyp: all i:Nat. if i < toNat(k)
        then (fun j:UInt { f(j) - g(j) })(s + fromNat(i))
           = (fun j:UInt { f(j) + (- g(j)) })(s + fromNat(i)) by {
    arbitrary i:Nat
    assume _
    show (fun j:UInt { f(j) - g(j) })(s + fromNat(i))
       = (fun j:UInt { f(j) + (- g(j)) })(s + fromNat(i))
    int_sub_as_add_neg[f(s + fromNat(i)), g(s + fromNat(i))]
  }
  have step1: int_summation(k, s, fun j:UInt { f(j) - g(j) })
            = int_summation(k, s, fun j:UInt { f(j) + (- g(j)) })
    by apply int_summation_cong[k, fun j:UInt { f(j) - g(j) },
                                fun j:UInt { f(j) + (- g(j)) }, s, s] to cong_hyp
  have cong_hyp2: all i:Nat. if i < toNat(k)
        then (fun j:UInt { f(j) + (- g(j)) })(s + fromNat(i))
           = (fun j:UInt { f(j) + (fun n:UInt { - g(n) })(j) })(s + fromNat(i)) by {
    arbitrary i:Nat
    assume _.
  }
  have step2: int_summation(k, s, fun j:UInt { f(j) + (- g(j)) })
            = int_summation(k, s, fun j:UInt { f(j) + (fun n:UInt { - g(n) })(j) })
    by apply int_summation_cong[k, fun j:UInt { f(j) + (- g(j)) },
                                fun j:UInt { f(j) + (fun n:UInt { - g(n) })(j) }, s, s] to cong_hyp2
  have step3: int_summation(k, s, fun j:UInt { f(j) + (fun n:UInt { - g(n) })(j) })
            = int_summation(k, s, f) + int_summation(k, s, fun n:UInt { - g(n) })
    by int_summation_add_pointwise[k, s, f, fun n:UInt { - g(n) }]
  have step4: int_summation(k, s, fun n:UInt { - g(n) }) = - int_summation(k, s, g)
    by int_summation_neg[k, s, g]
  have step5: int_summation(k, s, f) + (- int_summation(k, s, g))
            = int_summation(k, s, f) - int_summation(k, s, g) by {
    symmetric int_sub_as_add_neg[int_summation(k, s, f), int_summation(k, s, g)]
  }
  equations
        int_summation(k, s, fun j:UInt { f(j) - g(j) })
      = int_summation(k, s, fun j:UInt { f(j) + (- g(j)) })  by step1
  ... = int_summation(k, s, fun j:UInt { f(j) + (fun n:UInt { - g(n) })(j) })  by step2
  ... = int_summation(k, s, f) + int_summation(k, s, fun n:UInt { - g(n) })   by step3
  ... = int_summation(k, s, f) + (- int_summation(k, s, g))  by replace step4.
  ... = int_summation(k, s, f) - int_summation(k, s, g)      by step5
end

// Constant-zero summand: a sum of zeros is zero.
theorem int_summation_const_zero:
  all k:UInt, s:UInt.
  int_summation(k, s, fun i:UInt { +0 }) = +0
proof
  arbitrary k:UInt, s:UInt
  // Reduces to a Nat-level claim by induction on toNat(k).
  have inner: all n:Nat. all begin:Nat.
              int_summation_nat(n, begin, fun i:UInt { +0 }) = +0 by {
    induction Nat
    case zero {
      arbitrary begin:Nat
      expand int_summation_nat.
    }
    case suc(n') assume IH {
      arbitrary begin:Nat
      suffices +0 + int_summation_nat(n', suc(begin), fun i:UInt { +0 }) = +0
        by expand int_summation_nat.
      replace IH[suc(begin)].
    }
  }
  expand int_summation
  inner[toNat(k), toNat(s)]
end