/* Pure's set and bag data types based on AVL trees. */

/* Copyright (c) 2008 by Albert Graef <Dr.Graef@t-online.de>.
   Copyright (c) 2008 by Jiri Spitz <jiri.spitz@bluetone.cz>.

   This file is part of the Pure programming language and system.

   Pure is free software: you can redistribute it and/or modify it under the
   terms of the GNU General Public License as published by the Free Software
   Foundation, either version 3 of the License, or (at your option) any later
   version.

   Pure is distributed in the hope that it will be useful, but WITHOUT ANY
   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
   FOR a PARTICULAR PURPOSE.  See the GNU General Public License for more
   details.

   You should have received a copy of the GNU General Public License along
   with this program.  If not, see <http://www.gnu.org/licenses/>. */


/* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
   The used algorithm of AVL trees has its origin in the SWI-Prolog
   implementation of association lists. The original file was created by
   R. A. O'Keefe and updated for the SWI-Prolog by Jan Wielemaker. For the
   original file see http://www.swi-prolog.org.

   The port from SWI-Prolog and the deletion stuff (rmfirst, rmlast, delete)
   missing in the original file was provided by Jiri Spitz
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */


/* Public operations: ******************************************************

emptyset, emptybag:             return the empty set or bag
set xs, bag xs;                 create a set or bag from list xs
setp x, bagp x;                 check whether x is a set or bag

#m                              size of set or bag m

null m                          tests whether m is the empty set or bag
member m x                      tests whether m contains x
members m, list m               list members of m in ascending order

first m, last m                 return first and last member of m
rmfirst m, rmlast m             remove first and last member from m
insert m x                      insert x into m (replace existing element)
delete m x                      remove x from m

   *************************************************************************/


/* Tree constructors. */
private nullary nil;
private bin;

/*****
Tree for set and bag is either:
-  nil  (empty tree) or
-  bin key Balance Left Right  (Left, Right: trees)


Balance: ( 1), ( 0), or (-1) denoting |L|-|R| = 1, 0, or -1, respectively
*****/

private adjustd avl_geq;

// set and bag type checks
bagp (Bag _)                    = 1;
bagp _                          = 0;

setp (Set _)                    = 1;
setp _                          = 0;

// create an empty set or bag
emptyset                        = Set nil;
emptybag                        = Bag nil;

// create set or bag from a list
set xs                          = foldl insert emptyset xs if listp xs;
bag xs                          = foldl insert emptybag xs if listp xs;

// insert a new member into a set or bag
insert (t@Set m) y         |
insert (t@Bag m) y              = t ((insert m y)!0)
with
  insert nil key
        = [(bin key ( 0) nil nil), 1];

  insert (bin k         b::int l r) key
        = [(bin key b l r), 0] if (key == k) && (t === Set);

  insert (bin k         b::int l r) key
        = adjust leftHasChanged (bin k b newL r) (-1)
        when [newL, leftHasChanged] = insert l key end if key < k;

  insert (bin k         b::int l r) key
        = adjust rightHasChanged (bin k b l newR) ( 1)
        when [newR, rightHasChanged] = insert r key end
        if ((key > k) && (t === Set)) || ((key >= k) && (t === Bag));

  adjust 0 oldTree _
        = [oldTree, 0];

  adjust 1 (bin key         b0::int l r) LoR::int
        = [rebal toBeRebalanced (bin key b0 l r) b1, whatHasChanged]
                when
                        [b1, whatHasChanged, toBeRebalanced] = table b0 LoR
                end;

  rebal 0 (bin k _ l r) b
        = bin k b l r;

  rebal 1  oldTree _
        = (avl_geq oldTree)!0;
/*
// Balance rules for insertions
//      balance where           balance   whole tree    to be
//      before  inserted        after     increased     rebalanced
table   ( 0)    (-1)            = [( 1),          1,            0];
table   ( 0)    ( 1)            = [(-1),          1,            0];
table   ( 1)    (-1)            = [( 0),          0,            1];
table   ( 1)    ( 1)            = [( 0),          0,            0];
table   (-1)    (-1)            = [( 0),          0,            0];
table   (-1)    ( 1)            = [( 0),          0,            1];
*/

// table w/o pattern matching
  table bb::int wi::int = [ba, wti, tbr]
        when
                ba      = if bb == 0 then -wi else 0;
                wti     = bb == 0;
                tbr     = (bb + wi) == 0;
        end
end;

