module UInt

import Nat
import UIntDefs
import UIntToFrom
import UIntAdd
import UIntMult
import UIntMonus
import UIntLess

lemma expt_dub_inc: all n:UInt, p:UInt. n ^ dub_inc(p) = sqr(n * (n^p))
proof
  arbitrary n:UInt, p:UInt
  expand operator^ | expt.
end

lemma expt_inc_dub: all n:UInt, p:UInt. n ^ inc_dub(p) = n * sqr(n^p)
proof
  arbitrary n:UInt, p:UInt
  expand operator^ | expt.
end

theorem toNat_expt: all p:UInt, n:UInt.
  toNat(n^p) = toNat(n) ^ toNat(p)
proof
  induction UInt
  case bzero {
    arbitrary n:UInt
    evaluate
  }
  case dub_inc(p') assume IH {
    arbitrary n:UInt
    replace expt_dub_inc
    expand toNat
    replace pow_add_r
    expand sqr
    replace toNat_mult | toNat_mult | toNat_mult | IH 
    define x = toNat(n)
    define p = toNat(p')
    replace symmetric pow_mul_r[p, x, ℕ2] | pow_mul_l
    replace mult_commute[x ^ p, x].
  }
  case inc_dub(x') assume IH {
    arbitrary n:UInt
    replace expt_inc_dub
    expand toNat
    replace nat_suc_one_add | pow_add_r | toNat_mult
    expand sqr
    replace toNat_mult | IH | symmetric pow_mul_r[toNat(x'), toNat(n), ℕ2]
    replace pow_mul_l.
  }
end

theorem fromNat_expt: all x:Nat, y:Nat.
  fromNat(x^y) = fromNat(x)^fromNat(y)
proof
  arbitrary x:Nat, y:Nat
  suffices toNat(fromNat(x^y)) = toNat(fromNat(x)^fromNat(y))
    by uint_toNat_injective
  replace toNat_expt | uint_toNat_fromNat.
end

theorem lit_expt_fromNat: all x:Nat, y:Nat.
  fromNat(lit(x)) ^ fromNat(lit(y)) = fromNat(lit(x) ^ lit(y))
proof
  arbitrary x:Nat, y:Nat
  symmetric fromNat_expt
end

auto lit_expt_fromNat

theorem uint_expt_zero: all n:UInt.
  n ^ 0 = 1
proof
  arbitrary n:UInt
  suffices toNat(n^0) = toNat(1)   by uint_toNat_injective
  equations
      toNat(n^0) = toNat(n)^toNat(0)  by toNat_expt
             ... = ℕ1 by replace uint_toNat_fromNat.
             ... = # toNat(1) #           by replace uint_toNat_fromNat.
end

auto uint_expt_zero
  
postulate uint_expt_suc: all n:UInt, m:Nat. n ^ fromNat(lit(suc(m))) = n * n ^ fromNat(lit(m))
auto uint_expt_suc

theorem uint_expt_one: all n:UInt.
  n ^ 1 = n
proof
  arbitrary n:UInt
  suffices toNat(n^1) = toNat(n)   by uint_toNat_injective
  .
end

theorem uint_expt_two: all n:UInt.
  n ^ 2 = n * n
proof
  arbitrary n:UInt
  suffices toNat(n^2) = toNat(n * n) by uint_toNat_injective
  .
end

theorem uint_one_expt: all n:UInt.
  1 ^ n = 1
proof
  arbitrary n:UInt
  suffices toNat(1^n) = toNat(1) by uint_toNat_injective
  equations
      toNat(1^n) = toNat(1)^toNat(n)  by toNat_expt
             ... = ℕ1^toNat(n)        by { replace uint_toNat_fromNat. }
             ... = ℕ1                 by .
             ... = #toNat(1)#         by { replace uint_toNat_fromNat. }
end

theorem uint_pow_add_r : all m:UInt, n:UInt, o:UInt.
  m^(n + o) = m^n * m^o
proof
  arbitrary m:UInt, n:UInt, o:UInt
  suffices toNat(m^(n + o)) = toNat(m^n * m^o)  by uint_toNat_injective
  equations
      toNat(m^(n + o))
        = toNat(m) ^ toNat(n + o)           by toNat_expt
    ... = toNat(m) ^ (toNat(n) + toNat(o))  by replace toNat_add.
    ... = toNat(m) ^ toNat(n) * toNat(m) ^ toNat(o)  by replace pow_add_r.
    ... = #toNat(m^n) * toNat(m^o)#         by replace toNat_expt.
    ... = #toNat(m^n * m^o)#                by replace toNat_mult.
end

theorem uint_pow_mul_l : all m:UInt, n:UInt, o:UInt.
  (m * n)^o = m^o * n^o
proof
  arbitrary m:UInt, n:UInt, o:UInt
  suffices toNat((m * n)^o) = toNat(m^o * n^o)  by uint_toNat_injective
  equations
      toNat((m * n)^o)
        = toNat(m*n)^toNat(o)                   by replace toNat_expt.
    ... = (toNat(m)*toNat(n))^toNat(o)          by replace toNat_mult.
    ... = toNat(m)^toNat(o) * toNat(n)^toNat(o) by replace pow_mul_l.
    ... = #toNat(m^o) * toNat(n^o)#             by replace toNat_expt.
    ... = #toNat(m^o * n^o)#                    by replace toNat_mult.
end

theorem uint_pow_mul_r : all m : UInt, n : UInt, o : UInt.
  (m ^ n) ^ o = m ^ (n * o)
proof
  arbitrary m:UInt, n:UInt, o:UInt
  suffices toNat((m ^ n) ^ o) = toNat(m ^ (n * o))  by uint_toNat_injective
  equations
      toNat((m ^ n) ^ o)
        = (toNat(m) ^ toNat(n)) ^ toNat(o)   by replace toNat_expt | toNat_expt.
    ... = toNat(m) ^ (toNat(n) * toNat(o))   by pow_mul_r
    ... = #toNat(m ^ (n * o))#               by replace toNat_expt | toNat_mult.
end

postulate lit_pow_mul_r: all m:Nat, n:Nat, o:UInt.
  fromNat(lit(m)) ^ (fromNat(lit(n)) * o) = (fromNat(lit(m)) ^ fromNat(lit(n))) ^ o
auto lit_pow_mul_r

lemma pow_cnt_dubs_le_inc: all n:UInt. 2 ^ cnt_dubs(n)  inc(n)
proof
  induction UInt
  case bzero {
    evaluate
  }
  case dub_inc(n') assume IH {
    expand cnt_dubs | inc
    replace inc_add_one | uint_pow_add_r | inc_dub_add_mult2 | uint_dist_mult_add
    show 2 * 2 ^ cnt_dubs(n')  3 + 2 * n'
    have IH': 2 ^ cnt_dubs(n')  1 + n' by replace inc_add_one in IH
    have A: 2 * 2 ^ cnt_dubs(n')  2 * (1 + n') by apply uint_mult_mono_le[2] to IH'
    have B: 2 * (1 + n')  3 + 2 * n' by {
      replace uint_dist_mult_add
      apply uint_add_mono_less_equal[2,2*n',3,2*n'] to .
    }
    apply uint_less_equal_trans to A, B
  }
  case inc_dub(n') assume IH {
    expand cnt_dubs | inc
    replace inc_add_one | uint_pow_add_r | dub_inc_mult2_add | uint_dist_mult_add
    show 2 * 2 ^ cnt_dubs(n')  2 + 2 * n'
    have IH': 2 ^ cnt_dubs(n')  1 + n' by replace inc_add_one in IH
    have A: 2 * 2 ^ cnt_dubs(n')  2 * (1 + n') by apply uint_mult_mono_le[2] to IH'
    have B: 2 * (1 + n')  2 + 2 * n' by replace uint_dist_mult_add.
    apply uint_less_equal_trans to A, B
  }
end

lemma cnt_le_cnt_dub: all x:UInt.
  cnt_dubs(x)  cnt_dubs(dub(x)) 
proof
  induction UInt
  case bzero {
    expand dub.
  }
  case dub_inc(x') assume IH {
    expand dub | cnt_dubs | cnt_dubs
    replace inc_add_one | inc_add_one | uint_add_commute
    uint_less_equal_add
  }
  case inc_dub(x') assume IH {
    expand dub | cnt_dubs
    replace inc_add_one
    IH
  }
end

lemma cnt_dub_le_cnt: all x:UInt.
  cnt_dubs(dub(x))  1 + cnt_dubs(x)
proof
  induction UInt
  case bzero {
    expand dub | cnt_dubs
    evaluate
  }
  case dub_inc(x') assume IH {
    expand dub | cnt_dubs | cnt_dubs
    replace inc_add_one | inc_add_one.
  }
  case inc_dub(x') assume IH {
    expand dub | cnt_dubs
    replace inc_add_one
    apply uint_add_both_sides_of_less_equal[1] to IH
  }
end

postulate log_pow: all n:UInt. log(2^n) = n
  
theorem uint_expt_log_less_equal: all n:UInt. if 0 < n then 2^log(n)  n
proof
  arbitrary n:UInt
  assume pos
  obtain n' where n_n': n = 1 + n' from apply uint_positive_add_one[n] to pos
  replace n_n'
  expand log
  replace uint_pred_monus
  conclude 2 ^ cnt_dubs(n')  1 + n'
    by replace inc_add_one in pow_cnt_dubs_le_inc[n']
end


lemma inc_dub_less_dub : all n : UInt, m : UInt. if n < m then 1 + 2 * n < 2 * m
proof  
  arbitrary n : UInt, m : UInt
  assume prem
  have prem1 : 1 + n <= m by replace uint_less_is_less_equal[n][m] in prem
  have prem2 : 2 * (1 + n) <= 2 * m by apply uint_mult_mono_le[2] to prem1
  have prem3 : 1 + (1 + 2 * n) <= 2 * m by replace uint_dist_mult_add in prem2
  define g = 1 + 2 * n
  conclude g < 2 * m by replace symmetric uint_less_is_less_equal[g][2 * m] in prem3
end

lemma less_equal_pow_cnt_dubs: all n:UInt. 1 + n < 2 ^ (1 + cnt_dubs(n))
proof
  induction UInt
  case bzero {
    evaluate
  }
  case dub_inc(n') assume IH {
    expand cnt_dubs
    replace inc_add_one | uint_pow_add_r | dub_inc_mult2_add | uint_dist_mult_add
    show 3 + 2 * n' < 4 * 2 ^ cnt_dubs(n')
    have IH1: 1 + n' < 2 * 2 ^ cnt_dubs(n') by replace uint_pow_add_r in IH
    have IH2: 1 + 2*(1 + n') < 4 * 2 ^ cnt_dubs(n') 
      by apply inc_dub_less_dub[1+n', 2*2^cnt_dubs(n')] to ., IH1
    conclude 3 + 2*n' < 4 * 2 ^ cnt_dubs(n')
      by replace uint_dist_mult_add in IH2
  }
  case inc_dub(n') assume IH {
    expand cnt_dubs
    replace inc_add_one | uint_pow_add_r | inc_dub_add_mult2
    show 2 + 2 * n' < 4 * 2 ^ cnt_dubs(n')
    have IH1 : 2 * (1 + n') < 2 * 2 ^ (1 + cnt_dubs(n')) by
      apply uint_pos_mult_both_sides_of_less[2, 1+n', 2^(1+ cnt_dubs(n'))] to ., IH
    conclude  2 + 2 * n' < 4 * 2 ^ cnt_dubs(n') by
      replace uint_dist_mult_add[2, 1, n'] | uint_pow_add_r in IH1
  }
