(Set : SetType) =
struct
open UnionFind
let disjoint s1 s2 =
Set.is_empty (Set.inter s1 s2)
let subset s1 s2 =
Set.is_empty (Set.diff s1 s2)
type set =
Set.t
type term =
descriptor point
and descriptor =
| Variable of set
| Constant of set
| DisjointSum of set * term
let i2s =
string_of_int
let hashed_terms = ref []
let new_name v =
let name = "V"^ (string_of_int (List.length !hashed_terms)) in
hashed_terms := (v, name) :: !hashed_terms;
name
let name v =
let rec chop = function
[] -> new_name v
| (v', n) :: q when UnionFind.equivalent v v' -> n
| _ :: q -> chop q
in
chop !hashed_terms
let rec normalize node =
match find node with
| DisjointSum (s1, node2) as desc -> (
match normalize node2 with
| Constant s2 ->
change node (Constant (Set.union s1 s2))
| DisjointSum (s2, tail2) ->
change node (DisjointSum (Set.union s1 s2, tail2))
| Variable _ ->
desc
)
| desc ->
desc
let print node =
match normalize node with
| Variable _ ->
name node
| Constant set ->
Set.print set
| DisjointSum (set, node) ->
Printf.sprintf "(%s+%s)" (Set.print set) (name node)
exception Error
let rec impose restriction node =
match normalize node with
| Constant s ->
if not (disjoint s restriction) then
raise Error
| Variable forbidden ->
let forbidden' = Set.union forbidden restriction in
if not (Set.equal forbidden forbidden') then
ignore (change node (Variable forbidden'))
| DisjointSum (s, node) ->
if not (disjoint s restriction) then
raise Error;
impose restriction node
let check x node =
match normalize node with
| Constant _ ->
()
| Variable _ ->
if equivalent x node then
raise Error
| DisjointSum (_, node) ->
if equivalent x node then
raise Error
let variable forbidden =
fresh (Variable forbidden)
let constant s =
fresh (Constant s)
let empty =
constant Set.empty
let _sum s node =
impose s node;
fresh (DisjointSum (s, node))
let sum s node =
if Set.is_empty s then
node
else
_sum s node
let rec unify node1 node2 =
if not (equivalent node1 node2) then
match normalize node1, normalize node2, node1, node2 with
| Variable forbidden1, _, node1, node2
| _, Variable forbidden1, node2, node1 ->
impose forbidden1 node2;
check node1 node2;
union node1 node2
| Constant s1, Constant s2, _, _ ->
if not (Set.equal s1 s2) then
raise Error;
union node1 node2
| Constant s1, DisjointSum (s2, tail2), node1, node2
| DisjointSum (s2, tail2), Constant s1, node2, node1 ->
if not (subset s2 s1) then
raise Error;
unify tail2 (constant (Set.diff s1 s2));
union node2 node1
| DisjointSum (s1, tail1), DisjointSum (s2, tail2), _, _ ->
let s1 = Set.diff s1 s2
and s2 = Set.diff s2 s1 in
begin
match Set.is_empty s1, Set.is_empty s2 with
| true, true ->
unify tail1 tail2
| false, true ->
unify (_sum s1 tail1) tail2
| true, false ->
unify tail1 (_sum s2 tail2)
| false, false ->
let tail = variable (Set.union s1 s2) in
unify tail1 (_sum s2 tail);
unify tail2 (_sum s1 tail)
end;
union node1 node2
let rec default node =
match normalize node with
| Variable _ ->
union node empty
| DisjointSum (_, node) ->
union node empty
| Constant _ ->
()
let svariable () =
variable Set.empty
let print_descriptor = function
| Variable s -> "V("^ Set.print s ^")"
| Constant s -> "Cst("^ Set.print s ^")"
| DisjointSum (s1, s2) -> "("^ Set.print s1 ^") + (" ^ print s2 ^ ")"
let print v = print_descriptor (UnionFind.find v)
end