module BigO

import UInt

fun operator + (f : fn UInt -> UInt, g : fn UInt -> UInt) {
  fun x:UInt {
    f(x) + g(x)
  }
}

fun operator * (f : fn UInt -> UInt, g : fn UInt -> UInt) {
  fun x:UInt {
    f(x) * g(x)
  }
}

fun operator * (k : UInt, g : fn UInt -> UInt) {
  fun x:UInt {
    k * g(x)
  }
}

fun operator ≲ (f : fn UInt -> UInt, g : fn UInt -> UInt) {
  some n0 :UInt, c:UInt. all n:UInt. if n0  n then f(n)  c * g(n)
}

fun operator ≈ (f : fn UInt -> UInt, g : fn UInt -> UInt) {
  (f  g) and (g  f)
}

fun log(f : fn UInt -> UInt) {
  fun x:UInt { log(f(x)) }
}


// Constant Time
fun O_1(n:UInt) { 1 }

// Logarithmic Time
fun O_log(n:UInt) { log(n) }

// Linear Time
fun O_n(n:UInt) { n }

// Quadratic Time
fun O_n2(n:UInt) { n * n }

theorem bigo_refl: all f : fn UInt -> UInt. f  f
proof
  arbitrary f : fn UInt -> UInt
  expand operator≲
  choose 0, 1
  arbitrary n:UInt
  assume prem
  show f(n)  1 * f(n)
  uint_less_equal_refl
end

theorem bigo_trans: all f : fn UInt -> UInt, g : fn UInt -> UInt, h : fn UInt -> UInt.
  if f  g and g  h then f  h
proof
  arbitrary f : fn UInt -> UInt, g : fn UInt -> UInt, h : fn UInt -> UInt
  assume prem
  have fg: f  g by prem
  obtain n1, c1 where fg_all: all n:UInt. if n1  n then f(n)  c1 * g(n)
    from expand operator≲ in fg
  have gh: g  h by prem
  obtain n2, c2 where gh_all: all n:UInt. if n2  n then g(n)  c2 * h(n)
    from expand operator≲ in gh
  expand operator≲
  choose n1+n2, c1*c2
  arbitrary n:UInt
  assume n12_n
  show f(n)  (c1 * c2) * h(n)

  have n1_n: n1  n by {
    have n1_n12: n1  n1 + n2 by uint_less_equal_add
    apply uint_less_equal_trans to n1_n12, n12_n
  }
  have fg_n: f(n)  c1 * g(n) by apply fg_all to n1_n
  
  have n2_n: n2  n by {
    have n2_n21: n2  n1 + n2 by { replace uint_add_commute uint_less_equal_add }
    apply uint_less_equal_trans to n2_n21, n12_n
  }
  have gh_n: g(n)  c2 * h(n) by apply gh_all to n2_n

  have c1g_c12_h: c1 * g(n)  c1 * c2 * h(n)
    by apply uint_mult_mono_le[c1, g(n), c2*h(n)] to gh_n
  
  conclude f(n)  c1 * c2 * h(n)
    by apply uint_less_equal_trans to fg_n, c1g_c12_h
end

theorem bigo_add: all f: fn UInt->UInt. f + f  f
proof
  arbitrary f: fn UInt->UInt
  expand operator≲
  choose 0, 2
  arbitrary n:UInt
  assume prem
  expand operator+
  show f(n) + f(n)  2 * f(n)
  replace uint_two_mult.
end

theorem bigo_const_mult: all f: fn UInt -> UInt, k:UInt.
  k * f  f
proof
  arbitrary f: fn UInt -> UInt
  define P = fun c:UInt { c * f  f }
  have base: P(0) by {
    expand P | operator≲ | operator*.
  }
  have ind: all m:UInt. if P(m) then P(1 + m) by {
    arbitrary m:UInt
    expand P
    assume mf_f: m * f  f 
    obtain n0, c where mf_f_all: all n:UInt. if n0  n then m * f(n)  c * f(n)
      from expand operator≲ | operator* in mf_f
    expand operator≲
    choose n0, 2*(1 + c)
    arbitrary n:UInt
    assume n0_n
    expand operator*
    replace uint_dist_mult_add_right | uint_dist_mult_add
    show f(n) + m * f(n)  2 * f(n) + 2 * c * f(n)
    
    have fn_cfn: f(n)  2 * f(n) by replace uint_two_mult uint_less_equal_add
    have mfn_2cfn: m * f(n)  2 * c * f(n) by {
      have IH: m * f(n)  c * f(n) by apply mf_f_all[n] to n0_n
      have cfn_2cfn: c * f(n)  2 * c * f(n) by {
        apply uint_mult_mono_le2[1,c*f(n),2,c*f(n)] to .
      }
      apply uint_less_equal_trans to IH, cfn_2cfn
    }
    conclude f(n) + m * f(n)  2 * f(n) + 2 * c * f(n)
      by apply uint_add_mono_less_equal to fn_cfn, mfn_2cfn
  }
  expand P in apply uint_induction[P] to base, ind
end

theorem bigo_mult_mono:
  all f1: fn UInt->UInt, f2: fn UInt->UInt, g1: fn UInt->UInt, g2: fn UInt->UInt.
  if f1  g1 and f2  g2
  then f1 * f2  g1 * g2