// delete a member by key from the data structure
delete (t@Set m) y         |
delete (t@Bag m) y
= t ((delete m y)!0)
with
  delete nil _                  = [nil, 0];

  delete (bin k         _ nil r) key
                                = [r, 1] if key == k;

  delete (bin k         _ l nil) key
                                = [l, 1] if key == k;

  delete (bin k         b::int x@(bin kl         bl::int rl ll) r) key
                        = adjustd leftHasChanged (bin lk b newL r) (-1)
                when
                        lk                      = last x;
                        [newL, leftHasChanged]  = rmlast x
                end
                if key == k;

  delete (bin k         b::int l r) key
                = adjustd leftHasChanged (bin k b newL r) (-1)
                when
                        [newL, leftHasChanged] = delete l key
                end
                if key < k;

  delete (bin k         b::int l r) key
                = adjustd rightHasChanged (bin k b l newR) ( 1)
                when
                        [newR, rightHasChanged] = delete r key
                end
                  if key > k;

  rmlast nil                    = [nil, 0];
  rmlast (bin _ _      l nil)   = [l,   1];
  rmlast (bin k b::int l r  )
                = adjustd rightHasChanged (bin k b l newR) ( 1)
                when [newR, rightHasChanged] = rmlast r end;

  last (bin x _ _ nil)          = x;
  last (bin _ _ _ m2 )          = last m2
end;

// check for the empty set or bag
null (Set nil)          = 1;
null (Set _)            = 0;

null (Bag nil)          = 1;
null (Bag _)            = 0;

// get a number of members in set or bag
#(Set m) |
#(Bag m)                        = #m
with
  #nil                          = 0;
  #(bin _ _ m1 m2)              = #m1 + #m2 + 1
end;

// check whether a key exists in set or bag
member (Set m) k |
member (Bag m) k
= member m k
with
  member nil _                  = 0;

  member (bin x _ m1 m2) y
                                = member m1 y if x > y;
                                = member m2 y if x < y;
                                = 1 if x == y
end;

// get all members of set or bag as a list
members (Set m) |
members (Bag m)
= members m
with
  members nil   = [];

  members (bin x         _ m1 m2)
                                = (members m1) + (x : (members m2))
end;

list m@(Set _) |
list m@(Bag _)
                                = members m;

// get the first member of set or bag
first (Set m) |
first (Bag m)
= first m
with
  first (bin x _ nil _)         = x;
  first (bin _ _ m1  _)         = first m1
end;

// get the last member of set or bag
last (Set m) |
last (Bag m)
= last m
with
  last (bin x _ _ nil)          = x;
  last (bin _ _ _ m2 )          = last m2
end;

// remove the first member from set or bag
rmfirst (t@Set m) |
rmfirst (t@Bag m)
= t ((rmfirst m)!0)
with
  rmfirst nil                   = [nil, 0];
  rmfirst (bin _ _      nil r)  = [r,   1];
  rmfirst (bin k b::int l   r)
        = adjustd leftHasChanged (bin k b newL r) (-1)
        when [newL, leftHasChanged] = rmfirst l end
end;

// remove the last member from set or bag
rmlast (t@Set m) |
rmlast (t@Bag m)
= t ((rmlast m)!0)
with
  rmlast nil                    = [nil, 0];
  rmlast (bin _ _      l nil)   = [l,   1];
  rmlast (bin k b::int l r  )
        = adjustd rightHasChanged (bin k b l newR) ( 1)
        when [newR, rightHasChanged] = rmlast r end
end;

// set and bag relations
m1@(Set _) == m2@(Set _) |
m1@(Bag _) == m2@(Bag _)
                                = (members m1 == members m2);

m1@(Set _) != m2@(Set _) |
m1@(Bag _) != m2@(Bag _)
                                = (members m1 != members m2);

m1@(Set _) <= m2@(Set _)        = all (member m2) (members m1);
m1@(Bag _) <= m2@(Bag _)        = null (m1 - m2);

m1@(Set _) >= m2@(Set _)        = all (member m1) (members m2);
m1@(Bag _) >= m2@(Bag _)        = null (m2 - m1);

m1@(Set _) < m2@(Set _) |
m1@(Bag _) < m2@(Bag _)
                                = if (m1 <= m2) then (m1 != m2) else 0;

m1@(Set _) > m2@(Set _) |
m1@(Bag _) > m2@(Bag _)
                                = if (m1 >= m2) then (m1 != m2) else 0;

