module UInt

import Base
import Nat

/*
  Core UInt definitions.

  UInt uses a private binary representation inspired by the Agda standard
  library. The representation is hidden from clients; the public API
  exposes arithmetic and order while `toNat` provides the specification
  used in many proofs.
*/

// This follows the representation for binary numbers in the Agda
// standard library. The representation is not the standard one, but
// it is easier to work with.

private union Binary {
  bzero         // 0
  dub_inc(Binary) // 2(1 + x)
  inc_dub(Binary) // 1 + 2x
}

private union UIntView {
  UIntZero
  UIntSucc(Binary)
}
export UIntZero
export UIntSucc

// Interpret a UInt as a Nat.
opaque recursive toNat(Binary) -> Nat {
  toNat(bzero) = ℕ0
  toNat(dub_inc(x)) = ℕ2 * suc(toNat(x))
  toNat(inc_dub(x)) = suc(ℕ2 * toNat(x))
}

// Sanity checks documenting the first few binary representations.
assert toNat(bzero)                   = ℕ0    //   0
assert toNat(inc_dub(bzero))          = ℕ1    //   1
assert toNat(dub_inc(bzero))          = ℕ2    //  10
assert toNat(inc_dub(inc_dub(bzero))) = ℕ3    //  11
assert toNat(dub_inc(inc_dub(bzero))) = ℕ4    // 100

opaque recursive operator< (Binary, Binary) -> bool {
  operator < (bzero, y) =
    switch y {
      case bzero { false }
      case dub_inc(y') { true }
      case inc_dub(y') { true }
    }
  operator < (dub_inc(x'), y) =
    switch y {
      case bzero { false }
      case dub_inc(y') {  x' < y' }
      case inc_dub(y') {  x' < y' }
    }
  operator < (inc_dub(x'), y) = 
    switch y {
      case bzero { false }
      case dub_inc(y') { x' < y' or x' = y'  }
      case inc_dub(y') {  x' < y' }
    }
}

fun operator ≤ (x:Binary, y:Binary) {
  x < y or x = y
}

fun operator > (x:Binary, y:Binary) {
  y < x
}

fun operator ≥ (x:Binary, y:Binary) {
  y  x
}

private recursive dub(Binary) -> Binary {
  dub(bzero) = bzero
  dub(dub_inc(x)) = dub_inc(inc_dub(x))
    // because 2*2*(1+x) = 4 + 4x = 2*(2+2x) = 2*(1+(1+2*x))
  dub(inc_dub(x)) = dub_inc(dub(x))
    // because 2*(1+2x) = 2*(1+2x)
}

private recursive inc(Binary) -> Binary {
  inc(bzero) = inc_dub(bzero)
  inc(dub_inc(x)) = inc_dub(inc(x))
  inc(inc_dub(x)) = dub_inc(x)
}

private fun pred(x:Binary) {
  switch x {
    case bzero { bzero }
    case dub_inc(x') { inc_dub(x') }
    case inc_dub(x') { dub(x') }
  }
}

private fun uint_view(x:Binary) {
  if x = bzero then UIntZero else UIntSucc(pred(x))
}

private fun uint_unview(v:UIntView) {
  switch v {
    case UIntZero { bzero }
    case UIntSucc(p) { inc(p) }
  }
}

lemma dub_inc_is_dub_inc: all x:Binary. dub(inc(x)) = dub_inc(x)
proof
  induction Binary
  case bzero {
    evaluate
  }
  case dub_inc(x) assume IH {
    expand inc | dub
    replace IH
    .
  }
  case inc_dub(x) assume IH {
    evaluate
  }
end

lemma uint_view_inc: all x:Binary. uint_view(inc(x)) = UIntSucc(x)
proof
  induction Binary
  case bzero {
    evaluate
  }
  case dub_inc(x) assume IH {
    suffices uint_view(inc(dub_inc(x))) = UIntSucc(dub_inc(x)) by .
    expand uint_view | inc | pred
    replace dub_inc_is_dub_inc[x]
    .
  }
  case inc_dub(x) assume IH {
    evaluate
  }
end

lemma uint_view_unview: all v:UIntView. uint_view(uint_unview(v)) = v
proof
  arbitrary v:UIntView
  switch v {
    case UIntZero { evaluate }
    case UIntSucc(p) {
      expand uint_unview
      uint_view_inc[p]
    }
  }
end

view UInt {
  source Binary
  target UIntView
  into uint_view
  out uint_unview
  roundtrip uint_view_unview
}

opaque fun div2(b : UInt) {
  switch b {
    case bzero { bzero }
    case dub_inc(x) { inc(x) }     // 2*(1 + x) /2 = 1+x
    case inc_dub(x) { x }          // (1 + 2*x) / 2 = 1/2 + x = x
  }
}

private recursive cnt_dubs(Binary) -> Binary {
  cnt_dubs(bzero) = bzero
  cnt_dubs(dub_inc(x)) = inc(cnt_dubs(x))
  cnt_dubs(inc_dub(x)) = inc(cnt_dubs(x))
}

opaque fun log(b : UInt) {
  cnt_dubs(pred(b))
}

opaque recursive operator+(Binary, Binary) -> Binary {
  operator+(bzero, y) = y
  operator+(dub_inc(x), y) =
    switch y {
      case bzero { dub_inc(x) }
      case dub_inc(y') { dub_inc(inc(x + y')) }
      case inc_dub(y') { inc(dub_inc(x + y')) }
    }
  operator+(inc_dub(x), y) = 
    switch y {
      case bzero { inc_dub(x) }
      case dub_inc(y') { inc(dub_inc(x + y')) }
      case inc_dub(y') { inc(inc_dub(x + y')) }
    }
}