proof
  arbitrary f1: fn UInt->UInt, f2: fn UInt->UInt, g1: fn UInt->UInt, g2: fn UInt->UInt
  assume prem
  expand operator≲ | operator*
  have f1_g1: f1  g1 by prem
  obtain n1, c1 where f1_c1g1_all: all n:UInt. if n1  n then f1(n)  c1 * g1(n)
    from expand operator≲ in f1_g1
  have f2_g2: f2  g2 by prem
  obtain n2, c2 where f2_c2g2_all: all n:UInt. if n2  n then f2(n)  c2 * g2(n)
    from expand operator≲ in f2_g2
  choose n1 + n2, c1 * c2
  arbitrary n:UInt
  assume n12_n: n1 + n2  n
  show f1(n) * f2(n)  c1 * c2 * g1(n) * g2(n)

  have n1_n: n1  n by {
    have n1_n12: n1  n1 + n2 by uint_less_equal_add
    apply uint_less_equal_trans to n1_n12, n12_n
  }
  have f1_c1g1: f1(n)  c1 * g1(n)
    by apply f1_c1g1_all to n1_n
  
  have n2_n: n2  n by {
    have n2_n12: n2  n1 + n2 by { replace uint_add_commute uint_less_equal_add }
    apply uint_less_equal_trans to n2_n12, n12_n
  }
  have f2_c2g2: f2(n)  c2 * g2(n)
    by apply f2_c2g2_all to n2_n
  
  have f12_c1g1_c2g2: f1(n) * f2(n)  c1 * g1(n) * c2 * g2(n)
    by apply uint_mult_mono_le2[f1(n), f2(n)] to f1_c1g1, f2_c2g2

  conclude f1(n) * f2(n)  c1 * c2 * g1(n) * g2(n)
    by replace uint_mult_commute[g1(n),c2] in f12_c1g1_c2g2
end

theorem constant_le_log: O_1  O_log
proof
  expand operator≲ | O_1 | O_log
  choose 2, 1
  arbitrary n:UInt
  assume two_n
  apply uint_log_greater_one to two_n
end

theorem log_le_linear: O_log  O_n
proof
  expand operator≲ | O_log | O_n
  choose 0, 1
  arbitrary n:UInt
  assume _
  uint_logn_le_n
end

theorem linear_le_quadratic: O_n  O_n2
proof
  expand operator≲ | O_n | O_n2
  choose 1, 1
  arbitrary n:UInt
  assume pos
  have n_n: n  n by .
  conclude 1 * n  n * n by apply uint_mult_mono_le2 to pos, n_n
end

fun O_2n(n:UInt) { 2*n }

theorem O_n_eq_O_2n: O_n  O_2n
proof
  expand operator≈
  have A: O_n  O_2n by {
    expand operator≲
    choose 0, 1
    arbitrary n:UInt
    assume _
    expand O_n | O_2n
    replace uint_two_mult
    uint_less_equal_add
  }
  have B: O_2n  O_n by {
    expand operator≲
    choose 0, 2
    arbitrary n:UInt
    assume prem
    expand O_n | O_2n.
  }
  A, B
end

fun positive(f : fn UInt -> UInt) {
  all n:UInt. 0 < f(n)
}

theorem log_product_less_equal_sum:
  all f:fn UInt -> UInt, g:fn UInt -> UInt.
  if (all n:UInt. 1 < f(n)) and (all n:UInt. 1 < g(n))
  then log(f * g)  log(f) + log(g)
proof
  arbitrary f:fn UInt -> UInt, g:fn UInt -> UInt
  assume fg_gt_1
  expand operator≲ | operator+ | operator* | log
  choose 0, 2
  arbitrary n:UInt
  assume _
  have A: log(f(n) * g(n))  1 + log(f(n)) + log(g(n)) by uint_log_mult_le_log_add
  have B: 1 + log(f(n)) + log(g(n))  2 * (log(f(n)) + log(g(n))) by {
    have log_fg_pos: 0 < log(f(n)) + log(g(n)) by {
      have log_f_pos: 0 < log(f(n)) by apply uint_log_pos[f(n)] to fg_gt_1
      have log_g_pos: 0 < log(g(n)) by apply uint_log_pos[g(n)] to fg_gt_1
      apply uint_add_mono_less[0, 0, log(f(n)), log(g(n))] to log_f_pos, log_g_pos
    }
    replace uint_two_mult
    apply uint_add_one_le_double to log_fg_pos
  }
  apply uint_less_equal_trans to A, B
end

theorem log_sum_less_equal_product:
  all f:fn UInt -> UInt, g:fn UInt -> UInt.
  log(f) + log(g)  log(f * g)
proof
  arbitrary f:fn UInt -> UInt, g:fn UInt -> UInt
  expand operator≲ | operator+ | operator* | log
  choose 0, 1
  arbitrary n:UInt
  assume _
  uint_log_add_le_log_mult
end

theorem log_product_equiv_sum:
  all f:fn UInt -> UInt, g:fn UInt -> UInt.
  if (all n:UInt. 1 < f(n)) and (all n:UInt. 1 < g(n))
  then log(f * g)  log(f) + log(g)
proof
  arbitrary f:fn UInt -> UInt, g:fn UInt -> UInt
  assume fg_gt_1
  expand operator≈
  have A: log(f * g)  log(f) + log(g) by apply log_product_less_equal_sum[f,g] to fg_gt_1
  have B: log(f) + log(g)  log(f * g) by log_sum_less_equal_product
  A, B
end