end

theorem less_pow_log: all n:UInt. if 0 < n then n < 2^(1 + log(n))
proof
  arbitrary n:UInt
  assume npos
  expand log
  replace uint_pred_monus
  obtain n' where n_n': n = 1 + n' from apply uint_positive_add_one[n] to npos
  replace n_n'
  less_equal_pow_cnt_dubs
end

/*
 The following rule is not literally true for UInt:
 log(3 * 3) = 3 ≠  1 + 1 = log(3) + log(3)

theorem log_product: all m:UInt, n:UInt. log(m * n) = log(m) + log(n)

 But it is true asymptotically...
  
*/

postulate uint_log_mono: all x:UInt, y:UInt. (if x  y then log(x)  log(y))

theorem uint_log_greater_one: all n:UInt. if 2  n then 1  log(n)
proof
  arbitrary n:UInt
  assume two_n
  have log2_1: log(2) = 1 by evaluate
  have one_log2: 1  log(2) by replace log2_1.
  have log2_logn: log(2)  log(n) by apply uint_log_mono to two_n
  conclude 1  log(n)
    by apply uint_less_equal_trans to one_log2, log2_logn
end

lemma cnt_dub_less_n: all n:UInt. cnt_dubs(n)  1 + n
proof
  induction UInt
  case bzero {
    expand cnt_dubs
    uint_bzero_le
  }
  case dub_inc(n') assume IH {
    expand cnt_dubs
    replace dub_inc_mult2_add | inc_add_one | uint_dist_mult_add
    have less: 1 + n'  2 + 2 * n' by {
      have A: 1 + n'  (1 + n') + (1 + n') by uint_less_equal_add
      have B: (1 + n') + (1 + n')  2 + 2*n'
        by replace uint_add_commute[n', 1] | symmetric uint_two_mult[n'].
      apply uint_less_equal_trans to A, B
    }
    apply uint_less_equal_trans to IH, less
  }
  case inc_dub(n') assume IH {
    expand cnt_dubs
    replace inc_dub_add_mult2 | inc_add_one
    have A: 1 + cnt_dubs(n')  2 + n'
      by apply uint_add_both_sides_of_less_equal[1] to IH
    have B: 2 + n'  2 + 2*n' by {
      have n_2n: n'  2*n' by {
        replace uint_two_mult
        uint_less_equal_add
      }
      apply uint_add_mono_less_equal[2, n', 2, 2*n'] to ., n_2n
    }
    apply uint_less_equal_trans to A, B
  }
end

theorem uint_logn_le_n: all n:UInt. log(n)  n
proof
  arbitrary n:UInt
  expand log
  have X: cnt_dubs(pred(n))  1 + pred(n) by cnt_dub_less_n[pred(n)]
  have Y: cnt_dubs(n  1)  1 + (n  1) by replace uint_pred_monus in X
  cases uint_zero_or_add_one[n]
  case nz {
    replace nz
    evaluate
  }
  case ns {
    obtain n' where n_n1: n = 1 + n' from ns
    replace n_n1 | uint_pred_monus
    replace n_n1 in Y
  }
end

postulate uint_log_add_le_log_mult: all m:UInt, n:UInt. log(m) + log(n)  log(m * n)

postulate uint_log_mult_le_log_add: all m:UInt, n:UInt. log(m * n)  1 + log(m) + log(n)

postulate uint_log_pos: all n:UInt. (if 1 < n then 0 < log(n))