// set and bag union
m1@(Set _) + m2@(Set _) |
m1@(Bag _) + m2@(Bag _)
                                = foldl insert m1 (members m2);

// set and bag difference
m1@(Set _) - m2@(Set _) |
m1@(Bag _) - m2@(Bag _)
                                = foldl delete m1 (members m2);

// set and bag intersection
m1@(Set _) * m2@(Set _) |
m1@(Bag _) * m2@(Bag _)
                                = m1 -  (m1 - m2);


/* Private functions. */

adjustd ToF::int tree LoR::int
                                = adjust ToF tree LoR
with
  adjust 0 oldTree _            = [oldTree, 0];

  adjust 1 (bin key         b0::int l r) LoR::int
        = rebal toBeRebalanced (bin key b0 l r) b1 whatHasChanged
        when
                [b1, whatHasChanged, toBeRebalanced] = table b0 LoR;
        end;
/*
   Note that rebali and rebald are not symmetrical. With insertions it is
   sufficient to know the original balance and insertion side in order to
   decide whether the whole tree increases. With deletions it is sometimes not
   sufficient and we need to know which kind of tree rotation took place.
*/
  rebal 0 (bin k         _ l r) b::int whatHasChanged
                                = [bin k b l r, whatHasChanged];

  rebal 1 oldTree _ _ = avl_geq oldTree;

// Balance rules for deletions
/*
//      balance where           balance   whole tree    to be
//      before  deleted         after     decreased     rebalanced
table   ( 0)    ( 1)            = [( 1),        0,              0];
table   ( 0)    (-1)            = [(-1),        0,              0];
table   ( 1)    ( 1)            = [( 0),        1,              1];
//                                              ^^^^
// It depends on the tree pattern in avl_geq whether it really decreases

table   ( 1)    (-1)            = [( 0),        1,              0];
table   (-1)    ( 1)            = [( 0),        1,              0];
table   (-1)    (-1)            = [( 0),        1,              1]
//                                             ^^^^
// It depends on the tree pattern in avl_geq whether it really decreases
*/

// table w/o pattern matching
  table bb wd   = [ba, wtd, tbr]
        when
                ba      = if bb == 0 then wd else 0;
                wtd     = abs bb;
                tbr     = bb == wd;
        end
end;

// Single and double tree rotations - these are common for insert and delete
/*
  The patterns (-1)-(-1), (-1)-( 1), ( 1)-( 1) and ( 1)-(-1) on the LHS always
  change the tree height and these are the only patterns which can happen
  after an insertion. That's the reason why we can use tablei only to decide
  the needed changes.
  The patterns (-1)-( 0) and ( 1)-( 0) do not change the tree height. After a
  deletion any pattern can occur and so we return 1 or 0 as a flag of
  a height change.
*/

avl_geq x = avl_geq x
with
  avl_geq (bin a         (-1) alpha (bin b         (-1) beta gamma))
        = [bin b ( 0) (bin a ( 0) alpha beta) gamma, 1];

  avl_geq (bin a         (-1) alpha (bin b         ( 0) beta gamma))
        = [bin b ( 1) (bin a (-1) alpha beta) gamma, 0];
                // the tree doesn't decrease with this pattern

  avl_geq (bin a         (-1) alpha
        (bin b         ( 1) (bin x         b1 beta gamma) delta))
        = [bin x ( 0) (bin a b2 alpha beta)
           (bin b b3 gamma delta), 1]
        when
                [b2, b3] = table b1
        end;

  avl_geq (bin b         ( 1) (bin a         ( 1) alpha beta) gamma)
        = [bin a ( 0) alpha (bin b ( 0) beta  gamma), 1];

  avl_geq (bin b         ( 1) (bin a         ( 0) alpha beta) gamma)
        = [bin a (-1) alpha (bin b ( 1) beta  gamma), 0];
                // the tree doesn't decrease with this pattern

  avl_geq (bin b         ( 1)
        (bin a         (-1) alpha (bin x         b1 beta gamma)) delta)
        = [bin x ( 0) (bin a b2 alpha beta)
           (bin b b3 gamma delta), 1]
        when
                [b2, b3] = table b1
        end;
/*
  table ( 1)                    = [( 0), (-1)];
  table (-1)                    = [( 1), ( 0)];
  table ( 0)                    = [( 0), ( 0)]
*/

// table w/o pattern matching
  table bal = [b1, b2]
        when
                b1 =   bal == (-1);
                b2 = -(bal ==   1);
        end
end;