/* avltree.pure: Basic AVL tree example (after Bird/Wadler). */

/* AVL trees are represented using the constructor symbols nil (empty tree)
   and bin h x t1 t2 (AVL tree with height h, top element x and left and right
   subtrees t1 and t2). Example:

   > let t1 = avltree [17,5,26,5]; let t2 = avltree [8,17];
   > members (t1+t2); members (t1-t2); t1-t2;
   [5,5,8,17,17,26]
   [5,5,26]
   bin 2 5 (bin 1 5 nil nil) (bin 1 26 nil nil) */

nullary nil;

/* Public functions. */

avltree xs              = foldl insert nil xs;

avltreep nil            |
avltreep (bin _ _ _ _)  = 1;
avltreep _              = 0;

null nil                = 1;
null (bin _ _ _ _)      = 0;

#nil                    = 0;
#(bin h x t1 t2)        = #t1+#t2+1;

members nil             = [];
members (bin h x t1 t2) = members t1 + (x:members t2);

member nil y            = 0;
member (bin h x t1 t2) y
                        = member t1 y if x>y;
                        = member t2 y if x<y;
                        = 1;

insert nil y            = bin 1 y nil nil;
insert (bin h x t1 t2) y
                        = rebal (mknode x (insert t1 y) t2) if x>y;
                        = rebal (mknode x t1 (insert t2 y));

delete nil y            = nil;
delete (bin h x t1 t2) y
                        = rebal (mknode x (delete t1 y) t2) if x>y;
                        = rebal (mknode x t1 (delete t2 y)) if x<y;
                        = join t1 t2;

/* Implement the usual set operations on AVL trees. */

t1 + t2                 = foldl insert t1 (members t2) if avltreep t1;
t1 - t2                 = foldl delete t1 (members t2) if avltreep t1;
t1 * t2                 = t1-(t1-t2) if avltreep t1;

t1 <= t2                = all (member t2) (members t1) if avltreep t1;
t1 >= t2                = all (member t1) (members t2) if avltreep t1;

t1 < t2                 = t1<=t2 && not t2<=t1 if avltreep t1;
t1 > t2                 = t1>=t2 && not t2>=t1 if avltreep t1;

t1 == t2                = t1<=t2 && t2<=t1 if avltreep t1;
t1 != t2                = not t1==t2 if avltreep t1;

/* Private functions. */

join nil t2             = t2;
join t1@(bin _ _ _ _) t2
                        = rebal (mknode (last t1) (init t1) t2);

init (bin h x t1 nil)   = t1;
init (bin h x t1 t2)    = rebal (mknode x t1 (init t2));

last (bin h x t1 nil)   = x;
last (bin h x t1 t2)    = last t2;

/* mknode constructs an AVL tree node, computing the height value. */

mknode x t1 t2          = bin (max (height t1) (height t2) + 1) x t1 t2;

/* height and slope compute the height and slope (difference between heights
   of the left and the right subtree), respectively. */

height nil              = 0;
height (bin h x t1 t2)  = h;

slope nil               = 0;
slope (bin h x t1 t2)   = height t1 - height t2;

/* rebal rebalances after single insertions and deletions. */

rebal t                 = shl t if slope t == -2;
                        = shr t if slope t == 2;
                        = t;

/* Rotation operations. */

rol (bin h x1 t1 (bin h2 x2 t2 t3))
                        = mknode x2 (mknode x1 t1 t2) t3;

ror (bin h1 x1 (bin h2 x2 t1 t2) t3)
                        = mknode x2 t1 (mknode x1 t2 t3);

shl (bin h x t1 t2)     = rol (mknode x t1 (ror t2)) if slope t2 == 1;
                        = rol (bin h x t1 t2);

shr (bin h x t1 t2)     = ror (mknode x t1 (ror t2)) if slope t2 == -1;
                        = ror (bin h x t1 t2);