Created
December 14, 2012 10:31
-
-
Save giuniu/4284371 to your computer and use it in GitHub Desktop.
「アルゴリズムを学ぼう」よりAVL木の実装。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import static java.lang.Math.max; | |
import static java.lang.Math.min; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.Random; | |
public class AVLTree<T extends Comparable<? super T>> { | |
private Node root; | |
private class Node { | |
private T _value; | |
private Node _left; | |
private Node _right; | |
private int _balance; | |
T value() { | |
return _value; | |
} | |
void value(T v) { | |
_value = v; | |
} | |
Node left() { | |
return _left; | |
} | |
void setLeft(Node node) { | |
_left = node; | |
} | |
void setRight(Node node) { | |
_right = node; | |
} | |
Node right() { | |
return _right; | |
} | |
int balance() { | |
return _balance; | |
} | |
void addBalance(int diff) { | |
_balance += diff; | |
} | |
private Node(T value) { | |
this._value = value; | |
} | |
@Override | |
public String toString() { | |
return _value + (_left == null && _right == null ? "" : "<" + _balance + ">(" + _left + "," + _right + ")"); | |
} | |
} | |
public boolean contains(T v) { | |
if (v == null) | |
throw new IllegalArgumentException("値がnull"); | |
return contains(root, v); | |
} | |
private boolean contains(Node node, T v) { | |
if (node == null) | |
return false; | |
int comp = node.value().compareTo(v); | |
if (comp == 0) | |
return true; | |
else if (comp > 0) | |
return contains(node.left(), v); | |
else | |
return contains(node.right(), v); | |
} | |
public void insert(T v) { | |
if (v == null) | |
throw new IllegalArgumentException("値がnull"); | |
if (root == null) { | |
root = new Node(v); | |
return; | |
} | |
insert(null, root, v); | |
} | |
private int insert(Node parent, Node node, T v) { | |
assert node != null; | |
if (node.value().compareTo(v) >= 0) { | |
int diff; | |
if (node.left() == null) { | |
node.setLeft(new Node(v)); | |
diff = 1; | |
} else { | |
diff = insert(node, node.left(), v); | |
} | |
return achieveBalance(parent, node, diff, 0); | |
} else { | |
int diff; | |
if (node.right() == null) { | |
node.setRight(new Node(v)); | |
diff = 1; | |
} else { | |
diff = insert(node, node.right(), v); | |
} | |
return achieveBalance(parent, node, 0, diff); | |
} | |
} | |
public void remove(T v) { | |
if (v == null) | |
throw new IllegalArgumentException("値がnull"); | |
remove(null, root, v); | |
} | |
private int remove(Node parent, Node node, T v) { | |
if (node == null) | |
return 0; | |
if (node.value().compareTo(v) > 0) { | |
int diff = remove(node, node.left(), v); | |
return achieveBalance(parent, node, diff, 0); | |
} else if (node.value().compareTo(v) < 0) { | |
int diff = remove(node, node.right(), v); | |
return achieveBalance(parent, node, 0, diff); | |
} else { | |
assert node.value().equals(v); | |
if (node.left() != null) { | |
int diff = findMaxAndRemove(node, v); | |
return achieveBalance(parent, node, diff, 0); | |
} else if (node.right() != null) { | |
int diff = findMinAndRemove(node, v); | |
return achieveBalance(parent, node, 0, diff); | |
} else { | |
replace(parent, node, null); | |
return -1; | |
} | |
} | |
} | |
private int findMaxAndRemove(Node node, T v) { | |
Node maxParent = node; | |
Node max = node.left(); | |
while (max.right() != null) { | |
maxParent = max; | |
max = max.right(); | |
} | |
if (max.left() != null) | |
replace(maxParent, max, rotateRight(max)); | |
swapValue(node, max); | |
return remove(node, node.left(), v); | |
} | |
private int findMinAndRemove(Node node, T v) { | |
Node minParent = node; | |
Node min = node.right(); | |
while (min.left() != null) { | |
minParent = min; | |
min = min.left(); | |
} | |
if (min.right() != null) | |
replace(minParent, min, rotateLeft(min)); | |
swapValue(node, min); | |
return remove(node, node.right(), v); | |
} | |
private void swapValue(Node n1, Node n2) { | |
assert n1 != null && n2 != null; | |
T v = n1.value(); | |
n1.value(n2.value()); | |
n2.value(v); | |
} | |
private int achieveBalance(Node parent, Node node, int leftDiff, int rightDiff) { | |
assert (-1 <= node.balance() && node.balance() <= 1); | |
if (leftDiff == 0 && rightDiff == 0) | |
return 0; | |
int diff = 0; | |
if ((leftDiff > 0 && node.balance() <= 0) || (rightDiff > 0 && node.balance() >= 0)) | |
diff++; | |
if ((leftDiff < 0 && node.balance() < 0) || (rightDiff < 0 && node.balance() > 0)) | |
diff--; | |
node.addBalance(rightDiff - leftDiff); | |
assert (-2 <= node.balance() && node.balance() <= 2); | |
if (node.balance() == -2) { | |
if (node.left().balance() != 0) | |
diff--; | |
if (node.left().balance() == 1) | |
replace(node, node.left(), rotateLeft(node.left())); | |
replace(parent, node, rotateRight(node)); | |
} else if (node.balance() == 2) { | |
if (node.right().balance() != 0) | |
diff--; | |
if (node.right().balance() == -1) | |
replace(node, node.right(), rotateRight(node.right())); | |
replace(parent, node, rotateLeft(node)); | |
} | |
return diff; | |
} | |
/** | |
* 引数のノードをトップとする部分ツリーを左回転する | |
* | |
* @param node 回転するトップのノード | |
* @return 回転後のトップのノード | |
*/ | |
private Node rotateLeft(Node node) { | |
assert node != null; | |
Node result = node.right(); | |
assert result != null; | |
node.setRight(result.left()); | |
result.setLeft(node); | |
int a = node.balance(); | |
int b = result.balance(); | |
int aa = a - max(b, 0) - 1; | |
node.addBalance(aa - a); | |
result.addBalance(min(aa, 0) - 1); | |
return result; | |
} | |
/** | |
* 引数のノードをトップとする部分ツリーを右回転する | |
* | |
* @param node 回転するトップのノード | |
* @return 回転後のトップのノード | |
*/ | |
private Node rotateRight(Node node) { | |
assert node != null; | |
Node result = node.left(); | |
assert result != null; | |
node.setLeft(result.right()); | |
result.setRight(node); | |
int a = node.balance(); | |
int b = result.balance(); | |
int aa = a - min(b, 0) + 1; | |
node.addBalance(aa - a); | |
result.addBalance(max(aa, 0) + 1); | |
return result; | |
} | |
/** | |
* parentの子のノードをoldNodeからnewNodeに入れ替える。 | |
* | |
* @param parent | |
* @param oldNode | |
* @param newNode | |
* @throws IllegalArgumentException parentの子にoldNodeが無かった時 | |
*/ | |
private void replace(Node parent, Node oldNode, Node newNode) { | |
if (parent == null) { | |
root = newNode; | |
return; | |
} | |
if (parent.left() == oldNode) | |
parent.setLeft(newNode); | |
else if (parent.right() == oldNode) | |
parent.setRight(newNode); | |
else | |
throw new IllegalArgumentException("子供が見つからなかった"); | |
} | |
void selfCheck() { | |
check(root); | |
} | |
private int check(Node node) { | |
if (node == null) | |
return 0; | |
int lRank = check(node.left()); | |
int rRank = check(node.right()); | |
if (node.balance() != rRank - lRank) | |
// なにかおかしい | |
throw new IllegalStateException("rRank:" + rRank + " lRank:" + lRank + " node:" + node); | |
return max(lRank, rRank) + 1; | |
} | |
@Override | |
public String toString() { | |
return root == null ? null : root.toString(); | |
} | |
public static void main(String[] args) { | |
int MAX = 30000; | |
List<Integer> list = new ArrayList<>(); | |
for (int i = 1; i <= MAX; i++) | |
list.add(i); | |
AVLTree<Integer> tree = new AVLTree<>(); | |
Random rand = new Random(); | |
List<Integer> tmp = new ArrayList<>(list); | |
for (int i = MAX; i >= 1; i--) { | |
tree.insert(tmp.remove(rand.nextInt(i))); | |
tree.selfCheck(); | |
} | |
System.out.println(tree); | |
System.out.println(tree.contains(18)); | |
System.out.println(tree.contains(0)); | |
tmp = new ArrayList<>(list); | |
for (int i = MAX; i >= 1; i--) { | |
tree.remove(tmp.remove(rand.nextInt(i))); | |
tree.selfCheck(); | |
} | |
System.out.println(tree); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment