module Int

import Nat
import UInt
import Base

import IntDefs
import IntAddSub
import IntMult
import IntLess
import IntAbs
import IntEvenOdd

/*
  Theorems about Int exponentiation.

  Integer powers take a UInt exponent. The operator is defined by
  recursion on the Nat view of the exponent (so that `induction Nat`
  proofs can run on the underlying Nat structure). The bridge
  `int_pow_pos: pos(m)^k = pos(m^k)` lets nonnegative-base reasoning
  reuse the UInt power library. Negative bases factor through the
  parity of the exponent: `(-n)^k = n^k` when `k` is even, and
  `(-n)^k = -(n^k)` when `k` is odd.
*/

// Worker recursing on the Nat exponent.
private recursive expt_nat(Nat, Int) -> Int {
  expt_nat(0, a) = +1
  expt_nat(suc(p), a) = a * expt_nat(p, a)
}

opaque fun operator ^(a : Int, b : UInt) {
  expt_nat(toNat(b), a)
}

// Reduction lemmas for the private worker.
lemma expt_nat_zero: all n:Int. expt_nat(zero, n) = +1
proof
  arbitrary n:Int
  evaluate
end

lemma int_toNat_zero: toNat((0:UInt)) = zero
proof
  have h: toNat(fromNat(ℕ0)) = zero by uint_toNat_fromNat[ℕ0]
  h
end

lemma int_toNat_one: toNat((1:UInt)) = ℕ1
proof
  have h: toNat(fromNat(ℕ1)) = ℕ1 by uint_toNat_fromNat[ℕ1]
  h
end

lemma expt_nat_suc: all p:Nat, n:Int.
  expt_nat(suc(p), n) = n * expt_nat(p, n)
proof
  arbitrary p:Nat, n:Int
  evaluate
end

/*
  Unfolding `^` at UInt zero and at `1 + k`. These are the workhorse
  rules at the `n^k` level; the rest of the file uses them and avoids
  `expt_nat` directly.
*/

theorem int_pow_zero: all n:Int. n ^ (0:UInt) = +1
proof
  arbitrary n:Int
  expand operator^
  replace int_toNat_zero
  expt_nat_zero[n]
end

auto int_pow_zero

theorem int_pow_one_add: all n:Int, k:UInt. n ^ (1 + k) = n * n ^ k
proof
  arbitrary n:Int, k:UInt
  expand operator^
  have nat_eq: toNat((1:UInt) + k) = suc(toNat(k)) by {
    replace toNat_add | int_toNat_one
    show ℕ1 + toNat(k) = suc(toNat(k))
    evaluate
  }
  replace nat_eq
  expt_nat_suc[toNat(k), n]
end

/*
  Addition and multiplication in the exponent. Proven on the Nat
  exponent via `expt_nat`, then transported to `^`.
*/

lemma expt_nat_add: all j:Nat, k:Nat, n:Int.
  expt_nat(j + k, n) = expt_nat(j, n) * expt_nat(k, n)
proof
  induction Nat
  case zero {
    arbitrary k:Nat, n:Int
    show expt_nat(zero + k, n) = expt_nat(zero, n) * expt_nat(k, n)
    replace expt_nat_zero.
  }
  case suc(j') assume IH {
    arbitrary k:Nat, n:Int
    have step: suc(j') + k = suc(j' + k) by expand operator+.
    replace step | expt_nat_suc | IH[k, n].
  }
end

theorem int_pow_add: all n:Int, j:UInt, k:UInt.
  n ^ (j + k) = n ^ j * n ^ k
proof
  arbitrary n:Int, j:UInt, k:UInt
  expand operator^
  replace toNat_add
  expt_nat_add[toNat(j), toNat(k), n]
end

lemma expt_nat_mul_base: all k:Nat, a:Int, b:Int.
  expt_nat(k, a * b) = expt_nat(k, a) * expt_nat(k, b)
