package scala
package collection.mutable

import scala.annotation.tailrec
import collection.Iterator

import java.lang.String

/**
 * An object containing the red-black tree implementation used by mutable `TreeMaps`.
 *
 * The trees implemented in this object are *not* thread safe.
 *
 * @author Rui Gonçalves
 * @version 2.12
 * @since 2.12
 */
private[collection] object RedBlackTree {

  // ---- class structure ----

  // For performance reasons, this implementation uses `null` references to represent leaves instead of a sentinel node.
  // Currently, the internal nodes do not store their subtree size - only the tree object keeps track of their size.
  // Therefore, while obtaining the size of the whole tree is O(1), knowing the number of entries inside a range is O(n)
  // on the size of the range.

  @SerialVersionUID(3L)
  final class Tree[A, B](var root: Node[A, B], var size: Int) extends Serializable

  @SerialVersionUID(3L)
  final class Node[A, B](var key: A, var value: B, var red: Boolean,
                         var left: Node[A, B], var right: Node[A, B], var parent: Node[A, B]) extends Serializable {

    override def toString: String = "Node(" + key + ", " + value + ", " + red + ", " + left + ", " + right + ")"
  }

  object Tree {
    def empty[A, B]: Tree[A, B] = new Tree(null, 0)
  }

  object Node {

    @`inline` def apply[A, B](key: A, value: B, red: Boolean,
                            left: Node[A, B], right: Node[A, B], parent: Node[A, B]): Node[A, B] =
      new Node(key, value, red, left, right, parent)

    @`inline` def leaf[A, B](key: A, value: B, red: Boolean, parent: Node[A, B]): Node[A, B] =
      new Node(key, value, red, null, null, parent)

    def unapply[A, B](t: Node[A, B]) = Some((t.key, t.value, t.left, t.right, t.parent))
  }

  // ---- getters ----

  def isRed(node: Node[_, _]) = (node ne null) && node.red
  def isBlack(node: Node[_, _]) = (node eq null) || !node.red

  // ---- size ----

  def size(node: Node[_, _]): Int = if (node eq null) 0 else 1 + size(node.left) + size(node.right)
  def size(tree: Tree[_, _]): Int = tree.size
  def isEmpty(tree: Tree[_, _]) = tree.root eq null
  def clear(tree: Tree[_, _]): Unit = { tree.root = null; tree.size = 0 }

  // ---- search ----

  def get[A: Ordering, B](tree: Tree[A, B], key: A): Option[B] = getNode(tree.root, key) match {
    case null => None
    case node => Some(node.value)
  }

  def getKey[A : Ordering](tree: Tree[A, _], key: A): Option[A] =
    getNode(tree.root, key) match {
      case null => None
      case node => Some(node.key)
    }

  @tailrec private[this] def getNode[A, B](node: Node[A, B], key: A)(implicit ord: Ordering[A]): Node[A, B] =
    if (node eq null) null
    else {
      val cmp = ord.compare(key, node.key)
      if (cmp < 0) getNode(node.left, key)
      else if (cmp > 0) getNode(node.right, key)
      else node
    }

  def contains[A: Ordering](tree: Tree[A, _], key: A) = getNode(tree.root, key) ne null

  def min[A, B](tree: Tree[A, B]): Option[(A, B)] = minNode(tree.root) match {
    case null => None
    case node => Some((node.key, node.value))
  }

  def minKey[A](tree: Tree[A, _]): Option[A] = minNode(tree.root) match {
    case null => None
    case node => Some(node.key)
  }

  private def minNode[A, B](node: Node[A, B]): Node[A, B] =
    if (node eq null) null else minNodeNonNull(node)

  @tailrec def minNodeNonNull[A, B](node: Node[A, B]): Node[A, B] =
    if (node.left eq null) node else minNodeNonNull(node.left)

  def max[A, B](tree: Tree[A, B]): Option[(A, B)] = maxNode(tree.root) match {
    case null => None
    case node => Some((node.key, node.value))
  }

  def maxKey[A](tree: Tree[A, _]): Option[A] = maxNode(tree.root) match {
    case null => None
    case node => Some(node.key)
  }

  private def maxNode[A, B](node: Node[A, B]): Node[A, B] =
    if (node eq null) null else maxNodeNonNull(node)

  @tailrec def maxNodeNonNull[A, B](node: Node[A, B]): Node[A, B] =
    if (node.right eq null) node else maxNodeNonNull(node.right)

  /**
   * Returns the first (lowest) map entry with a key equal or greater than `key`. Returns `None` if there is no such
   * node.
   */
  def minAfter[A, B](tree: Tree[A, B], key: A)(implicit ord: Ordering[A]): Option[(A, B)] =
    minNodeAfter(tree.root, key) match {
      case null => None
      case node => Some((node.key, node.value))
    }

  def minKeyAfter[A](tree: Tree[A, _], key: A)(implicit ord: Ordering[A]): Option[A] =
    minNodeAfter(tree.root, key) match {
      case null => None
      case node => Some(node.key)
    }

  private[this] def minNodeAfter[A, B](node: Node[A, B], key: A)(implicit ord: Ordering[A]): Node[A, B] = {
    if (node eq null) null
    else {
      var y: Node[A, B] = null
      var x = node
      var cmp = 1
      while ((x ne null) && cmp != 0) {
        y = x
        cmp = ord.compare(key, x.key)
        x = if (cmp < 0) x.left else x.right
      }
      if (cmp <= 0) y else successor(y)
    }
  }

  /**
   * Returns the last (highest) map entry with a key smaller than `key`. Returns `None` if there is no such node.
   */
  def maxBefore[A, B](tree: Tree[A, B], key: A)(implicit ord: Ordering[A]): Option[(A, B)] =
    maxNodeBefore(tree.root, key) match {
      case null => None
      case node => Some((node.key, node.value))
    }

  def maxKeyBefore[A](tree: Tree[A, _], key: A)(implicit ord: Ordering[A]): Option[A] =
    maxNodeBefore(tree.root, key) match {
      case null => None
      case node => Some(node.key)
    }

  private[this] def maxNodeBefore[A, B](node: Node[A, B], key: A)(implicit ord: Ordering[A]): Node[A, B] = {
    if (node eq null) null
    else {
      var y: Node[A, B] = null
      var x = node
      var cmp = 1
      while ((x ne null) && cmp != 0) {
        y = x
        cmp = ord.compare(key, x.key)
        x = if (cmp < 0) x.left else x.right
      }
      if (cmp > 0) y else predecessor(y)
    }
  }

  // ---- insertion ----

  def insert[A, B](tree: Tree[A, B], key: A, value: B)(implicit ord: Ordering[A]): Unit = {
    var y: Node[A, B] = null
    var x = tree.root
    var cmp = 1
    while ((x ne null) && cmp != 0) {
      y = x
      cmp = ord.compare(key, x.key)
      x = if (cmp < 0) x.left else x.right
    }

    if (cmp == 0) y.value = value
    else {
      val z = Node.leaf(key, value, red = true, y)

      if (y eq null) tree.root = z
      else if (cmp < 0) y.left = z
      else y.right = z

      fixAfterInsert(tree, z)
      tree.size += 1
    }
  }

  private[this] def fixAfterInsert[A, B](tree: Tree[A, B], node: Node[A, B]): Unit = {
    var z = node
    while (isRed(z.parent)) {
      if (z.parent eq z.parent.parent.left) {
        val y = z.parent.parent.right
        if (isRed(y)) {
          z.parent.red = false
          y.red = false
          z.parent.parent.red = true
          z = z.parent.parent
        } else {
          if (z eq z.parent.right) {
            z = z.parent
            rotateLeft(tree, z)
          }
          z.parent.red = false
          z.parent.parent.red = true
          rotateRight(tree, z.parent.parent)
        }
      } else { // symmetric cases
        val y = z.parent.parent.left
        if (isRed(y)) {
          z.parent.red = false
          y.red = false
          z.parent.parent.red = true
          z = z.parent.parent
        } else {
          if (z eq z.parent.left) {
            z = z.parent
            rotateRight(tree, z)
          }
          z.parent.red = false
          z.parent.parent.red = true
          rotateLeft(tree, z.parent.parent)
        }
      }
    }
    tree.root.red = false
  }

  // ---- deletion ----

  def delete[A, B](tree: Tree[A, B], key: A)(implicit ord: Ordering[A]): Unit = {
    val z = getNode(tree.root, key)
    if (z ne null) {
      var y = z
      var yIsRed = y.red
      var x: Node[A, B] = null
      var xParent: Node[A, B] = null

      if (z.left eq null) {
        x = z.right
        transplant(tree, z, z.right)
        xParent = z.parent
      }
      else if (z.right eq null) {
        x = z.left
        transplant(tree, z, z.left)
        xParent = z.parent
      }
      else {
        y = minNodeNonNull(z.right)
        yIsRed = y.red
        x = y.right

        if (y.parent eq z) xParent = y
        else {
          xParent = y.parent
          transplant(tree, y, y.right)
          y.right = z.right
          y.right.parent = y
        }
        transplant(tree, z, y)
        y.left = z.left
        y.left.parent = y
        y.red = z.red
      }

      if (!yIsRed) fixAfterDelete(tree, x, xParent)
      tree.size -= 1
    }
  }

  private[this] def fixAfterDelete[A, B](tree: Tree[A, B], node: Node[A, B], parent: Node[A, B]): Unit = {
    var x = node
    var xParent = parent
    while ((x ne tree.root) && isBlack(x)) {
      if (x eq xParent.left) {
        var w = xParent.right
        // assert(w ne null)

        if (w.red) {
          w.red = false
          xParent.red = true
          rotateLeft(tree, xParent)
          w = xParent.right
        }
        if (isBlack(w.left) && isBlack(w.right)) {
          w.red = true
          x = xParent
        } else {
          if (isBlack(w.right)) {
            w.left.red = false
            w.red = true
            rotateRight(tree, w)
            w = xParent.right
          }
          w.red = xParent.red
          xParent.red = false
          w.right.red = false
          rotateLeft(tree, xParent)
          x = tree.root
        }
      } else { // symmetric cases
        var w = xParent.left
        // assert(w ne null)

        if (w.red) {
          w.red = false
          xParent.red = true
          rotateRight(tree, xParent)
          w = xParent.left
        }
        if (isBlack(w.right) && isBlack(w.left)) {
          w.red = true
          x = xParent
        } else {
          if (isBlack(w.left)) {
            w.right.red = false
            w.red = true
            rotateLeft(tree, w)
            w = xParent.left
          }
          w.red = xParent.red
          xParent.red = false
          w.left.red = false
          rotateRight(tree, xParent)
          x = tree.root
        }
      }
      xParent = x.parent
    }
    if (x ne null) x.red = false
  }

  // ---- helpers ----

  /**
   * Returns the node that follows `node` in an in-order tree traversal. If `node` has the maximum key (and is,
   * therefore, the last node), this method returns `null`.
   */
  private[this] def successor[A, B](node: Node[A, B]): Node[A, B] = {
    if (node.right ne null) minNodeNonNull(node.right)
    else {
      var x = node
      var y = x.parent
      while ((y ne null) && (x eq y.right)) {
        x = y
        y = y.parent
      }
      y
    }
  }

  /**
   * Returns the node that precedes `node` in an in-order tree traversal. If `node` has the minimum key (and is,
   * therefore, the first node), this method returns `null`.
   */
  private[this] def predecessor[A, B](node: Node[A, B]): Node[A, B] = {
    if (node.left ne null) maxNodeNonNull(node.left)
    else {
      var x = node
      var y = x.parent
      while ((y ne null) && (x eq y.left)) {
        x = y
        y = y.parent
      }
      y
    }
  }

  private[this] def rotateLeft[A, B](tree: Tree[A, B], x: Node[A, B]): Unit = if (x ne null) {
    // assert(x.right ne null)
    val y = x.right
    x.right = y.left

    if (y.left ne null) y.left.parent = x
    y.parent = x.parent

    if (x.parent eq null) tree.root = y
    else if (x eq x.parent.left) x.parent.left = y
    else x.parent.right = y

    y.left = x
    x.parent = y
  }

  private[this] def rotateRight[A, B](tree: Tree[A, B], x: Node[A, B]): Unit = if (x ne null) {
    // assert(x.left ne null)
    val y = x.left
    x.left = y.right

    if (y.right ne null) y.right.parent = x
    y.parent = x.parent

    if (x.parent eq null) tree.root = y
    else if (x eq x.parent.right) x.parent.right = y
    else x.parent.left = y

    y.right = x
    x.parent = y
  }

  /**
   * Transplant the node `from` to the place of node `to`. This is done by setting `from` as a child of `to`'s previous
   * parent and setting `from`'s parent to the `to`'s previous parent. The children of `from` are left unchanged.
   */
  private[this] def transplant[A, B](tree: Tree[A, B], to: Node[A, B], from: Node[A, B]): Unit = {
    if (to.parent eq null) tree.root = from
    else if (to eq to.parent.left) to.parent.left = from
    else to.parent.right = from

    if (from ne null) from.parent = to.parent
  }

  // ---- tree traversal ----

  def foreach[A, B, U](tree: Tree[A, B], f: ((A, B)) => U): Unit = foreachNode(tree.root, f)

  private[this] def foreachNode[A, B, U](node: Node[A, B], f: ((A, B)) => U): Unit =
    if (node ne null) foreachNodeNonNull(node, f)

  private[this] def foreachNodeNonNull[A, B, U](node: Node[A, B], f: ((A, B)) => U): Unit = {
    if (node.left ne null) foreachNodeNonNull(node.left, f)
    f((node.key, node.value))
    if (node.right ne null) foreachNodeNonNull(node.right, f)
  }

  def foreachKey[A, U](tree: Tree[A, _], f: A => U): Unit = foreachNodeKey(tree.root, f)

  private[this] def foreachNodeKey[A, U](node: Node[A, _], f: A => U): Unit =
    if (node ne null) foreachNodeKeyNonNull(node, f)

  private[this] def foreachNodeKeyNonNull[A, U](node: Node[A, _], f: A => U): Unit = {
    if (node.left ne null) foreachNodeKeyNonNull(node.left, f)
    f(node.key)
    if (node.right ne null) foreachNodeKeyNonNull(node.right, f)
  }

  def transform[A, B](tree: Tree[A, B], f: (A, B) => B): Unit = transformNode(tree.root, f)

  private[this] def transformNode[A, B, U](node: Node[A, B], f: (A, B) => B): Unit =
    if (node ne null) transformNodeNonNull(node, f)

  private[this] def transformNodeNonNull[A, B, U](node: Node[A, B], f: (A, B) => B): Unit = {
    if (node.left ne null) transformNodeNonNull(node.left, f)
    node.value = f(node.key, node.value)
    if (node.right ne null) transformNodeNonNull(node.right, f)
  }

  def iterator[A: Ordering, B](tree: Tree[A, B], start: Option[A] = None, end: Option[A] = None): Iterator[(A, B)] =
    new EntriesIterator(tree, start, end)

  def keysIterator[A: Ordering](tree: Tree[A, _], start: Option[A] = None, end: Option[A] = None): Iterator[A] =
    new KeysIterator(tree, start, end)

  def valuesIterator[A: Ordering, B](tree: Tree[A, B], start: Option[A] = None, end: Option[A] = None): Iterator[B] =
    new ValuesIterator(tree, start, end)

  private[this] abstract class TreeIterator[A, B, R](tree: Tree[A, B], start: Option[A], end: Option[A])
                                                    (implicit ord: Ordering[A]) extends Iterator[R] {

    protected def nextResult(node: Node[A, B]): R

    def hasNext: Boolean = nextNode ne null

    @throws[NoSuchElementException]
    def next(): R = nextNode match {
      case null => throw new NoSuchElementException("next on empty iterator")
      case node =>
        nextNode = successor(node)
        setNullIfAfterEnd()
        nextResult(node)
    }

    private[this] var nextNode: Node[A, B] = start match {
      case None => minNode(tree.root)
      case Some(from) => minNodeAfter(tree.root, from)
    }

    private[this] def setNullIfAfterEnd(): Unit =
      if (end.isDefined && (nextNode ne null) && ord.compare(nextNode.key, end.get) >= 0)
        nextNode = null

    setNullIfAfterEnd()
  }

  private[this] final class EntriesIterator[A: Ordering, B](tree: Tree[A, B], start: Option[A], end: Option[A])
    extends TreeIterator[A, B, (A, B)](tree, start, end) {

    def nextResult(node: Node[A, B]) = (node.key, node.value)
  }

  private[this] final class KeysIterator[A: Ordering, B](tree: Tree[A, B], start: Option[A], end: Option[A])
    extends TreeIterator[A, B, A](tree, start, end) {

    def nextResult(node: Node[A, B]) = node.key
  }

  private[this] final class ValuesIterator[A: Ordering, B](tree: Tree[A, B], start: Option[A], end: Option[A])
    extends TreeIterator[A, B, B](tree, start, end) {

    def nextResult(node: Node[A, B]) = node.value
  }

  // ---- debugging ----

  /**
   * Checks if the tree is in a valid state. That happens if:
   * - It is a valid binary search tree;
   * - All red-black properties are satisfied;
   * - All non-null nodes have their `parent` reference correct;
   * - The size variable in `tree` corresponds to the actual size of the tree.
   */
  def isValid[A: Ordering, B](tree: Tree[A, B]): Boolean =
    isValidBST(tree.root) && hasProperParentRefs(tree) && isValidRedBlackTree(tree) && size(tree.root) == tree.size

  /**
   * Returns true if all non-null nodes have their `parent` reference correct.
   */
  private[this] def hasProperParentRefs[A, B](tree: Tree[A, B]): Boolean = {

    def hasProperParentRefs(node: Node[A, B]): Boolean = {
      if (node eq null) true
      else {
        if ((node.left ne null) && (node.left.parent ne node) ||
          (node.right ne null) && (node.right.parent ne node)) false
        else hasProperParentRefs(node.left) && hasProperParentRefs(node.right)
      }
    }

    if(tree.root eq null) true
    else (tree.root.parent eq null) && hasProperParentRefs(tree.root)
  }

  /**
   * Returns true if this node follows the properties of a binary search tree.
   */
  private[this] def isValidBST[A, B](node: Node[A, B])(implicit ord: Ordering[A]): Boolean = {
    if (node eq null) true
    else {
      if ((node.left ne null) && (ord.compare(node.key, node.left.key) <= 0) ||
        (node.right ne null) && (ord.compare(node.key, node.right.key) >= 0)) false
      else isValidBST(node.left) && isValidBST(node.right)
    }
  }

  /**
   * Returns true if the tree has all the red-black tree properties: if the root node is black, if all children of red
   * nodes are black and if the path from any node to any of its null children has the same number of black nodes.
   */
  private[this] def isValidRedBlackTree[A, B](tree: Tree[A, B]): Boolean = {

    def noRedAfterRed(node: Node[A, B]): Boolean = {
      if (node eq null) true
      else if (node.red && (isRed(node.left) || isRed(node.right))) false
      else noRedAfterRed(node.left) && noRedAfterRed(node.right)
    }

    def blackHeight(node: Node[A, B]): Int = {
      if (node eq null) 1
      else {
        val lh = blackHeight(node.left)
        val rh = blackHeight(node.right)

        if (lh == -1 || lh != rh) -1
        else if (isRed(node)) lh
        else lh + 1
      }
    }

    isBlack(tree.root) && noRedAfterRed(tree.root) && blackHeight(tree.root) >= 0
  }
}