opaque recursive operator ∸ (Binary, Binary) -> Binary {
  operator∸(bzero, y) = bzero
  operator∸(dub_inc(x), y) =
    switch y {
      case bzero { dub_inc(x) }
      case dub_inc(y') {  dub(x  y') } // 2(1+x) - 2(1+y') = 2 + 2x - 2 - 2y' = 2x - 2y'
      case inc_dub(y') {
        // 2(1+x) - (1+2y') = 2 + 2x - 1 - 2y' = 1+2(x-y')
        if x < y' then bzero
        else inc_dub(x  y')
      }
    }
  operator∸(inc_dub(x), y) =
    switch y {
      case bzero { inc_dub(x) }
      case dub_inc(y') {
        // 1 + 2x - 2(1+y') = 1 + 2x - 2 - 2y' = - 1 + 2(x - y')
        if x < y' then bzero else pred(dub(x  y'))
      }
      case inc_dub(y') { dub(x  y') }
    }
}

opaque recursive operator *(Binary, Binary) -> Binary {
  operator*(bzero, y) = bzero
  operator*(dub_inc(x), y) =
    switch y {
      case bzero { bzero }
      case dub_inc(y') {
        // 2*(1 + x) * 2*(1 + y') = (2+2x)(2 + 2y')
        // = 4 + 4x + 4y' + 4xy'
        // = 2(2 + 2x + 2y' + 2xy')
        // = 2(2(1 + x + y' + xy'))
        dub(dub_inc(x + y' + x * y'))
      }
      case inc_dub(y') {
        // 2*(1 + x) * (1 + 2y') = 2*(1 + x + 2y' + 2xy')
        // = 2*(1 + x + 2y' + 2xy')
        dub_inc(x + dub(y' + x*y'))
      }
    }
  operator*(inc_dub(x), y) = 
    switch y {
      case bzero { bzero }
      case dub_inc(y') {
        // (1 + 2x)(2*(1+y')) = (1 + 2x)(2 + 2y')
        // = 2 + 4x + 2y' + 4xy'
        // = 2(1 + 2x + y' + 2xy')
        dub_inc(dub(x) + y' + dub(x * y'))
      }
      case inc_dub(y') {
        // (1 + 2x)(1 + 2y') = 1 + 2x + 2y' + 4xy'
        // = 1 + 2(x + y' + 2xy')
        inc_dub(x + y' + dub(x * y'))
      }
    }
}

fun sqr(a : UInt) { a * a }

private recursive expt(Binary, Binary) -> Binary {
  expt(bzero, a) = inc_dub(bzero)
  expt(dub_inc(p), a) = sqr(a * expt(p, a))  // a^(2*(1+p)) = a^2 * (a^p)^2 = (a * a^p)^2
  expt(inc_dub(p), a) = a * sqr(expt(p, a))  // a^(1+2*p) = a * (a^p)^2
}

opaque fun operator ^(a : UInt, b : UInt)  {
  expt(b, a)
}

opaque recursive fromNat(Nat) -> UInt {
  fromNat(0) = bzero
  fromNat(suc(n)) = inc(fromNat(n))
}

fun max(x : UInt, y : UInt) {
  if x < y then y
  else x
}

fun min(x : UInt, y : UInt) {
  if x < y then x
  else y
}