proof
  induction Nat
  case zero {
    arbitrary a:Int, b:Int
    replace expt_nat_zero.
  }
  case suc(k') assume IH {
    arbitrary a:Int, b:Int
    replace expt_nat_suc | IH[a, b]
    show a * b * (expt_nat(k', a) * expt_nat(k', b))
       = a * expt_nat(k', a) * (b * expt_nat(k', b))
    replace int_mult_commute[b, expt_nat(k', a)].
  }
end

theorem int_pow_mul_base: all k:UInt, a:Int, b:Int.
  (a * b) ^ k = a ^ k * b ^ k
proof
  arbitrary k:UInt, a:Int, b:Int
  expand operator^
  expt_nat_mul_base[toNat(k), a, b]
end

lemma expt_nat_mul_exp: all j:Nat, k:Nat, n:Int.
  expt_nat(j * k, n) = expt_nat(j, expt_nat(k, n))
proof
  induction Nat
  case zero {
    arbitrary k:Nat, n:Int
    replace expt_nat_zero.
  }
  case suc(j') assume IH {
    arbitrary k:Nat, n:Int
    have step: suc(j') * k = k + j' * k by expand operator*.
    replace step | expt_nat_add | IH[k, n] | expt_nat_suc.
  }
end

theorem int_pow_mul_exp: all n:Int, j:UInt, k:UInt.
  n ^ (j * k) = (n ^ k) ^ j
proof
  arbitrary n:Int, j:UInt, k:UInt
  expand operator^
  replace toNat_mult
  expt_nat_mul_exp[toNat(j), toNat(k), n]
end

/*
  Basic literal-form laws.
*/

theorem int_pow_one: all n:Int. n ^ (1:UInt) = n
proof
  arbitrary n:Int
  int_pow_one_add[n, 0]
end

auto int_pow_one

theorem int_pow_two: all n:Int. n ^ (2:UInt) = n * n
proof
  arbitrary n:Int
  int_pow_one_add[n, 1]
end

/*
  Bridge to UInt powers when the base is nonnegative. Proven by
  `uint_induction` over the exponent in 0 / 1+k' form.
*/

theorem int_pow_pos: all m:UInt, k:UInt. pos(m) ^ k = pos(m ^ k)
proof
  arbitrary m:UInt
  define P = fun x:UInt { pos(m) ^ x = pos(m ^ x) }
  have base_case: P(0) by {
    expand P.
  }
  have ind: all x:UInt. if P(x) then P(1 + x) by {
    arbitrary x:UInt
    assume IH
    have IH_e: pos(m) ^ x = pos(m ^ x) by expand P in IH
    expand P
    have unfold_int: pos(m) ^ (1 + x) = pos(m) * pos(m) ^ x
      by int_pow_one_add[pos(m), x]
    have unfold_uint: m ^ (1 + x) = m * m ^ x by {
      replace uint_pow_add_r[m, 1, x].
    }
    replace unfold_int | IH_e | unfold_uint
    show pos(m) * pos(m ^ x) = pos(m * m ^ x)
    replace mult_pos_pos.
  }
  arbitrary k:UInt
  expand P in apply uint_induction[P] to base_case, ind
end

auto int_pow_pos

theorem int_one_pow: all k:UInt. (+1) ^ k = +1
proof
  arbitrary k:UInt.
end

auto int_one_pow

theorem int_zero_pow: all k:UInt. if 0 < k then (+0) ^ k = +0
proof
  arbitrary k:UInt
  assume k_pos
  show pos(0) ^ k = +0
  replace apply uint_zero_pow to k_pos.
end

/*
  Absolute value distributes over `^`. The bridge for magnitude-only
  facts.
*/

theorem int_abs_pow: all n:Int, k:UInt. abs(n ^ k) = abs(n) ^ k
proof
  arbitrary n:Int
  define Q = fun x:UInt { abs(n ^ x) = abs(n) ^ x }
  have base_case: Q(0) by {
    expand Q.
  }
  have ind: all x:UInt. if Q(x) then Q(1 + x) by {
    arbitrary x:UInt
    assume IH
    have IH_e: abs(n ^ x) = abs(n) ^ x by expand Q in IH
    expand Q
    have unfold_int: n ^ (1 + x) = n * n ^ x by int_pow_one_add[n, x]
    have unfold_uint: abs(n) ^ (1 + x) = abs(n) * abs(n) ^ x by {
      replace uint_pow_add_r[abs(n), 1, x].
    }
    replace unfold_int | unfold_uint
    show abs(n * n ^ x) = abs(n) * abs(n) ^ x
    replace int_abs_mult
    show abs(n) * abs(n ^ x) = abs(n) * abs(n) ^ x
    replace IH_e.
  }
  arbitrary k:UInt
  expand Q in apply uint_induction[Q] to base_case, ind
end

/*
  Zero/one characterizations.
*/

theorem int_pow_eq_zero: all n:Int, k:UInt.
  if n ^ k = +0 then n = +0
proof
  arbitrary n:Int, k:UInt
  assume prem: n ^ k = +0
  have abs_eq: abs(n) ^ k = 0 by {
    have h: abs(n ^ k) = abs((+0):Int) by replace prem.
    replace int_abs_pow in h
  }
  have abs_zero: abs(n) = 0 by apply uint_pow_eq_zero[abs(n), k] to abs_eq
  apply int_abs_eq_zero_implies_zero to abs_zero
end

theorem int_pow_nonzero: all n:Int, k:UInt.
  if not (n = +0) then not (n ^ k = +0)
proof
  arbitrary n:Int, k:UInt
  assume n_nz: not (n = +0)
  assume eq: n ^ k = +0
  have n_z: n = +0 by apply int_pow_eq_zero[n, k] to eq
  apply n_nz to n_z
end

/*
  Nonnegativity / positivity for nonnegative bases.
*/

theorem int_pow_nonneg: all n:Int, k:UInt.
  if +0  n then +0  n ^ k
proof
  arbitrary n:Int, k:UInt
  assume n_nn: +0  n
  have abs_eq: pos(abs(n)) = n by
    apply (conjunct 1 of int_pos_abs_eq_iff_nonneg[n]) to n_nn
  replace symmetric abs_eq.
end

theorem int_pow_pos_of_pos: all n:Int, k:UInt.
  if +0 < n then +0 < n ^ k
proof
  arbitrary n:Int, k:UInt
  assume n_pos: +0 < n
  switch n {
    case pos(n') assume n_eq: n = pos(n') {
      have h: +0 < pos(n') by replace n_eq in n_pos
      have n'_pos: 0 < n' by h
      have inner: 0 < n' ^ k by apply uint_pow_pos[n', k] to n'_pos
      have done: pos(0) < pos(n' ^ k) by inner
      done
    }
    case negsuc(n') assume n_eq: n = negsuc(n') {
      have h: +0 < negsuc(n') by replace n_eq in n_pos
      conclude false by h
    }
  }
end

/*
  Sign / parity behavior for negative bases.

  The base case `(-n)^(2*m)`-vs-`(-n)^(1+2*m)` reduces by
  `int_pow_mul_exp` and `int_pow_one_add` to whether `(-n) * (-n) = n * n`,
  which holds by `dist_neg_mult` applied twice. We then lift via the
  `Even`/`Odd` predicates from `IntEvenOdd.pf`'s UInt counterpart in
  `UIntEvenOdd.pf`.
*/

lemma int_neg_mult_neg: all n:Int. (- n) * (- n) = n * n
proof
  arbitrary n:Int
  have s1: (- n) * (- n) = - (n * (- n)) by symmetric dist_neg_mult[n, - n]
  have s2: n * (- n) = (- n) * n by int_mult_commute[n, - n]
  have s3: (- n) * n = - (n * n) by symmetric dist_neg_mult[n, n]
  equations
      (- n) * (- n)
        = - (n * (- n))  by s1
    ... = - ((- n) * n)  by replace s2.
    ... = - (- (n * n))  by replace s3.
    ... = n * n          by replace neg_involutive.
end

theorem int_pow_neg_base_even: all m:UInt, n:Int.
  (- n) ^ (2 * m) = n ^ (2 * m)
proof
  arbitrary m:UInt, n:Int
  have commute_exp: 2 * m = m * 2 by uint_mult_commute[2, m]
  have step1: (- n) ^ (2 * m) = ((- n) ^ (2:UInt)) ^ m by {
    replace commute_exp
    int_pow_mul_exp[- n, m, 2]
  }
  have step2: ((- n) ^ (2:UInt)) ^ m = ((- n) * (- n)) ^ m
    by replace int_pow_two.
  have step3: ((- n) * (- n)) ^ m = (n * n) ^ m
    by replace int_neg_mult_neg.
  have step4: (n * n) ^ m = (n ^ (2:UInt)) ^ m
    by replace int_pow_two.
  have step5: (n ^ (2:UInt)) ^ m = n ^ (2 * m) by {
    replace commute_exp
    symmetric int_pow_mul_exp[n, m, 2]
  }
  transitive step1 (transitive step2 (transitive step3 (transitive step4 step5)))
end

theorem int_pow_neg_base_odd: all m:UInt, n:Int.
  (- n) ^ (1 + 2 * m) = - (n ^ (1 + 2 * m))
proof
  arbitrary m:UInt, n:Int
  have a: (- n) ^ (1 + 2 * m) = (- n) * (- n) ^ (2 * m)
    by int_pow_one_add[- n, 2 * m]
  have b: (- n) ^ (2 * m) = n ^ (2 * m) by int_pow_neg_base_even[m, n]
  have c: (- n) ^ (1 + 2 * m) = (- n) * n ^ (2 * m) by replace b in a
  have d: (- n) * n ^ (2 * m) = - (n * n ^ (2 * m))
    by symmetric dist_neg_mult[n, n ^ (2 * m)]
  have e: n * n ^ (2 * m) = n ^ (1 + 2 * m)
    by symmetric int_pow_one_add[n, 2 * m]
  equations
      (- n) ^ (1 + 2 * m)
        = (- n) * n ^ (2 * m)    by c
    ... = - (n * n ^ (2 * m))    by d
    ... = - (n ^ (1 + 2 * m))    by replace e.
end

theorem int_pow_neg_base_when_even: all n:Int, k:UInt.
  if Even(k) then (- n) ^ k = n ^ k
proof
  arbitrary n:Int, k:UInt
  assume even_k: Even(k)
  obtain m where k_eq: k = 2 * m from expand Even in even_k
  replace k_eq
  int_pow_neg_base_even[m, n]
end

theorem int_pow_neg_base_when_odd: all n:Int, k:UInt.
  if Odd(k) then (- n) ^ k = - (n ^ k)
proof
  arbitrary n:Int, k:UInt
  assume odd_k: Odd(k)
  obtain m where k_eq: k = 1 + 2 * m from expand Odd in odd_k
  replace k_eq
  int_pow_neg_base_odd[m, n]
end

/*
  `(-1)^k` is `+1` for even `k` and `-1` for odd `k`.
*/

theorem int_neg_one_pow_even: all k:UInt.
  if Even(k) then (- +1) ^ k = +1
proof
  arbitrary k:UInt
  assume even_k
  apply int_pow_neg_base_when_even[+1, k] to even_k
end

theorem int_neg_one_pow_odd: all k:UInt.
  if Odd(k) then (- +1) ^ k = - +1
proof
  arbitrary k:UInt
  assume odd_k
  apply int_pow_neg_base_when_odd[+1, k] to odd_k
end