Created
September 25, 2015 22:46
-
-
Save anonymous/e82b20caf5883932581f to your computer and use it in GitHub Desktop.
GADT enforced AVL tree
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* fuck license, fuck all copyrights, fuck this society *) | |
type compare = LessThan | Equal | GreaterThan | |
module RawAVLTree = struct | |
type z = Z : z | |
type 'n s = S : 'n -> 'n s | |
(* Depths of branches of a AVL tree node differs at most by 1 *) | |
type (_, _, _) diff = | |
| Less : ('a, 'a s, 'a s) diff | |
| Same : ('a, 'a, 'a) diff | |
| More : ('a s, 'a, 'a s) diff | |
type ('a, 'd) atree = | |
| Empty : ('a, z) atree | |
| Tree : ('a, 'm) atree * 'a * ('a, 'n) atree | |
* ('m, 'n, 'o) diff -> ('a, 'o s) atree | |
let rec member : type d. ('a -> 'a -> compare) -> 'a -> ('a, d) atree -> bool = | |
fun cmp ele t -> | |
match t with | |
| Empty -> false | |
| Tree (l, k, r, _) -> | |
match cmp ele k with | |
| LessThan -> member cmp ele l | |
| Equal -> true | |
| GreaterThan -> member cmp ele r | |
type ('a, 'd) result = | |
| SameDepth : ('a, 'd) atree -> ('a, 'd) result | |
| Deeper : ('a, 'd s) atree -> ('a, 'd) result | |
let rotate_left : type d. ('a, d) atree -> 'a -> ('a, d s s) atree -> ('a, d s s) result = | |
fun l v r -> | |
let Tree (rl, rv, rr, diff) = r in | |
match diff with | |
| Less -> SameDepth (Tree (Tree (l, v, rl, Same), rv, rr, Same)) | |
| Same -> Deeper (Tree (Tree (l, v, rl, Less), rv, rr, More)) | |
| More -> begin | |
let Tree (rll, rlv, rlr, diffl) = rl in | |
match diffl with | |
| Less -> SameDepth (Tree (Tree (l, v, rll, More), rlv, Tree (rlr, rv, rr, Same), Same)) | |
| Same -> SameDepth (Tree (Tree (l, v, rll, Same), rlv, Tree (rlr, rv, rr, Same), Same)) | |
| More -> SameDepth (Tree (Tree (l, v, rll, Same), rlv, Tree (rlr, rv, rr, Less), Same)) | |
end | |
let rotate_right : type d. ('a, d s s) atree -> 'a -> ('a, d) atree -> ('a, d s s) result = | |
fun l v r -> | |
let Tree (ll, lv, lr, diff) = l in | |
match diff with | |
| More -> SameDepth (Tree (ll, lv, (Tree (lr, v, r, Same)), Same)) | |
| Same -> Deeper (Tree (ll, lv, (Tree (lr, v, r, More)), Less)) | |
| Less -> begin | |
let Tree (lrl, lrv, lrr, diffr) = lr in | |
match diffr with | |
| Less -> SameDepth (Tree (Tree (ll, lv, lrl, More), lrv, Tree (lrr, v, r, Same), Same)) | |
| Same -> SameDepth (Tree (Tree (ll, lv, lrl, Same), lrv, Tree (lrr, v, r, Same), Same)) | |
| More -> SameDepth (Tree (Tree (ll, lv, lrl, Same), lrv, Tree (lrr, v, r, Less), Same)) | |
end | |
let rec insert : type d. ('a -> 'a -> compare) -> 'a -> ('a, d) atree -> ('a, d) result = | |
fun cmp v t -> | |
match t with | |
| Empty -> Deeper (Tree (Empty, v, Empty, Same)) | |
| Tree (l, tv, r, diff) -> | |
match cmp v tv with | |
| LessThan -> begin | |
match insert cmp v l with | |
| Deeper t' -> begin | |
match diff with | |
| More -> rotate_right t' tv r | |
| Same -> Deeper (Tree (t', v, r, More)) | |
| Less -> SameDepth (Tree (t', v, r, Same)) | |
end | |
| SameDepth t' -> SameDepth (Tree (t', v, r, diff)) | |
end | |
| Equal -> SameDepth t | |
| GreaterThan -> begin | |
match insert cmp v r with | |
| Deeper t' -> begin | |
match diff with | |
| Less -> rotate_left l tv t' | |
| Same -> Deeper (Tree (l, v, t', Less)) | |
| More -> SameDepth (Tree (l, v, t', Same)) | |
end | |
| SameDepth t' -> SameDepth (Tree (l, v, t', diff)) | |
end | |
end | |
module type Set = sig | |
type t | |
type elem | |
val empty : t | |
val member : elem -> t -> bool | |
val insert : elem -> t -> t | |
end | |
module Set (X : sig type t val compare : t -> t -> compare end) | |
: Set with type elem := X.t = struct | |
type t = T : (X.t, _) RawAVLTree.atree -> t | |
type elem = X.t | |
let empty = T (RawAVLTree.Empty) | |
let member e (T t) = RawAVLTree.(member X.compare e t) | |
let insert e (T t) = RawAVLTree.( | |
match insert X.compare e t with | |
| SameDepth t -> T t | |
| Deeper t -> T t) | |
end | |
module IntCompare = struct | |
type t = int | |
let compare i j = | |
match () with | |
| _ when i < j -> LessThan | |
| _ when i = j -> Equal | |
| _ -> GreaterThan | |
end | |
module IntSet = Set(IntCompare) | |
let () = | |
let open IntSet in | |
let t = empty in | |
let t1 = insert 2 empty in | |
assert (member 2 t1); | |
assert (not (member 3 t)); | |
assert (not (member 3 t)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment