import Base
import Nat

// This follows the representation for binary numbers in the Agda
// standard library. The representation is not the standard one, but
// it is easier to work with.

union UInt {
  bzero         // 0
  dub_inc(UInt) // 2(1 + x)
  inc_dub(UInt) // 1 + 2x
}

recursive toNat(UInt) -> Nat {
  toNat(bzero) = ℕ0
  toNat(dub_inc(x)) = ℕ2 * suc(toNat(x))
  toNat(inc_dub(x)) = suc(ℕ2 * toNat(x))
}

assert toNat(bzero)                   = ℕ0    //   0
assert toNat(inc_dub(bzero))          = ℕ1    //   1
assert toNat(dub_inc(bzero))          = ℕ2    //  10
assert toNat(inc_dub(inc_dub(bzero))) = ℕ3    //  11
assert toNat(dub_inc(inc_dub(bzero))) = ℕ4    // 100

recursive operator< (UInt, UInt) -> bool {
  operator < (bzero, y) =
    switch y {
      case bzero { false }
      case dub_inc(y') { true }
      case inc_dub(y') { true }
    }
  operator < (dub_inc(x'), y) =
    switch y {
      case bzero { false }
      case dub_inc(y') {  x' < y' }
      case inc_dub(y') {  x' < y' }
    }
  operator < (inc_dub(x'), y) = 
    switch y {
      case bzero { false }
      case dub_inc(y') { x' < y' or x' = y'  }
      case inc_dub(y') {  x' < y' }
    }
}

fun operator ≤ (x:UInt, y:UInt) {
  x < y or x = y
}

fun operator > (x:UInt, y:UInt) {
  y < x
}

fun operator ≥ (x:UInt, y:UInt) {
  y  x
}

recursive dub(UInt) -> UInt {
  dub(bzero) = bzero
  dub(dub_inc(x)) = dub_inc(inc_dub(x))
    // because 2*2*(1+x) = 4 + 4x = 2*(2+2x) = 2*(1+(1+2*x))
  dub(inc_dub(x)) = dub_inc(dub(x))
    // because 2*(1+2x) = 2*(1+2x)
}

recursive inc(UInt) -> UInt {
  inc(bzero) = inc_dub(bzero)
  inc(dub_inc(x)) = inc_dub(inc(x))
  inc(inc_dub(x)) = dub_inc(x)
}

fun pred(x:UInt) {
  switch x {
    case bzero { bzero }
    case dub_inc(x') { inc_dub(x') }
    case inc_dub(x') { dub(x') }
  }
}

recursive operator+(UInt, UInt) -> UInt {
  operator+(bzero, y) = y
  operator+(dub_inc(x), y) =
    switch y {
      case bzero { dub_inc(x) }
      case dub_inc(y') { dub_inc(inc(x + y')) }
      case inc_dub(y') { inc(dub_inc(x + y')) }
    }
  operator+(inc_dub(x), y) = 
    switch y {
      case bzero { inc_dub(x) }
      case dub_inc(y') { inc(dub_inc(x + y')) }
      case inc_dub(y') { inc(inc_dub(x + y')) }
    }
}

recursive operator ∸ (UInt, UInt) -> UInt {
  operator∸(bzero, y) = bzero
  operator∸(dub_inc(x), y) =
    switch y {
      case bzero { dub_inc(x) }
      case dub_inc(y') {  dub(x  y') } // 2(1+x) - 2(1+y') = 2 + 2x - 2 - 2y' = 2x - 2y'
      case inc_dub(y') {
        // 2(1+x) - (1+2y') = 2 + 2x - 1 - 2y' = 1+2(x-y')
        if x < y' then bzero
        else inc_dub(x  y')
      }
    }
  operator∸(inc_dub(x), y) =
    switch y {
      case bzero { inc_dub(x) }
      case dub_inc(y') {
        // 1 + 2x - 2(1+y') = 1 + 2x - 2 - 2y' = - 1 + 2(x - y')
        if x < y' then bzero else pred(dub(x  y'))
      }
      case inc_dub(y') { dub(x  y') }
    }
}

recursive operator *(UInt, UInt) -> UInt {
  operator*(bzero, y) = bzero
  operator*(dub_inc(x), y) =
    switch y {
      case bzero { bzero }
      case dub_inc(y') {
        // 2*(1 + x) * 2*(1 + y') = (2+2x)(2 + 2y')
        // = 4 + 4x + 4y' + 4xy'
        // = 2(2 + 2x + 2y' + 2xy')
        // = 2(2(1 + x + y' + xy'))
        dub(dub_inc(x + y' + x * y'))
      }
      case inc_dub(y') {
        // 2*(1 + x) * (1 + 2y') = 2*(1 + x + 2y' + 2xy')
        // = 2*(1 + x + 2y' + 2xy')
        dub_inc(x + dub(y' + x*y'))
      }
    }
  operator*(inc_dub(x), y) = 
    switch y {
      case bzero { bzero }
      case dub_inc(y') {
        // (1 + 2x)(2*(1+y')) = (1 + 2x)(2 + 2y')
        // = 2 + 4x + 2y' + 4xy'
        // = 2(1 + 2x + y' + 2xy')
        dub_inc(dub(x) + y' + dub(x * y'))
      }
      case inc_dub(y') {
        // (1 + 2x)(1 + 2y') = 1 + 2x + 2y' + 4xy'
        // = 1 + 2(x + y' + 2xy')
        inc_dub(x + y' + dub(x * y'))
      }
    }
}

fun sqr(a : UInt) { a * a }

recursive expt(UInt, UInt) -> UInt {
  expt(bzero, a) = inc_dub(bzero)
  expt(dub_inc(p), a) = sqr(a * expt(p, a))  // a^(2*(1+p)) = a^2 * (a^p)^2 = (a * a^p)^2
  expt(inc_dub(p), a) = a * sqr(expt(p, a))  // a^(1+2*p) = a * (a^p)^2
}

fun operator ^(a : UInt, b : UInt)  {
  expt(b, a)
}

recursive fromNat(Nat) -> UInt {
  fromNat(0) = bzero
  fromNat(suc(n)) = inc(fromNat(n))
}

fun max(x : UInt, y : UInt) {
  if x < y then y
  else x
}

fun min(x : UInt, y : UInt) {
  if x < y then x
  else y
}