From 660e4c8eefd658e9d035aa8abefe765b335a9d63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Glondu?= Date: Sun, 11 Aug 2019 17:55:06 +0200 Subject: Import janest-base_0.12.2-1.debian.tar.xz [dgit import tarball janest-base 0.12.2-1 janest-base_0.12.2-1.debian.tar.xz] --- changelog | 5 ++++ compat | 1 + control | 56 ++++++++++++++++++++++++++++++++++++++++++++ copyright | 30 ++++++++++++++++++++++++ gbp.conf | 2 ++ libbase-ocaml-dev.docs | 1 + libbase-ocaml-dev.install.in | 20 ++++++++++++++++ libbase-ocaml.install.in | 6 +++++ rules | 19 +++++++++++++++ source/format | 1 + watch | 2 ++ 11 files changed, 143 insertions(+) create mode 100644 changelog create mode 100644 compat create mode 100644 control create mode 100644 copyright create mode 100644 gbp.conf create mode 100644 libbase-ocaml-dev.docs create mode 100644 libbase-ocaml-dev.install.in create mode 100644 libbase-ocaml.install.in create mode 100755 rules create mode 100644 source/format create mode 100644 watch diff --git a/changelog b/changelog new file mode 100644 index 0000000..8f9cfea --- /dev/null +++ b/changelog @@ -0,0 +1,5 @@ +janest-base (0.12.2-1) unstable; urgency=medium + + * Initial release (Closes: #934150) + + -- Stéphane Glondu Sun, 11 Aug 2019 17:55:06 +0200 diff --git a/compat b/compat new file mode 100644 index 0000000..48082f7 --- /dev/null +++ b/compat @@ -0,0 +1 @@ +12 diff --git a/control b/control new file mode 100644 index 0000000..6de0dce --- /dev/null +++ b/control @@ -0,0 +1,56 @@ +Source: janest-base +Priority: optional +Maintainer: Debian OCaml Maintainers +Uploaders: + Stéphane Glondu +Build-Depends: + debhelper (>= 12), + ocaml-nox, + dune, + libdune-ocaml-dev, + libsexplib0-ocaml-dev, + dh-ocaml +Standards-Version: 4.4.0 +Section: ocaml +Homepage: https://github.com/janestreet/base +Vcs-Git: https://salsa.debian.org/ocaml-team/janest-base.git +Vcs-Browser: https://salsa.debian.org/ocaml-team/janest-base + +Package: libbase-ocaml-dev +Architecture: any +Depends: + ${ocaml:Depends}, + ${shlibs:Depends}, + ${misc:Depends} +Provides: ${ocaml:Provides} +Recommends: ocaml-findlib +Description: Jane Street's alternative standard library (development) + Base is a standard library for OCaml. It provides a standard set of + general purpose modules that are well-tested, performant, and + fully-portable across any environment that can run OCaml code. Unlike + other standard library projects, Base is meant to be used as a + wholesale replacement of the standard library distributed with the + OCaml compiler. In particular it makes different choices and doesn’t + re-export features that are not fully portable such as I/O, which are + left to other libraries. + . + This package contains development files. + +Package: libbase-ocaml +Architecture: any +Depends: + ${ocaml:Depends}, + ${shlibs:Depends}, + ${misc:Depends} +Provides: ${ocaml:Provides} +Description: Jane Street's alternative standard library (runtime) + Base is a standard library for OCaml. It provides a standard set of + general purpose modules that are well-tested, performant, and + fully-portable across any environment that can run OCaml code. Unlike + other standard library projects, Base is meant to be used as a + wholesale replacement of the standard library distributed with the + OCaml compiler. In particular it makes different choices and doesn’t + re-export features that are not fully portable such as I/O, which are + left to other libraries. + . + This package contains runtime files. diff --git a/copyright b/copyright new file mode 100644 index 0000000..5520dab --- /dev/null +++ b/copyright @@ -0,0 +1,30 @@ +Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ + +Files: * +Copyright: (c) 2016-2019 Jane Street Group, LLC +License: MIT + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + . + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + . + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + +Files: src/map.ml src/random.mli src/set.ml +Copyright: (c) 1996 Inria +License: Apache-2.0 + On Debian systems, the full text of the Apache 2.0 license can be + found in `/usr/share/common-licenses/Apache-2.0'. diff --git a/gbp.conf b/gbp.conf new file mode 100644 index 0000000..cec628c --- /dev/null +++ b/gbp.conf @@ -0,0 +1,2 @@ +[DEFAULT] +pristine-tar = True diff --git a/libbase-ocaml-dev.docs b/libbase-ocaml-dev.docs new file mode 100644 index 0000000..d373737 --- /dev/null +++ b/libbase-ocaml-dev.docs @@ -0,0 +1 @@ +usr/doc/base/* diff --git a/libbase-ocaml-dev.install.in b/libbase-ocaml-dev.install.in new file mode 100644 index 0000000..a1da989 --- /dev/null +++ b/libbase-ocaml-dev.install.in @@ -0,0 +1,20 @@ +@OCamlStdlibDir@/base/*dune* +@OCamlStdlibDir@/base/*opam* +@OCamlStdlibDir@/base/*.ml +@OCamlStdlibDir@/base/*.mli +@OCamlStdlibDir@/base/*.cmi +@OCamlStdlibDir@/base/*.cmt +@OCamlStdlibDir@/base/*.cmti +@OCamlStdlibDir@/base/*.h +@OCamlStdlibDir@/base/*.js +OPT: @OCamlStdlibDir@/base/*.a +OPT: @OCamlStdlibDir@/base/*.cmx +OPT: @OCamlStdlibDir@/base/*.cmxa +@OCamlStdlibDir@/base/*/*.ml +@OCamlStdlibDir@/base/*/*.mli +@OCamlStdlibDir@/base/*/*.cmi +@OCamlStdlibDir@/base/*/*.cmt +@OCamlStdlibDir@/base/*/*.cmti +OPT: @OCamlStdlibDir@/base/*/*.a +OPT: @OCamlStdlibDir@/base/*/*.cmx +OPT: @OCamlStdlibDir@/base/*/*.cmxa diff --git a/libbase-ocaml.install.in b/libbase-ocaml.install.in new file mode 100644 index 0000000..cf9a9b6 --- /dev/null +++ b/libbase-ocaml.install.in @@ -0,0 +1,6 @@ +@OCamlDllDir@ +@OCamlStdlibDir@/base/META +@OCamlStdlibDir@/base/*.cma +OPT: @OCamlStdlibDir@/base/*.cmxs +@OCamlStdlibDir@/base/*/*.cma +OPT: @OCamlStdlibDir@/base/*/*.cmxs diff --git a/rules b/rules new file mode 100755 index 0000000..acaa169 --- /dev/null +++ b/rules @@ -0,0 +1,19 @@ +#!/usr/bin/make -f +# -*- makefile -*- + +include /usr/share/ocaml/ocamlvars.mk + +DESTDIR=$(CURDIR)/debian/tmp + +%: + dh $@ --with ocaml + +override_dh_auto_build: + dune build -p base + +override_dh_auto_install: + dune install --destdir=$(DESTDIR) --prefix=/usr --libdir=..$(OCAML_STDLIB_DIR) + rm -f $(DESTDIR)/usr/doc/base/LICENSE.md + +override_dh_missing: + dh_missing --fail-missing -Xdune diff --git a/source/format b/source/format new file mode 100644 index 0000000..163aaf8 --- /dev/null +++ b/source/format @@ -0,0 +1 @@ +3.0 (quilt) diff --git a/watch b/watch new file mode 100644 index 0000000..480623e --- /dev/null +++ b/watch @@ -0,0 +1,2 @@ +version=3 +https://github.com/janestreet/base/releases .*/archive/v(.*)\.tar\.gz -- cgit v1.2.3 From 5c8e8182515d6d1e12c6c7282630a8a7b11143a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Glondu?= Date: Sun, 11 Aug 2019 17:55:06 +0200 Subject: Import janest-base_0.12.2.orig.tar.gz [dgit import orig janest-base_0.12.2.orig.tar.gz] --- .gitignore | 4 + CHANGES.md | 275 ++++ CONTRIBUTING.md | 67 + LICENSE.md | 21 + Makefile | 17 + README.org | 187 +++ ROADMAP.md | 112 ++ base.opam | 35 + compiler-stdlib/gen/dune | 3 + compiler-stdlib/gen/gen.ml | 88 ++ compiler-stdlib/src/dune | 4 + dune-project | 1 + generate/dune | 2 + generate/generate_pow_overflow_bounds.ml | 185 +++ lint/dune | 3 + lint/ppx_base_lint.ml | 105 ++ md5/src/dune | 2 + md5/src/md5_lib.ml | 26 + md5/src/md5_lib.mli | 21 + shadow-stdlib/gen/dune | 4 + shadow-stdlib/gen/gen.ml | 36 + shadow-stdlib/gen/mapper.mll | 267 ++++ shadow-stdlib/src/dune | 10 + shadow-stdlib/src/shadow_stdlib.ml | 1 + src/am_testing.c | 9 + src/am_testing.h | 11 + src/applicative.ml | 159 +++ src/applicative.mli | 1 + src/applicative_intf.ml | 318 +++++ src/array.ml | 731 +++++++++++ src/array.mli | 339 +++++ src/array0.ml | 62 + src/array_permute.ml | 12 + src/avltree.ml | 411 ++++++ src/avltree.mli | 129 ++ src/backtrace.ml | 51 + src/backtrace.mli | 79 ++ src/base.ml | 471 +++++++ src/base.mld | 173 +++ src/binary_search.ml | 109 ++ src/binary_search.mli | 86 ++ src/binary_searchable.ml | 36 + src/binary_searchable.mli | 1 + src/binary_searchable_intf.ml | 76 ++ src/blit.ml | 113 ++ src/blit.mli | 1 + src/blit_intf.ml | 167 +++ src/bool.ml | 89 ++ src/bool.mli | 34 + src/buffer.ml | 28 + src/buffer.mli | 8 + src/buffer_intf.ml | 81 ++ src/bytes.ml | 125 ++ src/bytes.mli | 197 +++ src/bytes0.ml | 61 + src/bytes_tr.ml | 40 + src/char.ml | 105 ++ src/char.mli | 73 ++ src/char0.ml | 39 + src/comparable.ml | 217 ++++ src/comparable.mli | 1 + src/comparable_intf.ml | 217 ++++ src/comparator.ml | 127 ++ src/comparator.mli | 121 ++ src/comparisons.ml | 29 + src/container.ml | 175 +++ src/container.mli | 1 + src/container_intf.ml | 624 ++++++++++ src/discover/discover.ml | 20 + src/discover/discover.mli | 1 + src/discover/dune | 2 + src/dune | 30 + src/either.ml | 289 +++++ src/either.mli | 1 + src/either_intf.ml | 93 ++ src/equal.ml | 41 + src/error.ml | 20 + src/error.mli | 15 + src/exn.ml | 151 +++ src/exn.mli | 95 ++ src/exn_stubs.c | 8 + src/field.ml | 67 + src/field.mli | 39 + src/fieldslib.ml | 3 + src/float.ml | 1099 ++++++++++++++++ src/float.mli | 587 +++++++++ src/float0.ml | 125 ++ src/floatable.ml | 10 + src/fn.ml | 30 + src/fn.mli | 32 + src/formatter.ml | 1 + src/formatter.mli | 9 + src/hash.ml | 231 ++++ src/hash.mli | 1 + src/hash_intf.ml | 196 +++ src/hash_set.ml | 220 ++++ src/hash_set.mli | 1 + src/hash_set_intf.ml | 204 +++ src/hash_stubs.c | 26 + src/hashable.ml | 4 + src/hashable.mli | 5 + src/hashable_intf.ml | 93 ++ src/hasher.ml | 56 + src/hashtbl.ml | 880 +++++++++++++ src/hashtbl.mli | 1 + src/hashtbl_intf.ml | 769 ++++++++++++ src/hex_lexer.mll | 15 + src/identifiable.ml | 58 + src/identifiable.mli | 77 ++ src/import.ml | 8 + src/import0.ml | 392 ++++++ src/indexed_container.ml | 65 + src/indexed_container.mli | 1 + src/indexed_container_intf.ml | 66 + src/info.ml | 240 ++++ src/info.mli | 1 + src/info_intf.ml | 148 +++ src/int.ml | 310 +++++ src/int.mli | 1 + src/int0.ml | 23 + src/int32.ml | 274 ++++ src/int32.mli | 48 + src/int63.ml | 77 ++ src/int63.mli | 86 ++ src/int63_backends.ml | 50 + src/int63_emul.ml | 399 ++++++ src/int63_emul.mli | 34 + src/int64.ml | 260 ++++ src/int64.mli | 32 + src/int_conversions.ml | 359 ++++++ src/int_conversions.mli | 134 ++ src/int_intf.ml | 363 ++++++ src/int_math.ml | 144 +++ src/int_math.mli | 39 + src/int_math_stubs.c | 92 ++ src/intable.ml | 11 + src/internalhash.h | 3 + src/internalhash_stubs.c | 101 ++ src/invariant.ml | 26 + src/invariant.mli | 1 + src/invariant_intf.ml | 98 ++ src/lazy.ml | 44 + src/lazy.mli | 81 ++ src/linked_queue.ml | 159 +++ src/linked_queue.mli | 19 + src/linked_queue0.ml | 16 + src/list.ml | 1065 ++++++++++++++++ src/list.mli | 514 ++++++++ src/list0.ml | 42 + src/list1.ml | 15 + src/map.ml | 1770 ++++++++++++++++++++++++++ src/map.mli | 1 + src/map_intf.ml | 1936 +++++++++++++++++++++++++++++ src/maybe_bound.ml | 164 +++ src/maybe_bound.mli | 63 + src/monad.ml | 107 ++ src/monad.mli | 1 + src/monad_intf.ml | 365 ++++++ src/nativeint.ml | 258 ++++ src/nativeint.mli | 27 + src/obj_array.ml | 179 +++ src/obj_array.mli | 72 ++ src/option.ml | 200 +++ src/option.mli | 75 ++ src/option_array.ml | 171 +++ src/option_array.mli | 93 ++ src/or_error.ml | 129 ++ src/or_error.mli | 126 ++ src/ordered_collection_common.ml | 46 + src/ordered_collection_common.mli | 39 + src/ordering.ml | 70 ++ src/ordering.mli | 74 ++ src/poly0.ml | 18 + src/poly0.mli | 22 + src/popcount.ml | 39 + src/popcount.mli | 11 + src/pow_overflow_bounds.ml | 425 +++++++ src/ppx_compare_lib.ml | 113 ++ src/ppx_compare_lib.mli | 50 + src/ppx_enumerate_lib.ml | 1 + src/ppx_hash_lib.ml | 6 + src/ppx_sexp_conv_lib.ml | 1 + src/pretty_printer.ml | 29 + src/pretty_printer.mli | 44 + src/printf.ml | 8 + src/printf.mli | 132 ++ src/queue.ml | 503 ++++++++ src/queue.mli | 1 + src/queue_intf.ml | 135 ++ src/random.ml | 242 ++++ src/random.mli | 129 ++ src/ref.ml | 47 + src/ref.mli | 30 + src/result.ml | 188 +++ src/result.mli | 108 ++ src/runtime.js | 118 ++ src/select-bytes-set-primitives/select.ml | 20 + src/select-int63-backend/select.ml | 29 + src/sequence.ml | 1097 ++++++++++++++++ src/sequence.mli | 487 ++++++++ src/set.ml | 1314 ++++++++++++++++++++ src/set.mli | 1 + src/set_intf.ml | 1278 +++++++++++++++++++ src/sexp.ml | 42 + src/sexp.mli | 20 + src/sexp_with_comparable.ml | 3 + src/sexp_with_comparable.mli | 6 + src/sexpable.ml | 79 ++ src/sexpable.mli | 78 ++ src/sexplib.ml | 10 + src/sign.ml | 27 + src/sign.mli | 32 + src/sign0.ml | 85 ++ src/sign_or_nan.ml | 117 ++ src/sign_or_nan.mli | 34 + src/source_code_position.ml | 20 + src/source_code_position.mli | 32 + src/source_code_position0.ml | 179 +++ src/stack.ml | 209 ++++ src/stack.mli | 1 + src/stack_intf.ml | 79 ++ src/staged.ml | 6 + src/staged.mli | 43 + src/string.ml | 1267 +++++++++++++++++++ src/string.mli | 461 +++++++ src/string0.ml | 62 + src/stringable.ml | 10 + src/sys.ml | 3 + src/sys.mli | 95 ++ src/sys0.ml | 43 + src/t.ml | 10 + src/type_equal.ml | 172 +++ src/type_equal.mli | 236 ++++ src/uchar.ml | 81 ++ src/uchar.mli | 55 + src/uchar0.ml | 21 + src/uniform_array.ml | 119 ++ src/uniform_array.mli | 88 ++ src/unit.ml | 29 + src/unit.mli | 19 + src/validate.ml | 184 +++ src/validate.mli | 175 +++ src/variant.ml | 7 + src/variant.mli | 9 + src/variantslib.ml | 3 + src/with_return.ml | 35 + src/with_return.mli | 54 + src/word_size.ml | 19 + src/word_size.mli | 14 + test/avltree_unit_tests.ml | 282 +++++ test/avltree_unit_tests.mli | 1 + test/dune | 4 + test/hashtbl_tests.ml | 360 ++++++ test/hashtbl_tests.mli | 13 + test/import.ml | 56 + test/int_math_tests.ml | 39 + test/interfaces_tests.ml | 48 + test/test_am_testing.ml | 8 + test/test_am_testing.mli | 1 + test/test_am_testing.mlt | 7 + test/test_applicative.ml | 383 ++++++ test/test_applicative.mli | 1 + test/test_array.ml | 288 +++++ test/test_array.mli | 1 + test/test_backtrace.ml | 14 + test/test_backtrace.mli | 1 + test/test_base.ml | 17 + test/test_base.mli | 1 + test/test_blit.ml | 81 ++ test/test_blit.mli | 1 + test/test_bool.ml | 32 + test/test_bool.mli | 1 + test/test_bytes.ml | 14 + test/test_char.ml | 534 ++++++++ test/test_char.mli | 1 + test/test_compare.ml | 97 ++ test/test_compare.mli | 1 + test/test_container.ml | 174 +++ test/test_error.ml | 23 + test/test_error.mli | 1 + test/test_exn.ml | 18 + test/test_exn.mli | 1 + test/test_exported_int_conversions.ml | 235 ++++ test/test_exported_int_conversions.mli | 1 + test/test_float.ml | 1025 +++++++++++++++ test/test_float.mli | 1 + test/test_fn.ml | 12 + test/test_fn.mli | 1 + test/test_hash_set.ml | 64 + test/test_hash_set.mli | 1 + test/test_hashtbl.ml | 24 + test/test_hashtbl.mli | 1 + test/test_identifiable.ml | 17 + test/test_identifiable.mli | 1 + test/test_indexed_container.ml | 181 +++ test/test_indexed_container.mli | 1 + test/test_info.ml | 62 + test/test_info.mli | 1 + test/test_int.ml | 117 ++ test/test_int.mli | 1 + test/test_int32.ml | 8 + test/test_int32.mli | 1 + test/test_int32_pow2.ml | 114 ++ test/test_int32_pow2.mli | 1 + test/test_int63.ml | 76 ++ test/test_int63.mli | 1 + test/test_int63_emul.ml | 14 + test/test_int63_emul.mli | 1 + test/test_int64.ml | 7 + test/test_int64.mli | 1 + test/test_int64_pow2.ml | 123 ++ test/test_int64_pow2.mli | 1 + test/test_int_conversions.ml | 170 +++ test/test_int_conversions.mli | 1 + test/test_int_hash.ml | 8 + test/test_int_hash.mli | 1 + test/test_int_math.ml | 192 +++ test/test_int_math.mli | 1 + test/test_int_pow2.ml | 126 ++ test/test_int_pow2.mli | 1 + test/test_lazy.ml | 55 + test/test_lazy.mli | 1 + test/test_list.ml | 610 +++++++++ test/test_list.mli | 1 + test/test_map.ml | 62 + test/test_map.mlt | 6 + test/test_maybe_bound.ml | 115 ++ test/test_maybe_bound.mli | 1 + test/test_nativeint.ml | 7 + test/test_nativeint.mli | 1 + test/test_nativeint_pow2.ml | 127 ++ test/test_not_found.mlt | 18 + test/test_obj_array.ml | 108 ++ test/test_option.ml | 8 + test/test_option.mli | 1 + test/test_option_array.ml | 69 + test/test_or_error.ml | 26 + test/test_or_error.mli | 1 + test/test_ordered_collection_common.ml | 66 + test/test_ordered_collection_common.mli | 1 + test/test_ordering.ml | 15 + test/test_popcount.ml | 43 + test/test_popcount.mli | 1 + test/test_ppx_compare_lib.ml | 104 ++ test/test_ppx_compare_lib.mli | 1 + test/test_queue.ml | 902 ++++++++++++++ test/test_queue.mli | 1 + test/test_random.ml | 185 +++ test/test_random.mli | 1 + test/test_ref.ml | 21 + test/test_ref.mli | 1 + test/test_sequence.ml | 488 ++++++++ test/test_sequence.mli | 1 + test/test_set.ml | 28 + test/test_sexp.ml | 40 + test/test_sexp.mli | 1 + test/test_sign.ml | 17 + test/test_sign.mli | 1 + test/test_sign_or_nan.ml | 10 + test/test_sign_or_nan.mli | 1 + test/test_stack.ml | 235 ++++ test/test_stack.mli | 5 + test/test_stdlib_shadowing.mlt | 41 + test/test_string.ml | 678 ++++++++++ test/test_string.mli | 1 + test/test_type_equal.ml | 73 ++ test/test_type_equal.mli | 1 + test/test_uchar.ml | 81 ++ test/test_uchar.mli | 1 + test/test_uniform_array.ml | 13 + test/test_validate.ml | 95 ++ test/test_validate.mli | 1 + test/test_with_return.ml | 54 + test/test_with_return.mli | 1 + test/test_word_size.ml | 11 + test/test_word_size.mli | 1 + test/validate_fields_folder.mlt | 127 ++ 377 files changed, 45487 insertions(+) create mode 100644 .gitignore create mode 100644 CHANGES.md create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE.md create mode 100644 Makefile create mode 100644 README.org create mode 100644 ROADMAP.md create mode 100644 base.opam create mode 100644 compiler-stdlib/gen/dune create mode 100644 compiler-stdlib/gen/gen.ml create mode 100644 compiler-stdlib/src/dune create mode 100644 dune-project create mode 100644 generate/dune create mode 100644 generate/generate_pow_overflow_bounds.ml create mode 100644 lint/dune create mode 100644 lint/ppx_base_lint.ml create mode 100644 md5/src/dune create mode 100644 md5/src/md5_lib.ml create mode 100644 md5/src/md5_lib.mli create mode 100644 shadow-stdlib/gen/dune create mode 100644 shadow-stdlib/gen/gen.ml create mode 100644 shadow-stdlib/gen/mapper.mll create mode 100644 shadow-stdlib/src/dune create mode 100644 shadow-stdlib/src/shadow_stdlib.ml create mode 100644 src/am_testing.c create mode 100644 src/am_testing.h create mode 100644 src/applicative.ml create mode 100644 src/applicative.mli create mode 100644 src/applicative_intf.ml create mode 100644 src/array.ml create mode 100644 src/array.mli create mode 100644 src/array0.ml create mode 100644 src/array_permute.ml create mode 100644 src/avltree.ml create mode 100644 src/avltree.mli create mode 100644 src/backtrace.ml create mode 100644 src/backtrace.mli create mode 100644 src/base.ml create mode 100644 src/base.mld create mode 100644 src/binary_search.ml create mode 100644 src/binary_search.mli create mode 100644 src/binary_searchable.ml create mode 100644 src/binary_searchable.mli create mode 100644 src/binary_searchable_intf.ml create mode 100644 src/blit.ml create mode 100644 src/blit.mli create mode 100644 src/blit_intf.ml create mode 100644 src/bool.ml create mode 100644 src/bool.mli create mode 100644 src/buffer.ml create mode 100644 src/buffer.mli create mode 100644 src/buffer_intf.ml create mode 100644 src/bytes.ml create mode 100644 src/bytes.mli create mode 100644 src/bytes0.ml create mode 100644 src/bytes_tr.ml create mode 100644 src/char.ml create mode 100644 src/char.mli create mode 100644 src/char0.ml create mode 100644 src/comparable.ml create mode 100644 src/comparable.mli create mode 100644 src/comparable_intf.ml create mode 100644 src/comparator.ml create mode 100644 src/comparator.mli create mode 100644 src/comparisons.ml create mode 100644 src/container.ml create mode 100644 src/container.mli create mode 100644 src/container_intf.ml create mode 100644 src/discover/discover.ml create mode 100644 src/discover/discover.mli create mode 100644 src/discover/dune create mode 100644 src/dune create mode 100644 src/either.ml create mode 100644 src/either.mli create mode 100644 src/either_intf.ml create mode 100644 src/equal.ml create mode 100644 src/error.ml create mode 100644 src/error.mli create mode 100644 src/exn.ml create mode 100644 src/exn.mli create mode 100644 src/exn_stubs.c create mode 100644 src/field.ml create mode 100644 src/field.mli create mode 100644 src/fieldslib.ml create mode 100644 src/float.ml create mode 100644 src/float.mli create mode 100644 src/float0.ml create mode 100644 src/floatable.ml create mode 100644 src/fn.ml create mode 100644 src/fn.mli create mode 100644 src/formatter.ml create mode 100644 src/formatter.mli create mode 100644 src/hash.ml create mode 100644 src/hash.mli create mode 100644 src/hash_intf.ml create mode 100644 src/hash_set.ml create mode 100644 src/hash_set.mli create mode 100644 src/hash_set_intf.ml create mode 100644 src/hash_stubs.c create mode 100644 src/hashable.ml create mode 100644 src/hashable.mli create mode 100644 src/hashable_intf.ml create mode 100644 src/hasher.ml create mode 100644 src/hashtbl.ml create mode 100644 src/hashtbl.mli create mode 100644 src/hashtbl_intf.ml create mode 100644 src/hex_lexer.mll create mode 100644 src/identifiable.ml create mode 100644 src/identifiable.mli create mode 100644 src/import.ml create mode 100644 src/import0.ml create mode 100644 src/indexed_container.ml create mode 100644 src/indexed_container.mli create mode 100644 src/indexed_container_intf.ml create mode 100644 src/info.ml create mode 100644 src/info.mli create mode 100644 src/info_intf.ml create mode 100644 src/int.ml create mode 100644 src/int.mli create mode 100644 src/int0.ml create mode 100644 src/int32.ml create mode 100644 src/int32.mli create mode 100644 src/int63.ml create mode 100644 src/int63.mli create mode 100644 src/int63_backends.ml create mode 100644 src/int63_emul.ml create mode 100644 src/int63_emul.mli create mode 100644 src/int64.ml create mode 100644 src/int64.mli create mode 100644 src/int_conversions.ml create mode 100644 src/int_conversions.mli create mode 100644 src/int_intf.ml create mode 100644 src/int_math.ml create mode 100644 src/int_math.mli create mode 100644 src/int_math_stubs.c create mode 100644 src/intable.ml create mode 100644 src/internalhash.h create mode 100644 src/internalhash_stubs.c create mode 100644 src/invariant.ml create mode 100644 src/invariant.mli create mode 100644 src/invariant_intf.ml create mode 100644 src/lazy.ml create mode 100644 src/lazy.mli create mode 100644 src/linked_queue.ml create mode 100644 src/linked_queue.mli create mode 100644 src/linked_queue0.ml create mode 100644 src/list.ml create mode 100644 src/list.mli create mode 100644 src/list0.ml create mode 100644 src/list1.ml create mode 100644 src/map.ml create mode 100644 src/map.mli create mode 100644 src/map_intf.ml create mode 100644 src/maybe_bound.ml create mode 100644 src/maybe_bound.mli create mode 100644 src/monad.ml create mode 100644 src/monad.mli create mode 100644 src/monad_intf.ml create mode 100644 src/nativeint.ml create mode 100644 src/nativeint.mli create mode 100644 src/obj_array.ml create mode 100644 src/obj_array.mli create mode 100644 src/option.ml create mode 100644 src/option.mli create mode 100644 src/option_array.ml create mode 100644 src/option_array.mli create mode 100644 src/or_error.ml create mode 100644 src/or_error.mli create mode 100644 src/ordered_collection_common.ml create mode 100644 src/ordered_collection_common.mli create mode 100644 src/ordering.ml create mode 100644 src/ordering.mli create mode 100644 src/poly0.ml create mode 100644 src/poly0.mli create mode 100644 src/popcount.ml create mode 100644 src/popcount.mli create mode 100644 src/pow_overflow_bounds.ml create mode 100644 src/ppx_compare_lib.ml create mode 100644 src/ppx_compare_lib.mli create mode 100644 src/ppx_enumerate_lib.ml create mode 100644 src/ppx_hash_lib.ml create mode 100644 src/ppx_sexp_conv_lib.ml create mode 100644 src/pretty_printer.ml create mode 100644 src/pretty_printer.mli create mode 100644 src/printf.ml create mode 100644 src/printf.mli create mode 100644 src/queue.ml create mode 100644 src/queue.mli create mode 100644 src/queue_intf.ml create mode 100644 src/random.ml create mode 100644 src/random.mli create mode 100644 src/ref.ml create mode 100644 src/ref.mli create mode 100644 src/result.ml create mode 100644 src/result.mli create mode 100644 src/runtime.js create mode 100644 src/select-bytes-set-primitives/select.ml create mode 100644 src/select-int63-backend/select.ml create mode 100644 src/sequence.ml create mode 100644 src/sequence.mli create mode 100644 src/set.ml create mode 100644 src/set.mli create mode 100644 src/set_intf.ml create mode 100644 src/sexp.ml create mode 100644 src/sexp.mli create mode 100644 src/sexp_with_comparable.ml create mode 100644 src/sexp_with_comparable.mli create mode 100644 src/sexpable.ml create mode 100644 src/sexpable.mli create mode 100644 src/sexplib.ml create mode 100644 src/sign.ml create mode 100644 src/sign.mli create mode 100644 src/sign0.ml create mode 100644 src/sign_or_nan.ml create mode 100644 src/sign_or_nan.mli create mode 100644 src/source_code_position.ml create mode 100644 src/source_code_position.mli create mode 100644 src/source_code_position0.ml create mode 100644 src/stack.ml create mode 100644 src/stack.mli create mode 100644 src/stack_intf.ml create mode 100644 src/staged.ml create mode 100644 src/staged.mli create mode 100644 src/string.ml create mode 100644 src/string.mli create mode 100644 src/string0.ml create mode 100644 src/stringable.ml create mode 100644 src/sys.ml create mode 100644 src/sys.mli create mode 100644 src/sys0.ml create mode 100644 src/t.ml create mode 100644 src/type_equal.ml create mode 100644 src/type_equal.mli create mode 100644 src/uchar.ml create mode 100644 src/uchar.mli create mode 100644 src/uchar0.ml create mode 100644 src/uniform_array.ml create mode 100644 src/uniform_array.mli create mode 100644 src/unit.ml create mode 100644 src/unit.mli create mode 100644 src/validate.ml create mode 100644 src/validate.mli create mode 100644 src/variant.ml create mode 100644 src/variant.mli create mode 100644 src/variantslib.ml create mode 100644 src/with_return.ml create mode 100644 src/with_return.mli create mode 100644 src/word_size.ml create mode 100644 src/word_size.mli create mode 100644 test/avltree_unit_tests.ml create mode 100644 test/avltree_unit_tests.mli create mode 100644 test/dune create mode 100644 test/hashtbl_tests.ml create mode 100644 test/hashtbl_tests.mli create mode 100644 test/import.ml create mode 100644 test/int_math_tests.ml create mode 100644 test/interfaces_tests.ml create mode 100644 test/test_am_testing.ml create mode 100644 test/test_am_testing.mli create mode 100644 test/test_am_testing.mlt create mode 100644 test/test_applicative.ml create mode 100644 test/test_applicative.mli create mode 100644 test/test_array.ml create mode 100644 test/test_array.mli create mode 100644 test/test_backtrace.ml create mode 100644 test/test_backtrace.mli create mode 100644 test/test_base.ml create mode 100644 test/test_base.mli create mode 100644 test/test_blit.ml create mode 100644 test/test_blit.mli create mode 100644 test/test_bool.ml create mode 100644 test/test_bool.mli create mode 100644 test/test_bytes.ml create mode 100644 test/test_char.ml create mode 100644 test/test_char.mli create mode 100644 test/test_compare.ml create mode 100644 test/test_compare.mli create mode 100644 test/test_container.ml create mode 100644 test/test_error.ml create mode 100644 test/test_error.mli create mode 100644 test/test_exn.ml create mode 100644 test/test_exn.mli create mode 100644 test/test_exported_int_conversions.ml create mode 100644 test/test_exported_int_conversions.mli create mode 100644 test/test_float.ml create mode 100644 test/test_float.mli create mode 100644 test/test_fn.ml create mode 100644 test/test_fn.mli create mode 100644 test/test_hash_set.ml create mode 100644 test/test_hash_set.mli create mode 100644 test/test_hashtbl.ml create mode 100644 test/test_hashtbl.mli create mode 100644 test/test_identifiable.ml create mode 100644 test/test_identifiable.mli create mode 100644 test/test_indexed_container.ml create mode 100644 test/test_indexed_container.mli create mode 100644 test/test_info.ml create mode 100644 test/test_info.mli create mode 100644 test/test_int.ml create mode 100644 test/test_int.mli create mode 100644 test/test_int32.ml create mode 100644 test/test_int32.mli create mode 100644 test/test_int32_pow2.ml create mode 100644 test/test_int32_pow2.mli create mode 100644 test/test_int63.ml create mode 100644 test/test_int63.mli create mode 100644 test/test_int63_emul.ml create mode 100644 test/test_int63_emul.mli create mode 100644 test/test_int64.ml create mode 100644 test/test_int64.mli create mode 100644 test/test_int64_pow2.ml create mode 100644 test/test_int64_pow2.mli create mode 100644 test/test_int_conversions.ml create mode 100644 test/test_int_conversions.mli create mode 100644 test/test_int_hash.ml create mode 100644 test/test_int_hash.mli create mode 100644 test/test_int_math.ml create mode 100644 test/test_int_math.mli create mode 100644 test/test_int_pow2.ml create mode 100644 test/test_int_pow2.mli create mode 100644 test/test_lazy.ml create mode 100644 test/test_lazy.mli create mode 100644 test/test_list.ml create mode 100644 test/test_list.mli create mode 100644 test/test_map.ml create mode 100644 test/test_map.mlt create mode 100644 test/test_maybe_bound.ml create mode 100644 test/test_maybe_bound.mli create mode 100644 test/test_nativeint.ml create mode 100644 test/test_nativeint.mli create mode 100644 test/test_nativeint_pow2.ml create mode 100644 test/test_not_found.mlt create mode 100644 test/test_obj_array.ml create mode 100644 test/test_option.ml create mode 100644 test/test_option.mli create mode 100644 test/test_option_array.ml create mode 100644 test/test_or_error.ml create mode 100644 test/test_or_error.mli create mode 100644 test/test_ordered_collection_common.ml create mode 100644 test/test_ordered_collection_common.mli create mode 100644 test/test_ordering.ml create mode 100644 test/test_popcount.ml create mode 100644 test/test_popcount.mli create mode 100644 test/test_ppx_compare_lib.ml create mode 100644 test/test_ppx_compare_lib.mli create mode 100644 test/test_queue.ml create mode 100644 test/test_queue.mli create mode 100644 test/test_random.ml create mode 100644 test/test_random.mli create mode 100644 test/test_ref.ml create mode 100644 test/test_ref.mli create mode 100644 test/test_sequence.ml create mode 100644 test/test_sequence.mli create mode 100644 test/test_set.ml create mode 100644 test/test_sexp.ml create mode 100644 test/test_sexp.mli create mode 100644 test/test_sign.ml create mode 100644 test/test_sign.mli create mode 100644 test/test_sign_or_nan.ml create mode 100644 test/test_sign_or_nan.mli create mode 100644 test/test_stack.ml create mode 100644 test/test_stack.mli create mode 100644 test/test_stdlib_shadowing.mlt create mode 100644 test/test_string.ml create mode 100644 test/test_string.mli create mode 100644 test/test_type_equal.ml create mode 100644 test/test_type_equal.mli create mode 100644 test/test_uchar.ml create mode 100644 test/test_uchar.mli create mode 100644 test/test_uniform_array.ml create mode 100644 test/test_validate.ml create mode 100644 test/test_validate.mli create mode 100644 test/test_with_return.ml create mode 100644 test/test_with_return.mli create mode 100644 test/test_word_size.ml create mode 100644 test/test_word_size.mli create mode 100644 test/validate_fields_folder.mlt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..85f39e5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +_build +*.install +*.merlin + diff --git a/CHANGES.md b/CHANGES.md new file mode 100644 index 0000000..606e4f0 --- /dev/null +++ b/CHANGES.md @@ -0,0 +1,275 @@ +## git version + +- `Ordered_collection_common.get_pos_len` now returns an `Or_error.t` + +- Added `Bool.Non_short_circuiting`. + +- Added `Float.square`. + +- Remove module `Or_error.Ok`. + +- module `Ref` doesn't implement `Container.S1` anymore. + +- Rename parameter of `Sequence.merge` from `cmp` to `compare`. + +- Added `Info.of_lazy_t` + +- Added `List.partition_result` function, to partition a list of `Result.t` + values + +- Changed the signature of `equal` from `'a t -> 'a t -> equal:('a -> 'a -> + bool) -> bool` to `('a -> 'a -> bool) -> 'a t -> 'a t -> bool`. + +- Optimized `Lazy.compare` to check physical equality before forcing the lazy + values. + +- Deprecated `Args` in the `Applicative` interface in favor of using `ppx_let`. + +- Deprecated `Array.replace arr i ~f` in favor of using `arr.(i) <- (f (arr.(i)))` + +- Rename collection length parameter of `Ordered_collection_common` functions + from `length` to `total_length`, and add a unit argument to `get_pos_len` and + `get_pos_len_exn`. + +- Removed functions that were deprecated in 2016 from the `Array` and `Set` + modules. + +- [Int.Hex.of_string] and friends no longer silently ignore a suffix + of non-hexadecimal garbage. + +- Added `?backtrace` argument to `Or_error.of_exn_result`. + +- `List.zip` now returns a `List.Or_unequal_lengths.t` instead of an `option`. + +- Remove functions from the `Sequence` module that were deprecated in 2015. + +- `Container.Make` and `Container.Make0` now require callers to either provide a + custom `length` function or request that one be derived from `fold`. + `Container.to_array`'s signature is also changed to accept `length` and `iter` + instead of `fold`. + +## v0.11 + +- Deprecated `Not_found`, people who need it can use `Caml.Not_found`, but its + use isn't recommended. + +- Added the `Sexp.Not_found_s` exception which will replace `Caml.Not_found` as + the default exception in a future release. + +- Document that `Array.find_exn`, `Array.find_map_exn`, and `Array.findi_exn` + may throw `Caml.Not_found` _or_ `Not_found_s`. + +- Document that `Hashtbl.find_exn` may throw `Caml.Not_found` _or_ + `Not_found_s`. + +- Document that `List.find_exn`, and `List.find_map_exn` may throw + `Caml.Not_found` _or_ `Not_found_s`. + +- Document that `List.find_exn` may throw `Caml.Not_found` _or_ `Not_found_s`. + +- Document that `String.lsplit2_exn`, and `String.rsplit2_exn` may throw + `Caml.Not_found` _or_ `Not_found_s`. + +- Added `Sys.backend_type`. + +- Removed unnecessary unit argument from `Hashtbl.create`. + +- Removed deprecated operations from `Hashtbl`. + +- Removed `Hashable.t` constructors from `Hashtbl` and `Hash_set`, instead + favoring the first-class module constructors. + +- Removed `Container` operations from `Either.First` and `Either.Second`. + +- Changed the type of `fold_until` in the `Container` interfaces. Rather than + returning a `Finished_or_stopped_early.t` (which has also been removed), the + function now takes a `finish` function that will be applied the result if `f` + never returned a `Stop _`. + +- Removed the `String_dict` module. + +- Added a `Queue` module that is backed by an `Option_array` for efficient and + (non-allocating) implementations of most operations. + +- Added a `Poly` submodule to `Map` and `Set` that exposes constructors that + use polymorphic compare. + +- Deprecated `all_ignore` in the `Monad` and `Applicative` interfaces in favor + of `all_unit`. + +- Deprecated `Array.replace_all` in favor of `Array.map_inplace`, which is the + standard name for that sort of operation within Base. + +- Document that `List.find_exn`, and `List.find_map_exn` may throw + `Caml.Not_found` _or_ `Not_found_s`. + +- Make `~compare` a required argument to `List.dedup_and_sort`, `List.dedup`, + `List.find_a_dup`, `List.contains_dup`, and `List.find_all_dups`. + +- Removed `List.exn_if_dup`. It is still available in core_kernel. + +- Removed "normalized" index operation `List.slice`. It is still available in + core_kernel. + +- Remove "normalized" index operations from `Array`, which incluced + `Array.normalize`, `Array.slice`, `Array.nget` and `Array.nset`. These + operations are still available in core_kernel. + +- Added `Uniform_array` module that is just like an `Array` except guarantees + that the representation array is not tagged with `Double_array_tag`, the tag + for float arrays. + +- Added `Option_array` module that allows for a compact representation of `'a + optoin array`, which avoids allocating heap objects representing `Some a`. + +- Remove "normalized" index operations from `String`, which incluced + `String.normalize`, `String.slice`, `String.nget` and `String.nset`. These + operations are still available in core_kernel. + +- Added missing conversions between `Int63` and other integer types, + specifically, the versions that return options. + +- Added truncating versions of integer conversions, with a suffix of + `_trunc`. These allow fast conversions via bit arithmetic without + any conditional failure; excess bits beyond the width of the output + type are simply dropped. + +- Added `Sequence.group`, similar to `List.group`. + +- Reimplemented `String.Caseless.compare` so that it does not + allocate. + +- Added `String.is_substring_at string ~pos ~substring`. Used it as + back-end for `is_suffix` and `is_prefix`. + +- Moved all remaining `Replace_polymorphic_compare` submodules from Base + types and consolidated them in one place within `Import0`. + +- Removed `(<=.)` and its friends. + +- Added `Sys.argv`. + +- Added a infix exponentation operator for int. + +- Added a `Formatter` module to reexport the `Format.formatter` type and updated + the deprecation message for `Format`. + +## v0.10 + +(Changes that can break existing programs are marked with a "\*") + +### Bugfixes + +- Generalized the type of `Printf.ifprintf` to reflect OCaml's stdlib. + +- Made `Sequence.fold_m` and `iter_m` respect `Skip` steps and explicitly bind + when they occur. + +- Changed `Float.is_negative` and `is_non_positive` on `NaN` to return `false` + rather than `true`. + +- Fixed the `Validate.protect` function, which was mistakenly raising exceptions. + +### API changes + +- Renamed `Map.add` as `set`, and deprecated `add`. A later feature will add + `add` and `add_exn` in the style of `Hashtbl`. + +- A different hash function is used to implement [Base.Int.hash]. + The old implementation was [Int.abs] but collision resistance is not enough, + we want avalanching as well. + The new function is an adaptation of one of the + [Thomas Wang](http://web.archive.org/web/20071223173210/http://www.concentric.net/~Ttwang/tech/inthash.htm) + hash functions to OCaml (63-bit integers), which results in reasonably good avalanching. + + +- Made `open Base` expose infix float operators (+., -., etc.). + +* Renamed `List.dedup` to `List.dedup_and_sort`, to better reflect its existing behavior. + +- Added `Hashtbl.find_multi` and `Map.find_multi`. + +- Added function `Map.of_increasing_sequence` for constructing a `Map.t` from an + ordered `Sequence.t` + +- Added function `List.chunks_of : 'a t -> length : int -> 'a t t`, for breaking + a list into chunks of equal length. + +- Add to module `Random` numeric functions that take upper and lower inclusive + bounds, e.g. `Random.int_incl : int -> int -> int`. + +* Replaced `Exn.Never_elide_backtrace` with `Backtrace.elide`, a `ref` cell that + determines whether `Backtrace.to_string` and `Backtrace.sexp_of_t` elide + backtraces. + +- Exposed infix operator `Base.( @@ )`. + +- Exposed modules `Base.Continue_or_stop` and `Finished_or_stopped_early`, used + with the `Container.fold_until` function. + +- Exposed module types Base.T, T1, T2, and T3. + +- Added `Sequence.Expert` functions `next_step` and + `delayed_fold_step`, for clients that want to explicitly handle `Skip` steps. + +- Added `Bytes` module. + This includes the submodules `From_string` and `To_string` with blit + functions. + N.B. the signature (and name) of `unsafe_to_string` and `unsafe_of_string` are + different from the one in the standard library (and hopefully more explicit). + +- Add bytes functions to `Buffer`. + Also added `Buffer.content_bytes`, the analog of `contents` but that returns + `bytes` rather than `string`. + +* Enabled `-safe-string`. + +- Added function `Int63.of_int32`, which was missing. + +* Deprecated a number of `String` mutating functions. + +- Added module `Obj_array`, moved in from `Core_kernel`. + +* In module type `Hashtbl.Accessors`, removed deprecated functions, moving them + into a new module type, `Deprecated`. + +- Exported `sexp_*` types that are recognized by `ppx_sexp_*` converters: + `sexp_array`, `sexp_list`, `sexp_opaque`, `sexp_option`. + +* Reworked the `Or_error` module's interface, moving the `Container.S` interface + to an `Ok` submodule, and adding functions `is_ok`, `is_error`, and `ok` to + more closely resemble the interface of the `Result` module. + +- Removed `Int.O.of_int_exn`. + +- Exposed `Base.force` function. + +- Changed the deprecation warning for `mod` to recommend `( % )` rather than + `Caml.( mod )`. + +### Performance related changes + +- Optimized `List.compare`, removing its closure allocation. + +- Optimized `String.mem` to not allocate. + +- Optimized `Float.is_negative`, `is_non_negative`, `is_positive`, and + `is_non_positive` to avoid some boxing. + +- Changed `Hashtbl.merge` to relax its equality check on the input tables' + `Hashable.t` records, checking physical equality componentwise if the records + aren't physically equal. + +- Added `Result.combine_errors`, similar to `Or_error.combine_errors`, with a + slightly different type. + +- Added `Result.combine_errors_unit`, similar to `Or_error.combine_errors_unit`. + +- Optimized the `With_return.return` type by adding the `[@@unboxed]` attribute. + +- Improved a number of deprecation warnings. + + +## v0.9 + +Initial release. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..45e1a22 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,67 @@ +This repository contains open source software that is developed and +maintained by [Jane Street][js]. + +Contributions to this project are welcome and should be submitted via +GitHub pull requests. + +Signing contributions +--------------------- + +We require that you sign your contributions. Your signature certifies +that you wrote the patch or otherwise have the right to pass it on as +an open-source patch. The rules are pretty simple: if you can certify +the below (from [developercertificate.org][dco]): + +``` +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. +1 Letterman Drive +Suite D4700 +San Francisco, CA, 94129 + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. +``` + +Then you just add a line to every git commit message: + +``` +Signed-off-by: Joe Smith +``` + +Use your real name (sorry, no pseudonyms or anonymous contributions.) + +If you set your `user.name` and `user.email` git configs, you can sign +your commit automatically with git commit -s. + +[dco]: http://developercertificate.org/ +[js]: https://opensource.janestreet.com/ diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..0680c3e --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2016--2019 Jane Street Group, LLC + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1965878 --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +INSTALL_ARGS := $(if $(PREFIX),--prefix $(PREFIX),) + +default: + dune build + +install: + dune install $(INSTALL_ARGS) + +uninstall: + dune uninstall $(INSTALL_ARGS) + +reinstall: uninstall install + +clean: + dune clean + +.PHONY: default install uninstall reinstall clean diff --git a/README.org b/README.org new file mode 100644 index 0000000..83a9061 --- /dev/null +++ b/README.org @@ -0,0 +1,187 @@ +* Base + +Base is a standard library for OCaml. It provides a standard set of +general purpose modules that are well-tested, performant, and +fully-portable across any environment that can run OCaml code. Unlike +other standard library projects, Base is meant to be used as a +wholesale replacement of the standard library distributed with the +OCaml compiler. In particular it makes different choices and doesn't +re-export features that are not fully portable such as I/O, which are +left to other libraries. + +You also might want to browse the [[https://ocaml.janestreet.com/ocaml-core/latest/doc/base/index.html][API Documentation]]. + +** Installation + +Install Base via [[https://opam.ocaml.org][OPAM]]: + +#+begin_src +$ opam install base +#+end_src + +Base has no runtime dependencies and is fast to build. Its sole build +dependency is [[https://github.com/ocaml/dune][dune]], which itself requires nothing more than the +compiler. + +** Using the OCaml standard library with Base + +Base is intended as a full stdlib replacement. As a result, after an +=open Base=, all the modules, values, types, ... coming from the OCaml +standard library that one normally gets in the default environment are +deprecated. + +In order to access these values, one must use the =Caml= library, +which re-exports them all through the toplevel name =Caml=: +=Caml.String=, =Caml.print_string=, ... + +The recommended way to build code using Base is as follows: + +#+begin_src ocaml +$ ocamlc -open Base +#+end_src + +** Differences between Base and the OCaml standard library + +Programmers who are used to the OCaml standard library should read +through this section to understand major differences between the two +libraries that one should be aware of when switching to Base. + +*** Comparison operators + +The comparison operators exposed by the OCaml standard library are +polymorphic: + +#+begin_src ocaml +val compare : 'a -> 'a -> int +val ( <= ) : 'a -> 'a -> bool +... +#+end_src + +What they implement is structural comparison of the runtime +representation of values. Since these are often error-prone, +i.e. they don't correspond to what the user expects, they are not +exposed directly by Base. + +To use polymorphic comparison with Base, one should use the +=Polymorphic_compare= module. The default comparison operators exposed +by Base are the integer ones, just like the default arithmetic +operators are the integer ones. + +The recommended way to compare arbitrary complex data structures is to +use the specific =compare= functions. For instance: + +#+begin_src ocaml +List.compare String.compare x y +#+end_src + +The [[https://github.com/janestreet/ppx_compare][ppx_compare]] rewriter +offers an alternative way to write this: + +#+begin_src ocaml +[%compare: string list] x y +#+end_src + +** Base and ppx code generators + +Base uses a few ppx code generators to implement: + +- reliable and customizable comparison of OCaml values +- reliable and customizable hash of OCaml values +- conversions between OCaml values and s-expression + +However, it doesn't need these code generators to build. What it does +instead is use ppx as a code verification tool during development. It +works in a very similar fashion to +[[https://github.com/janestreet/ppx_expect][expectation tests]]. + +Whenever you see this in the code source: + +#+begin_src ocaml +type t = ... [@@deriving_inline sexp_of] +let sexp_of_t = ... +[@@@end] +#+end_src + +the code between the =[@@deriving_inline]= and the =[@@@end]= is +generated code. The generated code is currently quite big and hard to +read, however we are working on making it look like human-written +code. + +You can put the following elisp code in your =~/.emacs= file to hide +these blocks: + +#+begin_src scheme +(defun deriving-inline-forward-sexp (&optional arg) + (search-forward-regexp "\\[@@@end\\]") nil nil arg) + +(defun setup-hide-deriving-inline () + (inline) + (hs-minor-mode t) + (let ((hs-hide-comments-when-hiding-all nil)) + (hs-hide-all))) + +(require 'hideshow) +(add-to-list 'hs-special-modes-alist + '(tuareg-mode "\\[@@deriving_inline[^]]*\\]" "\\[@@@end\\]" nil + deriving-inline-forward-sexp nil)) +(add-hook 'tuareg-mode-hook 'setup-hide-deriving-inline) +#+end_src + +Things are not yet setup in the git repository to make it convenient +to change types and update the generated code, but they will be setup +soon. + +** Base coding rules + +There are a few coding rules across the code base that are enforced by +lint tools. + +These rules are: + +- Opening the =Caml= module is not allowed. Inside Base, the OCaml + stdlib is shadowed and accessible through the =Caml= module. We + forbid opening =Caml= so that we know exactly where things come + from. +- =Caml.Foo= modules cannot be aliased, one must use =Caml.Foo= + explicitly. This is to avoid having to remember a list of aliases + at the beginning of each file. +- For some modules that are both in the OCaml stdlib and Base, such as + =String=, we define a module =String0= for common functions that + cannot be defined directly in =Base.String= to avoid creating a + circular dependency. Except for =String= itself, other modules + are not allowed to use =Caml.String= and must use either =String= or + =String0= instead. +- Indentation is exactly the one of =ocp-indent=. +- A few other coding style rules enforced by + [[https://github.com/janestreet/ppx_js_style][ppx_js_style]]. + +The Base specific coding rules are checked by =ppx_base_lint=, in the +=lint= subfolder. The indentation rules are checked by a wrapper around +=ocp-indent= and the coding style rules are checked by =ppx_js_style=. + +These checks are currently not run by =dune=, but it will soon get a +=-dev= flag to run them automatically. + +** Roadmap + +Following is the current plan for a stable version 1 of Base. + +*** Add more integer types + +Add support for ={,u}int{8,16,32,64}=. These are always useful when +implementing binary protocols. + +Initially they should be implemented with C stubs and eventually we +should propose their inclusion in the compiler. + +*** 80 columns limit + +Currently lines in Base are limited to a maximum width of 90 +characters. To make things more standard, we should use an 80 columns +limit. The only thing needed for this is to extend the style checker +to enforce a maximum line width. + +*** Improve the generated code + +Improve our code generators to produce code that looks more like +hand-written code. diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000..d7644e8 --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,112 @@ +# Stable Interface (v1.0) + + - [X] Make the entire library `-safe-string` compliant. This will involve + introducing a `Bytes` module, removing all direct mutation on strings from + the `String` module, and "re-typing" string values that require mutation to + `bytes`. + + - [X] Do not export the `\*\_intf` modules from Base. Instead, any signatures + should be exported by the `.ml` and `.mli`s. + + - [X] Only expose the first-class module interface of `Hashtbl`. Accompanying + this should be cleanup of `Hashtbl_intf`, moving anything that's still + required in core_kernel to the appropriate files in that project. + + - [X] Replace `Hashtbl.create (module String) ()` by just + `Hashtbl.create (module String)` + + - [X] Remove `replace` from `Hashtbl_intf.Accessors`. + + - [X] Label one of the arguments of `Hashtbl_intf.merge_into` to indicate the + flow of data. + + - [X] Merge `Hashtbl_intf.Key_common` and `Hashtbl_intf.Key_plain`. + + - [ ] Use `Either.t` as the return value for `Map.partition`. + + - [X] Rename `Monad_intf.all_ignore` to `Monad_intf.all_unit`. + + - [ ] Eliminate all uses of `Not_found`, replacing them with descriptive error messages. + + - [X] Move the various private modules to `Base.Base_private` + instead of `Base.Exported_for_specific_uses` and `Base.Not_exposed_properly` + + - [X] Use `compare` rather than `cmp` as the label for comparison functions + throughout. + +# Implementation Cleanup + + - [ ] Remove `ignore` and `(=)` from `Sexp_conv`'s public interface. These + values are hidden from the documentation so their removal won't be + considered a breaking API change. + + - [ ] Do not expose the type equality `Int63_emul.W.t = int64`. + + - [ ] Replace the exception thrown by `Float.of_string` with a named + exception that's more descriptive. + + - [X] Delete the `Hashable` toplevel module. This is a vestige of the previous + `Map` and `Set` implementations and is no longer needed. + + - [ ] Ensure that `Map` operations that are effective NO-OPs return the same + `Map.t` they were provided. Candidate operations include e.g `add`, `remove`, + `filter`. + + - [ ] Simplify the implementation of `Option.value_exn`, if possible. + + - [ ] Eliminate all instances of `open! Polymorphic_compare` + + - [ ] Refactor common blit code in `String.replace_all` and `String.replace_first`. + + - [ ] Delete unused function aliases in `Import0` + + - [ ] Put `Sexp_conv.Exn_converter` into its own file, with only an + alias in Sexp_conv, so that it doesn't get pulled unless used + + - [ ] Create a file with all the basic types and their associated + combinators (`sexp_of_t`, `compare`, `hash`), and expose the + declaration + + - [ ] Put all the exported private modules from + `Base.Exported_for_specific_uses` and `Base.Not_exposed_properly` + in `Base.Base_private` + + - [ ] Decide on a better name for `Polymorphic_compare`. + `Polymorphic_compare_intf` contains interface for comparison + of non-polymorphic types, which is weird. Get rid of it and + inline things in `Comparable_intf` + + - [X] `hashtbl_of_sexp` shouldn't live in Base.Sexp_conv since we + have our own hash tables. Move it to sexplib + +# Performance Improvements + + - [ ] In `Hash_set.diff`, use the size of each set to determine which to iterate + over. + + - [ ] Ensure that the correct `compare` function and other related functions are + exported by all modules. These functions should not be derived from + a functor application, in order to ensure proper inlining. Implementing + this change should also include benchmarks to verify the initial result, + and to maintain it on an ongoing basis. See `bench/bench_int.ml` for + examples. + + - [X] Optimize `Lazy.compare` by performing a `phys_equal` check before + forcing the lazy value. Note that this will also change the semantics of + `compare` and should be documented and rolled out with care. + + - [ ] Conduct a thorough performance review of the `Sequence` module. + +# Documentation + + - [ ] Consolidate documentation the interface and implementation files + related to the `Hash` module. + + - [ ] Add documentation to the `Ref` toplevel module. + + - [ ] Document properly how `String.unescape_gen` handles malformed strings + +# Changes For The Distant Future + + - [ ] Make the various comparison functions return an `Ordering.t` + instead of an `int`. diff --git a/base.opam b/base.opam new file mode 100644 index 0000000..7662dad --- /dev/null +++ b/base.opam @@ -0,0 +1,35 @@ +opam-version: "2.0" +version: "v0.12.0" +maintainer: "opensource@janestreet.com" +authors: ["Jane Street Group, LLC "] +homepage: "https://github.com/janestreet/base" +bug-reports: "https://github.com/janestreet/base/issues" +dev-repo: "git+https://github.com/janestreet/base.git" +doc: "https://ocaml.janestreet.com/ocaml-core/latest/doc/base/index.html" +license: "MIT" +build: [ + ["dune" "build" "-p" name "-j" jobs] +] +depends: [ + "ocaml" {>= "4.04.2" & < "4.09.0"} + "sexplib0" {>= "v0.12" & < "v0.13"} + "dune" {build & >= "1.5.1"} +] +depopts: [ + "base-native-int63" +] +synopsis: "Full standard library replacement for OCaml" +description: " +Full standard library replacement for OCaml + +Base is a complete and portable alternative to the OCaml standard +library. It provides all standard functionalities one would expect +from a language standard library. It uses consistent conventions +across all of its module. + +Base aims to be usable in any context. As a result system dependent +features such as I/O are not offered by Base. They are instead +provided by companion libraries such as stdio: + + https://github.com/janestreet/stdio +" diff --git a/compiler-stdlib/gen/dune b/compiler-stdlib/gen/dune new file mode 100644 index 0000000..0f60968 --- /dev/null +++ b/compiler-stdlib/gen/dune @@ -0,0 +1,3 @@ +(executables (names gen) + (libraries compiler-libs.common compiler-libs.bytecomp) + (preprocess no_preprocessing)) \ No newline at end of file diff --git a/compiler-stdlib/gen/gen.ml b/compiler-stdlib/gen/gen.ml new file mode 100644 index 0000000..7b90ed3 --- /dev/null +++ b/compiler-stdlib/gen/gen.ml @@ -0,0 +1,88 @@ + +open StdLabels + +module Ocaml_version : sig + type t + val parse : string -> t + val v407 : t + val v408 : t + val current : t + val compare : t -> t -> int +end = struct + type t = int * int + + let parse s = + try + let d1 = String.index_from s 0 '.' in + let d2 = try String.index_from s (d1 + 1) '.' with Not_found -> String.length s in + let p1 = int_of_string (String.sub s ~pos:0 ~len:d1) in + let p2 = int_of_string (String.sub s ~pos:(d1 + 1) ~len:(d2 - d1 - 1)) in + p1, p2 + with _ -> failwith (Printf.sprintf "Invalid ocaml version %S" s) + + let v407 = parse "4.07" + let v408 = parse "4.08" + + let current = parse Sys.ocaml_version + + let compare ((a1,b1): t) ((a2,b2):t) = + match compare a1 a2 with + | 0 -> compare b1 b2 + | c -> c +end + +let () = + let ocaml_where, oc = + match Sys.argv with + | [|_; "-ocaml-where"; ocaml_where; "-o"; fn|] -> + (ocaml_where, open_out fn) + | _ -> + failwith "bad command line arguments" + in + let pr fmt = Printf.fprintf oc (fmt ^^ "\n") in + pr "(* This file is automatically generated *)"; + pr ""; + (if Ocaml_version.(compare current v407) >= 0 + then pr "include Stdlib" + else begin + (* The cma format is documented in typing/cmo_format.mli in the compiler sources *) + let ic = + let (^/) = Filename.concat in + try open_in_bin (ocaml_where ^/ "stdlib" ^/ "stdlib.cma") + with Sys_error _ -> open_in_bin (ocaml_where ^/ "stdlib.cma") + in + let len_magic_number = String.length Config.cma_magic_number in + let magic_number = really_input_string ic len_magic_number in + assert (magic_number = Config.cma_magic_number); + let toc_pos = input_binary_int ic in + seek_in ic toc_pos; + let toc = (input_value ic : Cmo_format.library) in + close_in ic; + let units = + List.map toc.lib_units ~f:(fun cu -> cu.Cmo_format.cu_name) + |> List.sort ~cmp:String.compare + in + let max_len = + List.fold_left units ~init:0 ~f:(fun acc unit -> + max acc (String.length unit)) + in + List.iter units ~f:(fun u -> pr "module %-*s = %s" max_len u u); + pr ""; + pr "include Pervasives"; + end); + pr ""; + (if Ocaml_version.(compare current v407) < 0 + then + pr "module Float = struct end"); + (if Ocaml_version.(compare current v408) < 0 + then begin + pr "module Bool = struct end"; + pr "module Int = struct end"; + pr "module Option = struct end"; + pr "module Result = struct end"; + pr "module Unit = struct end"; + pr "module Fun = struct end" + end; + pr ""; + pr "exception Not_found = Not_found") + diff --git a/compiler-stdlib/src/dune b/compiler-stdlib/src/dune new file mode 100644 index 0000000..317f8e4 --- /dev/null +++ b/compiler-stdlib/src/dune @@ -0,0 +1,4 @@ +(library (name caml) (public_name base.caml) (preprocess no_preprocessing)) + +(rule (targets caml.ml) (deps (:first_dep ../gen/gen.exe)) + (action (run %{first_dep} -ocaml-where %{ocaml_where} -o %{targets}))) \ No newline at end of file diff --git a/dune-project b/dune-project new file mode 100644 index 0000000..598db56 --- /dev/null +++ b/dune-project @@ -0,0 +1 @@ +(lang dune 1.5) \ No newline at end of file diff --git a/generate/dune b/generate/dune new file mode 100644 index 0000000..8c94536 --- /dev/null +++ b/generate/dune @@ -0,0 +1,2 @@ +(executables (names generate_pow_overflow_bounds) (libraries num) + (preprocess no_preprocessing)) \ No newline at end of file diff --git a/generate/generate_pow_overflow_bounds.ml b/generate/generate_pow_overflow_bounds.ml new file mode 100644 index 0000000..4f1f623 --- /dev/null +++ b/generate/generate_pow_overflow_bounds.ml @@ -0,0 +1,185 @@ +(* NB: This needs to be pure OCaml (no Base!), since we need this in order to build + Base. *) + +(* This module generates lookup tables to detect integer overflow when calculating integer + exponents. At index [e], [table.[e]^e] will not overflow, but [(table[e] + 1)^e] + will. *) + +type mode = Normal | Atomic of { out_fn : string; tmp_fn : string } + +let oc, mode = + match Sys.argv with + | [|_|] -> (stdout, Normal) + | [|_; "-o"; out_fn|] + | [|_; "-atomic"; "-o"; out_fn|] -> + (* Always produce the file atomically, we just have this option to remember that we + need to do it *) + let tmp_fn, oc = + Filename.open_temp_file + ~temp_dir:(Filename.dirname out_fn) + "generate_pow_overflow_bounds" ".ml.tmp" + in + (oc, Atomic { out_fn; tmp_fn }) + | _ -> failwith "bad command line arguments" + +module Big_int = struct + include Big_int + type t = big_int + let (>) = gt_big_int + let (<=) = le_big_int + let (^) = power_big_int_positive_int + let (-) = sub_big_int + let (+) = add_big_int + let one = unit_big_int + let sqrt = sqrt_big_int + let to_string = string_of_big_int +end + +module Array = StdLabels.Array + +type generated_type = + | Int + | Int32 + | Int63 + | Int64 + +let max_big_int_for_bits bits = + let shift = bits - 1 in (* sign bit *) + Big_int.((shift_left_big_int one shift) - one) +;; + +let safe_to_print_as_int = + let int31_max = max_big_int_for_bits 31 in + fun x -> Big_int.(x <= int31_max) + +let format_entry typ b = + let s = Big_int.to_string b in + match typ with + | Int -> + if safe_to_print_as_int b + then s + else Printf.sprintf "Caml.Int64.to_int %sL" s + | Int32 -> s ^ "l" + | Int63 + | Int64 -> s ^ "L" + +let bits = function + | Int -> assert false (* architecture dependent *) + | Int32 -> 32 + | Int63 -> 63 + | Int64 -> 64 + +let max_val typ = max_big_int_for_bits (bits typ) + +let name = function + | Int -> "int" + | Int32 -> "int32" + | Int63 -> "int63_on_int64" + | Int64 -> "int64" + +let ocaml_type_name = function + | Int -> "int" + | Int32 -> "int32" + | Int63 + | Int64 -> "int64" + +let generate_negative_bounds = function + | Int -> false + | Int32 -> false + | Int63 -> false + | Int64 -> true + +let highest_base exponent max_val = + let open Big_int in + match exponent with + | 0 | 1 -> max_val + | 2 -> sqrt max_val + | _ -> + let rec search possible_base = + if possible_base ^ exponent > max_val then + begin + let res = possible_base - one in + assert (res ^ exponent <= max_val); + res + end + else + search (possible_base + one) + in + search one +;; + +type sign = Pos | Neg + +let pr fmt = Printf.fprintf oc (fmt ^^ "\n") + +let gen_array ~typ ~bits ~sign ~indent = + let pr fmt = pr ("%*s" ^^ fmt) indent "" in + let max_val = max_big_int_for_bits bits in + let pos_bounds = Array.init 64 ~f:(fun i -> highest_base i max_val) in + let bounds = + match sign with + | Pos -> pos_bounds + | Neg -> Array.map pos_bounds ~f:Big_int.minus_big_int + in + pr "[| %s" (format_entry typ bounds.(0)); + for i = 1 to Array.length bounds - 1 do + pr "; %s" (format_entry typ bounds.(i)) + done; + pr "|]"; +;; + + +let gen_bounds typ = + pr "let overflow_bound_max_%s_value : %s =" (name typ) (ocaml_type_name typ); + (match typ with + | Int -> pr " (-1) lsr 1" + | _ -> pr " %s" (format_entry typ (max_val typ))); + pr ""; + + let array_name typ sign = + Printf.sprintf "%s_%s_overflow_bounds" (name typ) + (match sign with Pos -> "positive" | Neg -> "negative") + in + + pr "let %s : %s array =" (array_name typ Pos) (ocaml_type_name typ); + (match typ with + | Int -> + pr " match Int_conversions.num_bits_int with"; + pr " | 32 -> Array.map %s ~f:Caml.Int32.to_int" (array_name Int32 Pos); + pr " | 63 ->"; + gen_array ~typ ~bits:63 ~sign:Pos ~indent:4; + pr " | 31 ->"; + gen_array ~typ ~bits:31 ~sign:Pos ~indent:4; + pr " | _ -> assert false" + | _ -> + gen_array ~typ ~bits:(bits typ) ~sign:Pos ~indent:2); + pr ""; + + if generate_negative_bounds typ then begin + pr "let %s : %s array =" (array_name typ Neg) (ocaml_type_name typ); + gen_array ~typ ~bits:(bits typ) ~sign:Neg ~indent:2 + end; +;; + +let () = + pr "(* This file was autogenerated by %s *)" Sys.argv.(0); + pr ""; + pr "open! Import"; + pr ""; + pr "module Array = Array0"; + pr ""; + pr "(* We have to use Int64.to_int_exn instead of int constants to make"; + pr " sure that file can be preprocessed on 32-bit machines. *)"; + pr ""; + gen_bounds Int32; + gen_bounds Int; + gen_bounds Int63; + gen_bounds Int64; +;; + +let () = + match mode with + | Normal -> () + | Atomic { tmp_fn; out_fn } -> + close_out oc; + Sys.rename tmp_fn out_fn diff --git a/lint/dune b/lint/dune new file mode 100644 index 0000000..40445e6 --- /dev/null +++ b/lint/dune @@ -0,0 +1,3 @@ +(library (name ppx_base_lint) (kind ppx_rewriter) + (libraries compiler-libs.common base ppxlib) + (preprocess no_preprocessing)) \ No newline at end of file diff --git a/lint/ppx_base_lint.ml b/lint/ppx_base_lint.ml new file mode 100644 index 0000000..a012454 --- /dev/null +++ b/lint/ppx_base_lint.ml @@ -0,0 +1,105 @@ +open Ppxlib +open Base + +let error ~loc fmt = + Location.raise_errorf ~loc (Caml.(^^) "ppx_base_lint:" fmt) + +type suspicious_id = + | Caml_submodule of string + +let rec iter_suspicious (id : Longident.t) ~f = + match id with + | Ldot (Lident "Caml", s) when + String.(<>) s "" && + match s.[0] with + | 'A'..'Z' -> true + | _ -> false + -> + f (Caml_submodule s) + | Ldot (x, _) -> iter_suspicious x ~f + | Lapply (a, b) -> + iter_suspicious a ~f; + iter_suspicious b ~f + | Lident _ -> () + +let zero_modules () = + Caml.Sys.readdir "." + |> Array.to_list + |> List.filter ~f:(fun fn -> + Caml.Filename.check_suffix fn "0.ml") + |> List.map ~f:(fun fn -> + String.capitalize (String.sub fn ~pos:0 ~len:(String.length fn - 4))) + |> Set.of_list (module String) + +let check_open (id : Longident.t Asttypes.loc) = + match id.txt with + | Lident "Caml" -> + error ~loc:id.loc + "you are not allowed to open Caml inside Base" + | _ -> () + +let rec is_caml_dot_something : Longident.t -> bool = function + | Ldot (Lident "Caml", _) -> true + | Ldot (id, _) -> is_caml_dot_something id + | _ -> false + +let check current_module = + let zero_modules = zero_modules () in + object + inherit Ast_traverse.iter as super + + method! longident_loc { txt = id; loc } = + (* Note: we don't distinguish between module identifiers and constructors + names. Since there is no [Caml.String], [Caml.Array], ... constructors this is + not a problem. *) + iter_suspicious id ~f:(function + | Caml_submodule m -> + if not (Set.mem zero_modules m) then + () (* We are allowed to use Caml modules that don't have a Foo0 version *) + else if String.equal (m ^ "0") current_module then + () (* Foo0 is allowed to use Caml.Foo *) + else + match current_module with + | "Import0" | "Base" -> () + | _ -> + error ~loc + "you cannot use [Caml.%s] here, use [%s0] instead" m m) + + method! expression e = + super#expression e; + match e.pexp_desc with + | Pexp_open (_, id, _) -> check_open id + | _ -> () + + method! open_description op = + super#open_description op; + check_open op.popen_lid + + method! module_binding mb = + super#module_binding mb; + match current_module with + | "Import0" -> () + | _ -> + match mb.pmb_expr.pmod_desc with + | Pmod_ident { txt = id; _ } when is_caml_dot_something id -> + error ~loc:mb.pmb_loc + "you cannot alias [Caml] sub-modules, use them directly" + | _ -> () + end + +let module_of_loc (loc : Location.t) = + String.capitalize (Caml.Filename.chop_extension + (Caml.Filename.basename loc.loc_start.pos_fname)) + +let () = + Ppxlib.Driver.register_transformation "base_lint" + ~impl:(function + | [] -> [] + | { pstr_loc = loc; _ } :: _ as st -> + (check (module_of_loc loc))#structure st; + st) + ~intf:(function + | [] -> [] + | { psig_loc = loc; _ } :: _ as sg -> + (check (module_of_loc loc))#signature sg; + sg) diff --git a/md5/src/dune b/md5/src/dune new file mode 100644 index 0000000..a10038e --- /dev/null +++ b/md5/src/dune @@ -0,0 +1,2 @@ +(library (name md5_lib) (public_name base.md5) (preprocess no_preprocessing) + (libraries) (js_of_ocaml (javascript_files))) \ No newline at end of file diff --git a/md5/src/md5_lib.ml b/md5/src/md5_lib.ml new file mode 100644 index 0000000..52eccc1 --- /dev/null +++ b/md5/src/md5_lib.ml @@ -0,0 +1,26 @@ +type t = string + +(* Share the digest of the empty string *) +let empty = Digest.string "" +let make s = + if s = empty then + empty + else + s + +let compare = compare + +let length = 16 + +let to_binary s = s +let of_binary_exn s = assert (String.length s = length); make s +let unsafe_of_binary = make + +let to_hex = Digest.to_hex +let of_hex_exn s = make (Digest.from_hex s) + +let string s = make (Digest.string s) + +let bytes s = make (Digest.bytes s) + +let subbytes bytes ~pos ~len = make (Digest.subbytes bytes pos len) diff --git a/md5/src/md5_lib.mli b/md5/src/md5_lib.mli new file mode 100644 index 0000000..f784eb9 --- /dev/null +++ b/md5/src/md5_lib.mli @@ -0,0 +1,21 @@ +type t + +val compare : t -> t -> int + +(** [length = 16] is the size of the digest in bytes. *) +val length : int + +val to_binary : t -> string +val of_binary_exn : string -> t + +(** assumes the input is 16 bytes without checking *) +val unsafe_of_binary : string -> t + +val to_hex : t -> string +val of_hex_exn : string -> t + +val string : string -> t + +val bytes : bytes -> t + +val subbytes : bytes -> pos:int -> len:int -> t diff --git a/shadow-stdlib/gen/dune b/shadow-stdlib/gen/dune new file mode 100644 index 0000000..9e34b3a --- /dev/null +++ b/shadow-stdlib/gen/dune @@ -0,0 +1,4 @@ +(executables (names gen) (libraries str compiler-libs.common caml) + (link_flags -linkall) (preprocess no_preprocessing)) + +(ocamllex mapper) \ No newline at end of file diff --git a/shadow-stdlib/gen/gen.ml b/shadow-stdlib/gen/gen.ml new file mode 100644 index 0000000..a763c86 --- /dev/null +++ b/shadow-stdlib/gen/gen.ml @@ -0,0 +1,36 @@ +open StdLabels + +let () = + let cmi_fn, oc = + match Sys.argv with + | [|_; "-caml-cmi"; cmi_fn; "-o"; fn|] -> + (cmi_fn, open_out fn) + | [|_; "-caml-cmi"; cmi_fn1; cmi_fn2; "-o"; fn|] -> + let cmi_fn = + if Sys.file_exists cmi_fn1 then + cmi_fn1 + else + cmi_fn2 + in + (cmi_fn, open_out fn) + | _ -> + failwith "bad command line arguments" + in + + try + let cmi = Cmi_format.read_cmi cmi_fn in + let buf = Buffer.create 512 in + let pp = Format.formatter_of_buffer buf in + Format.pp_set_margin pp max_int; (* so we can parse line by line below *) + Format.fprintf pp "%a@." Printtyp.signature cmi.Cmi_format.cmi_sign; + let s = Buffer.contents buf in + let lines = Str.split (Str.regexp "\n") s in + Printf.fprintf oc "[@@@warning \"-3\"]\n\n"; + List.iter lines ~f:(fun line -> + let repl = Mapper.line (Lexing.from_string line) in + if repl <> "" then + Printf.fprintf oc "%s\n\n" repl); + flush oc + with exn -> + Location.report_exception Format.err_formatter exn; + exit 2 diff --git a/shadow-stdlib/gen/mapper.mll b/shadow-stdlib/gen/mapper.mll new file mode 100644 index 0000000..3de7c37 --- /dev/null +++ b/shadow-stdlib/gen/mapper.mll @@ -0,0 +1,267 @@ +{ +open StdLabels +open Printf + +module String = struct + [@@@warning "-32-3"] + let capitalize_ascii = String.capitalize + let uncapitalize_ascii = String.uncapitalize + let uppercase_ascii = String.uppercase + let lowercase_ascii = String.lowercase + include String +end + +let deprecated_msg ~is_exn what = + sprintf + "[%sdeprecated \"\\\n\ + [2016-09] this element comes from the stdlib distributed with OCaml.\n\ + Referring to the stdlib directly is discouraged by Base. You should either\n\ + use the equivalent functionality offered by Base, or if you really want to\n\ + refer to the stdlib, use Caml.%s instead\"]" + (if is_exn then "@" else "@@") + what + +let deprecated_msg_no_equivalent ~is_exn what = + sprintf + "[%sdeprecated \"\\\n\ + [2016-09] this element comes from the stdlib distributed with OCaml.\n\ + There is not equivalent functionality in Base or Stdio at the moment,\n\ + so you need to use [Caml.%s] instead\"]" + (if is_exn then "@" else "@@") + what + +let deprecated_msg_with_repl_text ~is_exn text = + sprintf + "[%sdeprecated \"\\\n\ + [2016-09] this element comes from the stdlib distributed with OCaml.\n\ + %s.\"]" + (if is_exn then "@" else "@@") + text + +let deprecated_msg_with_repl ~is_exn repl = + deprecated_msg_with_repl_text ~is_exn (sprintf "Use [%s] instead" repl) + +let deprecated_msg_with_approx_repl ~is_exn ~id repl = + sprintf + "[%sdeprecated \"\\\n\ + [2016-09] this element comes from the stdlib distributed with OCaml.\n\ + There is no equivalent functionality in Base or Stdio but you can use\n\ + [%s] instead.\n\ + Alternatively, if you really want to refer the stdlib function you can\n\ + use [Caml.%s].\"]" + (if is_exn then "@" else "@@") + repl id + +type replacement = + | No_equivalent + | Repl of string + | Repl_text of string + | Approx of string + +let val_replacement = function + | "( != )" -> Repl "not (phys_equal ...)" + | "( == )" -> Repl "phys_equal" + | "( ** )" -> Repl "**." + | "( mod )" -> Repl_text "Use (%), which has slightly different \ + semantics, or Int.rem which is equivalent" + | "acos" -> Repl "Float.acos" + | "asin" -> Repl "Float.asin" + | "atan" -> Repl "Float.atan" + | "atan2" -> Repl "Float.atan2" + | "bool_of_string" -> Repl "Bool.of_string" + | "ceil" -> Repl "Float.round_up" + | "char_of_int" -> Repl "Char.of_int_exn" + | "classify_float" -> Repl "Float.classify" + | "close_in" -> Repl "Stdio.In_channel.close" + | "close_in_noerr" -> Repl "Stdio.In_channel.close" + | "close_out" -> Repl "Stdio.Out_channel.close" + | "close_out_noerr" -> Repl "Stdio.Out_channel.close" + | "copysign" -> Repl "Float.copysign" + | "cos" -> Repl "Float.cos" + | "cosh" -> Repl "Float.cosh" + | "decr" -> Repl "Int.decr" + | "epsilon_float" -> Repl "Float.epsilon_float" + | "exp" -> Repl "Float.exp" + | "expm1" -> Repl "Float.expm1" + | "float" -> Repl "Float.of_int" + | "float_of_int" -> Repl "Float.of_int" + | "float_of_string" -> Repl "Float.of_string" + | "floor" -> Repl "Float.round_down" + | "flush" -> Repl "Stdio.Out_channel.flush" + | "flush_all" -> No_equivalent + | "frexp" -> Repl "Float.frexp" + | "hypot" -> Repl "Float.hypot" + | "in_channel_length" -> Repl "Stdio.In_channel.length" + | "infinity" -> Repl "Float.infinity" + | "incr" -> Repl "Int.incr" + | "input" -> Repl "Stdio.In_channel.input" + | "input_binary_int" -> Repl "Stdio.In_channel.input_binary_int" + | "input_byte" -> Repl "Stdio.In_channel.input_byte" + | "input_char" -> Repl "Stdio.In_channel.input_char" + | "input_line" -> Repl "Stdio.In_channel.input_line" + | "input_value" -> Repl "Stdio.In_channel.unsafe_input_value" + | "int_of_char" -> Repl "Char.to_int" + | "int_of_float" -> Repl "Int.of_float" + | "int_of_string" -> Repl "Int.of_string" + | "ldexp" -> Repl "Float.ldexp" + | "log" -> Repl "Float.log" + | "log10" -> Repl "Float.log10" + | "log1p" -> Repl "Float.log1p" + | "max_float" -> Repl "Float.max_finite_value" + | "max_int" -> Repl "Int.max_value" + | "min_float" -> Repl "Float.min_positive_normal_value" + | "min_int" -> Repl "Int.min_value" + | "mod_float" -> Repl "Float.mod_float" + | "modf" -> Repl "Float.modf" + | "nan" -> Repl "Float.nan" + | "neg_infinity" -> Repl "Float.neg_infinity" + | "open_in" -> Repl "Stdio.In_channel.create" + | "open_in_bin" -> Repl "Stdio.In_channel.create" + | "open_in_gen" -> No_equivalent + | "open_out" -> Repl "Stdio.Out_channel.create" + | "open_out_bin" -> Repl "Stdio.Out_channel.create" + | "open_out_gen" -> No_equivalent + | "out_channel_length" -> Repl "Stdio.Out_channel.length" + | "output" -> Repl "Stdio.Out_channel.output" + | "output_binary_int" -> Repl "Stdio.Out_channel.output_binary_int" + | "output_byte" -> Repl "Stdio.Out_channel.output_byte" + | "output_bytes" -> Repl "Stdio.Out_channel.output_bytes" + | "output_char" -> Repl "Stdio.Out_channel.output_char" + | "output_string" -> Repl "Stdio.Out_channel.output_string" + | "output_substring" -> Repl "Stdio.Out_channel.output" + | "output_value" -> Repl "Stdio.Out_channel.output_value" + | "pos_in" -> Repl "Stdio.In_channel.pos" + | "pos_out" -> Repl "Stdio.Out_channel.pos" + | "pred" -> Repl "Int.pred" + | "prerr_bytes" -> Repl "Stdio.Out_channel.output_bytes Stdio.stderr" + | "prerr_char" -> Repl "Stdio.Out_channel.output_char Stdio.stderr" + | "prerr_endline" -> Repl "Stdio.prerr_endline" + | "prerr_float" -> Repl "Stdio.eprintf \"%f\"" + | "prerr_int" -> Repl "Stdio.eprintf \"%d\"" + | "prerr_newline" -> Repl "Stdio.eprintf \"\n%!\"" + | "prerr_string" -> Repl "Stdio.Out_channel.output_string Stdio.stderr" + | "print_bytes" -> Repl "Stdio.Out_channel.output_bytes Stdio.stdout" + | "print_char" -> Repl "Stdio.Out_channel.output_char Stdio.stdout" + | "print_endline" -> Repl "Stdio.print_endline" + | "print_float" -> Repl "Stdio.eprintf \"%f\"" + | "print_int" -> Repl "Stdio.eprintf \"%d\"" + | "print_newline" -> Repl "Stdio.eprintf \"\n%!\"" + | "print_string" -> Repl "Stdio.Out_channel.output_string Stdio.stdout" + | "read_float" -> No_equivalent + | "read_int" -> No_equivalent + | "read_line" -> Repl "Stdio.In_channel.input_line" + | "really_input" -> Repl "Stdio.In_channel.really_input" + | "really_input_string" -> Approx "Stdio.Out_channel" + | "seek_in" -> Repl "Stdio.In_channel.seek" + | "seek_out" -> Repl "Stdio.Out_channel.seek" + | "set_binary_mode_in" -> Repl "Stdio.In_channel.set_binary_mode" + | "set_binary_mode_out" -> Repl "Stdio.Out_channel.set_binary_mode" + | "sin" -> Repl "Float.sin" + | "sinh" -> Repl "Float.sinh" + | "sqrt" -> Repl "Float.sqrt" + | "stderr" -> Repl "Stdio.stderr" + | "stdin" -> Repl "Stdio.stdin" + | "stdout" -> Repl "Stdio.stdout" + | "string_of_bool" -> Repl "Bool.to_string" + | "string_of_float" -> Repl "Float.to_string" + | "string_of_int" -> Repl "Int.to_string" + | "succ" -> Repl "Int.succ" + | "tan" -> Repl "Float.tan" + | "tanh" -> Repl "Float.tanh" + | "truncate" -> Repl "Int.of_float" + (* This is documented as DO-NOT-USE in the stdlib *) + | "unsafe_really_input" -> No_equivalent + | _ -> No_equivalent +;; + +let exception_replacement = function + | "Not_found" -> + Some (Repl_text "\ +Instead of raising [Not_found], consider using [raise_s] with an informative error\n\ +message. If code needs to distinguish [Not_found] from other exceptions, please change\n\ +it to handle both [Not_found] and [Not_found_s]. Then, instead of raising [Not_found],\n\ +raise [Not_found_s] with an informative error message") + | _ -> None + +let module_replacement = function + | "Printexc" -> Some (Repl_text "Use [Exn] or [Backtrace] instead") + | "Format" -> + let repl_text = + "[Base] doesn't export a [Format] module, although the \n\ + [Caml.Format.formatter] type is available (as [Formatter.t])\n\ + for interaction with other libraries" + in + Some (Repl_text repl_text) + | "Fun" -> Some (Repl_text "Use [Fn] instead") + | _ -> None + +let replace ~is_exn id replacement line = + let msg = + match replacement with + | No_equivalent -> deprecated_msg_no_equivalent ~is_exn id + | Repl repl -> deprecated_msg_with_repl ~is_exn repl + | Repl_text text -> deprecated_msg_with_repl_text ~is_exn text + | Approx repl -> deprecated_msg_with_approx_repl ~is_exn repl ~id + in + sprintf "%s\n%s" line msg +;; +} + +let id_trail = ['a'-'z' 'A'-'Z' '_' '0'-'9']* + let id = ['a'-'z' 'A'-'Z' '_' '0'-'9'] id_trail +let val_id = id | '(' [^ ')']* ')' +let params = ('(' [^')']* ')' | ['+' '-']? '\'' id) " " + +let val_ = "val " | "external " + +rule line = parse + | "module Camlinternal" _* + { "" (* We can't deprecate these *) } + | "module Bigarray" _* { "" (* Don't deprecate it yet *) } + | "type " (params? (id as id) _* as def) + { sprintf "type nonrec %s\n%s" def + (match id with + | "in_channel" -> deprecated_msg_with_repl ~is_exn:false "Stdio.In_channel.t" + | "out_channel" -> deprecated_msg_with_repl ~is_exn:false "Stdio.Out_channel.t" + | _ -> deprecated_msg ~is_exn:false id) + } + + | val_ (val_id as id) _* as line { replace ~is_exn:false id (val_replacement id) line } + + | "module " (id as id) " = Stdlib__" (id as id2) (_* as line) + { + let line = + Printf.sprintf "module %s = Stdlib.%s %s" + id (String.capitalize_ascii id2) line in + match module_replacement id with + | Some replacement -> replace ~is_exn:false id replacement line + | None -> sprintf "%s\n%s" line (deprecated_msg ~is_exn:false id) } + + | "exception " (id as id) _* as line + { match exception_replacement id with + | Some replacement -> replace ~is_exn:true id replacement line + | None -> + let predefined_exceptions = + [ "Out_of_memory" + ; "Sys_error" + ; "Failure" + ; "Invalid_argument" + ; "End_of_file" + ; "Division_by_zero" + ; "Not_found" + ; "Match_failure" + ; "Stack_overflow" + ; "Sys_blocked_io" + ; "Assert_failure" + ; "Undefined_recursive_module" ] + in + if List.mem id ~set:predefined_exceptions + then "" + else sprintf "%s\n%s" line (deprecated_msg ~is_exn:true id) + } + | "module " (id as id) _* as line + { match module_replacement id with + | Some replacement -> replace ~is_exn:false id replacement line + | None -> sprintf "%s\n%s" line (deprecated_msg ~is_exn:false id) } + | _* as line + { ksprintf failwith "cannot parse this: %s" line } diff --git a/shadow-stdlib/src/dune b/shadow-stdlib/src/dune new file mode 100644 index 0000000..edea282 --- /dev/null +++ b/shadow-stdlib/src/dune @@ -0,0 +1,10 @@ +(library (name shadow_stdlib) (public_name base.shadow_stdlib) + (libraries caml) (preprocess no_preprocessing)) + +(rule (targets shadow_stdlib.mli) + (deps (:first_dep ../gen/gen.exe) + ../../compiler-stdlib/src/caml.cma) + (action + (run %{first_dep} -caml-cmi ../../compiler-stdlib/src/.caml.objs/caml.cmi + ../../compiler-stdlib/src/.caml.objs/byte/caml.cmi + -o %{targets}))) diff --git a/shadow-stdlib/src/shadow_stdlib.ml b/shadow-stdlib/src/shadow_stdlib.ml new file mode 100644 index 0000000..62ab439 --- /dev/null +++ b/shadow-stdlib/src/shadow_stdlib.ml @@ -0,0 +1 @@ +include Caml diff --git a/src/am_testing.c b/src/am_testing.c new file mode 100644 index 0000000..30ce05c --- /dev/null +++ b/src/am_testing.c @@ -0,0 +1,9 @@ +#include + +/* The default [Base_am_testing] value is [false]. [ppx_inline_test] overrides + the default by linking against an implementation of [Base_am_testing] that + returns [true]. */ +CAMLprim CAMLweakdef value Base_am_testing() +{ + return Val_false; +} diff --git a/src/am_testing.h b/src/am_testing.h new file mode 100644 index 0000000..cb5d923 --- /dev/null +++ b/src/am_testing.h @@ -0,0 +1,11 @@ +#ifndef BASE_AM_TESTING_H +#define BASE_AM_TESTING_H +#include + +CAMLprim value Base_am_testing (); + +static inline int am_testing () { + return Bool_val (Base_am_testing ()); +} + +#endif diff --git a/src/applicative.ml b/src/applicative.ml new file mode 100644 index 0000000..3a32955 --- /dev/null +++ b/src/applicative.ml @@ -0,0 +1,159 @@ +open! Import + +include Applicative_intf + +(** This module serves mostly as a partial check that [S2] and [S] are in sync, but + actually calling it is occasionally useful. *) +module S_to_S2 (X : S) : (S2 with type ('a, 'e) t = 'a X.t) = struct + type ('a, 'e) t = 'a X.t + include (X : S with type 'a t := 'a X.t) +end + +module S2_to_S (X : S2) : (S with type 'a t = ('a, unit) X.t) = struct + type 'a t = ('a, unit) X.t + include (X : S2 with type ('a, 'e) t := ('a, 'e) X.t) +end + +module Args_to_Args2 (X : Args) : ( + Args2 with type ('a, 'e) arg = 'a X.arg + with type ('f, 'r, 'e) t = ('f, 'r) X.t +) = struct + type ('a, 'e) arg = 'a X.arg + type ('f, 'r, 'e) t = ('f, 'r) X.t + include (X : Args with type 'a arg := 'a X.arg and type ('f, 'r) t := ('f, 'r) X.t) +end +[@@warning "-3"] + +module Make2 (X : Basic2) : S2 with type ('a, 'e) t := ('a, 'e) X.t = struct + + include X + + let (<*>) = apply + + let derived_map t ~f = return f <*> t + + let map = + match X.map with + | `Define_using_apply -> derived_map + | `Custom x -> x + + let ( >>|) t f = map t ~f + + let map2 ta tb ~f = map ~f ta <*> tb + + let map3 ta tb tc ~f = map ~f ta <*> tb <*> tc + + let all ts = List.fold_right ts ~init:(return []) ~f:(map2 ~f:(fun x xs -> x :: xs)) + + let both ta tb = map2 ta tb ~f:(fun a b -> (a, b)) + + let ( *> ) u v = return (fun () y -> y) <*> u <*> v + let ( <* ) u v = return (fun x () -> x) <*> u <*> v + + let all_unit ts = List.fold ts ~init:(return ()) ~f:( *> ) + let all_ignore = all_unit + + module Applicative_infix = struct + let ( <*> ) = ( <*> ) + let ( *> ) = ( *> ) + let ( <* ) = ( <* ) + let ( >>| ) = ( >>| ) + end +end + +module Make (X : Basic) : S with type 'a t := 'a X.t = + Make2 (struct + type ('a, 'e) t = 'a X.t + include (X : Basic with type 'a t := 'a X.t) + end) + +module Make_let_syntax (X : For_let_syntax) (Intf : sig module type S end) (Impl : Intf.S) = struct + module Let_syntax = struct + include X + module Let_syntax = struct + include X + module Open_on_rhs = Impl + end + end +end + +module Make2_using_map2 (X : Basic2_using_map2) = + Make2 (struct + include X + let apply tf tx = map2 tf tx ~f:(fun f x -> f x) + let map = + match map with + | `Custom map -> `Custom map + | `Define_using_map2 -> `Define_using_apply + end) + +module Make_using_map2 (X : Basic_using_map2) : S with type 'a t := 'a X.t = + Make2_using_map2 (struct + type ('a, 'e) t = 'a X.t + include (X : Basic_using_map2 with type 'a t := 'a X.t) + end) + +module Make_args' (X : S2) = struct + open X + + type ('f, 'r, 'e) t_ = { applyN : ('f, 'e) X.t -> ('r, 'e) X.t } + + let nil = { applyN = Fn.id } + + let cons arg t = { applyN = fun d -> t.applyN (apply d arg) } + + let step t ~f = { applyN = fun d -> t.applyN (map ~f d) } + + let (@>) = cons + + let applyN arg t = t.applyN arg + + let mapN ~f t = applyN (return f) t +end + +module Make_args (X : S) : Args with type 'a arg := 'a X.t = struct + include Make_args' (struct + type ('a, 'e) t = 'a X.t + include (X : S with type 'a t := 'a X.t) + end) + + type ('f, 'r) t = ('f, 'r, unit) t_ +end +[@@warning "-3"] + +module Make_args2 (X : S2) : Args2 with type ('a, 'e) arg := ('a, 'e) X.t = struct + include Make_args' (X) + + type ('f, 'r, 'e) t = ('f, 'r, 'e) t_ +end +[@@warning "-3"] + +module Of_monad (M : Monad.S) : S with type 'a t := 'a M.t = + Make (struct + type 'a t = 'a M.t + let return = M.return + let apply mf mx = M.bind mf ~f:(fun f -> M.map mx ~f) + let map = `Custom M.map + end) + +module Compose (F : S) (G : S) : S with type 'a t = 'a F.t G.t = struct + type 'a t = 'a F.t G.t + include Make (struct + type nonrec 'a t = 'a t + let return a = G.return (F.return a) + let apply tf tx = G.apply (G.map ~f:F.apply tf) tx + let custom_map t ~f = G.map ~f:(F.map ~f) t + let map = `Custom custom_map + end) +end + +module Pair (F : S) (G : S) : S with type 'a t = 'a F.t * 'a G.t = struct + type 'a t = 'a F.t * 'a G.t + include Make (struct + type nonrec 'a t = 'a t + let return a = (F.return a, G.return a) + let apply tf tx = (F.apply (fst tf) (fst tx), G.apply (snd tf) (snd tx)) + let custom_map t ~f = (F.map ~f (fst t), G.map ~f (snd t)) + let map = `Custom custom_map + end) +end diff --git a/src/applicative.mli b/src/applicative.mli new file mode 100644 index 0000000..3a4f2d2 --- /dev/null +++ b/src/applicative.mli @@ -0,0 +1 @@ +include Applicative_intf.Applicative (** @inline *) diff --git a/src/applicative_intf.ml b/src/applicative_intf.ml new file mode 100644 index 0000000..e866442 --- /dev/null +++ b/src/applicative_intf.ml @@ -0,0 +1,318 @@ +(** Applicatives model computations in which values computed by subcomputations cannot + affect what subsequent computations will take place. + + Relative to monads, this restriction takes power away from the user of the interface + and gives it to the implementation. In particular, because the structure of the + entire computation is known, one can augment its definition with some description of + that structure. + + For more information, see: + + {v + Applicative Programming with Effects. + Conor McBride and Ross Paterson. + Journal of Functional Programming 18:1 (2008), pages 1-13. + http://staff.city.ac.uk/~ross/papers/Applicative.pdf + v} *) + +open! Import + +module type Basic = sig + type 'a t + val return : 'a -> 'a t + val apply : ('a -> 'b) t -> 'a t -> 'b t + (** The following identities ought to hold for every Applicative (for some value of =): + + - identity: [return Fn.id <*> t = t] + - composition: [return Fn.compose <*> tf <*> tg <*> tx = tf <*> (tg <*> tx)] + - homomorphism: [return f <*> return x = return (f x)] + - interchange: [tf <*> return x = return (fun f -> f x) <*> tf] + + Note: <*> is the infix notation for apply. *) + + (** The [map] argument to [Applicative.Make] says how to implement the applicative's + [map] function. [`Define_using_apply] means to define [map t ~f = return f <*> t]. + [`Custom] overrides the default implementation, presumably with something more + efficient. + + Some other functions returned by [Applicative.Make] are defined in terms of [map], + so passing in a more efficient [map] will improve their efficiency as well. *) + val map : [`Define_using_apply | `Custom of ('a t -> f:('a -> 'b) -> 'b t)] +end + +module type Basic_using_map2 = sig + type 'a t + val return : 'a -> 'a t + val map2 : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t + val map : [`Define_using_map2 | `Custom of ('a t -> f:('a -> 'b) -> 'b t)] +end + +module type Applicative_infix = sig + type 'a t + + val ( <*> ) : ('a -> 'b) t -> 'a t -> 'b t (** same as [apply] *) + + val ( <* ) : 'a t -> unit t -> 'a t + val ( *> ) : unit t -> 'a t -> 'a t + + val ( >>| ) : 'a t -> ('a -> 'b) -> 'b t +end + +module type For_let_syntax = sig + type 'a t + + val return : 'a -> 'a t + + val map : 'a t -> f:('a -> 'b) -> 'b t + + val both : 'a t -> 'b t -> ('a * 'b) t + + include Applicative_infix with type 'a t := 'a t + +end + +module type S = sig + + include For_let_syntax + + val apply : ('a -> 'b) t -> 'a t -> 'b t + + val map2 : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t + + val map3 : 'a t -> 'b t -> 'c t -> f:('a -> 'b -> 'c -> 'd) -> 'd t + + val all : 'a t list -> 'a list t + + val all_unit : unit t list -> unit t + + val all_ignore : unit t list -> unit t [@@deprecated "[since 2018-02] Use [all_unit]"] + + module Applicative_infix : Applicative_infix with type 'a t := 'a t +end + +module type Let_syntax = sig + type 'a t + + module Open_on_rhs_intf : sig + module type S + end + + module Let_syntax : sig + + val return : 'a -> 'a t + include Applicative_infix with type 'a t := 'a t + + module Let_syntax : sig + + val return : 'a -> 'a t + + val map : 'a t -> f:('a -> 'b) -> 'b t + + val both : 'a t -> 'b t -> ('a * 'b) t + + module Open_on_rhs : Open_on_rhs_intf.S + end + end +end + +(** Argument lists and associated N-ary map and apply functions. *) +module type Args = sig + + type 'a arg (** the underlying applicative *) + + (** ['f] is the type of a function that consumes the list of arguments and returns an + ['r]. *) + type ('f, 'r) t + + (** the empty argument list **) + val nil : ('r, 'r) t + + (** prepend an argument *) + val cons : 'a arg -> ('f, 'r) t -> ('a -> 'f, 'r) t + + (** infix operator for [cons] *) + val (@>) : 'a arg -> ('f, 'r) t -> ('a -> 'f, 'r) t + + (** Transform argument values in some way. For example, one can label a function + argument like so: + + {[ + step ~f:(fun f x -> f ~foo:x) : ('a -> 'r1, 'r2) t -> (foo:'a -> 'r1, 'r2) t + ]} *) + val step : ('f1, 'r) t -> f:('f2 -> 'f1) -> ('f2, 'r) t + + (** The preferred way to factor out an [Args] sub-sequence: + + {[ + let args = + Foo.Args.( + bar "A" + (* TODO: factor out the common baz qux sub-sequence *) + @> baz "B" + @> qux "C" + @> zap "D" + @> nil + ) + ]} + + is to write a function that prepends the sub-sequence: + + {[ + let baz_qux remaining_args = + Foo.Args.( + baz "B" + @> qux "C" + @> remaining_args + ) + ]} + + and splice it back into the original sequence using [@@] so that things line up + nicely: + + {[ + let args = + Foo.Args.( + bar "A" + @> baz_qux + @@ zap "D" + @> nil + ) + ]} *) + + val mapN : f:'f -> ('f, 'r) t -> 'r arg + + val applyN : 'f arg -> ('f, 'r) t -> 'r arg + +end +[@@deprecated "[since 2018-09] Use [ppx_let] instead."] + +module type Basic2 = sig + type ('a, 'e) t + val return : 'a -> ('a, _) t + val apply : ('a -> 'b, 'e) t -> ('a, 'e) t -> ('b, 'e) t + val map : [`Define_using_apply | `Custom of (('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t)] +end + +module type Basic2_using_map2 = sig + type ('a, 'e) t + val return : 'a -> ('a, _) t + val map2 : ('a, 'e) t -> ('b, 'e) t -> f:('a -> 'b -> 'c) -> ('c, 'e) t + val map : [`Define_using_map2 | `Custom of (('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t)] +end + +module type S2 = sig + type ('a, 'e) t + + val return : 'a -> ('a, _) t + + val apply : ('a -> 'b, 'e) t -> ('a, 'e) t -> ('b, 'e) t + + val map : ('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t + + val map2 : ('a, 'e) t -> ('b, 'e) t -> f:('a -> 'b -> 'c) -> ('c, 'e) t + + val map3 + : ('a, 'e) t + -> ('b, 'e) t + -> ('c, 'e) t + -> f:('a -> 'b -> 'c -> 'd) + -> ('d, 'e) t + + val all : ('a, 'e) t list -> ('a list, 'e) t + + val all_unit : (unit, 'e) t list -> (unit, 'e) t + + val all_ignore : (unit, 'e) t list -> (unit, 'e) t + [@@deprecated "[since 2018-02] Use [all_unit]"] + + val both : ('a, 'e) t -> ('b, 'e) t -> ('a * 'b, 'e) t + + module Applicative_infix : sig + val ( <*> ) : ('a -> 'b, 'e) t -> ('a, 'e) t -> ('b, 'e) t + val ( <* ) : ('a, 'e) t -> (unit, 'e) t -> ('a, 'e) t + val ( *> ) : (unit, 'e) t -> ('a, 'e) t -> ('a, 'e) t + val ( >>| ) : ('a, 'e) t -> ('a -> 'b) -> ('b, 'e) t + end + + include module type of Applicative_infix +end + +module type Args2 = sig + type ('a, 'e) arg + + type ('f, 'r, 'e) t + + val nil : ('r, 'r, _) t + + val cons : ('a, 'e) arg -> ('f, 'r, 'e) t -> ('a -> 'f, 'r, 'e) t + val (@>) : ('a, 'e) arg -> ('f, 'r, 'e) t -> ('a -> 'f, 'r, 'e) t + + val step : ('f1, 'r, 'e) t -> f:('f2 -> 'f1) -> ('f2, 'r, 'e) t + + val mapN : f:'f -> ('f, 'r, 'e) t -> ('r, 'e) arg + val applyN : ('f, 'e) arg -> ('f, 'r, 'e) t -> ('r, 'e) arg +end +[@@deprecated "[since 2018-09] Use [ppx_let] instead."] + +module type Applicative = sig + + module type Applicative_infix = Applicative_infix + module type Args = Args + [@@warning "-3"] + [@@deprecated "[since 2018-09] Use [ppx_let] instead."] + module type Args2 = Args2 + [@@warning "-3"] + [@@deprecated "[since 2018-09] Use [ppx_let] instead."] + module type Basic = Basic + module type Basic2 = Basic2 + module type Basic2_using_map2 = Basic2_using_map2 + module type Basic_using_map2 = Basic_using_map2 + module type Let_syntax = Let_syntax + module type S = S + module type S2 = S2 + + module Args_to_Args2 (X : Args) : + Args2 + with type ('a, 'e) arg = 'a X.arg + with type ('f, 'r, 'e) t = ('f, 'r) X.t + [@@warning "-3"] + + module S2_to_S (X : S2) : S with type 'a t = ('a, unit) X.t + + module S_to_S2 (X : S) : S2 with type ('a, 'e) t = 'a X.t + + module Make (X : Basic ) : S with type 'a t := 'a X.t + module Make2 (X : Basic2) : S2 with type ('a, 'e) t := ('a, 'e) X.t + + module Make_let_syntax + (X : For_let_syntax) + (Intf : sig module type S end) + (Impl : Intf.S) + : Let_syntax with type 'a t := 'a X.t + with module Open_on_rhs_intf := Intf + + module Make_using_map2 (X : Basic_using_map2 ) : S with type 'a t := 'a X.t + module Make2_using_map2 (X : Basic2_using_map2) : S2 with type ('a, 'e) t := ('a, 'e) X.t + + module Make_args (X : S ) : Args with type 'a arg := 'a X.t [@@warning "-3"] + [@@deprecated "[since 2018-09] Use [ppx_let] instead."] + module Make_args2 (X : S2) : Args2 with type ('a, 'e) arg := ('a, 'e) X.t [@@warning "-3"] + [@@deprecated "[since 2018-09] Use [ppx_let] instead."] + + (** The following functors give a sense of what Applicatives one can define. + + Of these, [Of_monad] is likely the most useful. The others are mostly didactic. *) + + (** Every monad is Applicative via: + + {[ + let apply mf mx = + mf >>= fun f -> + mx >>| fun x -> + f x + ]} *) + module Of_monad (M : Monad.S) : S with type 'a t := 'a M.t + module Compose (F : S) (G : S) : S with type 'a t = 'a F.t G.t + module Pair (F : S) (G : S) : S with type 'a t = 'a F.t * 'a G.t + +end diff --git a/src/array.ml b/src/array.ml new file mode 100644 index 0000000..5c94b1f --- /dev/null +++ b/src/array.ml @@ -0,0 +1,731 @@ +open! Import + +include Array0 +module Int = Int0 + +let raise_s = Error.raise_s + +type 'a t = 'a array [@@deriving_inline compare, sexp] +let compare : 'a . ('a -> 'a -> int) -> 'a t -> 'a t -> int = compare_array +let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + array_of_sexp +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + sexp_of_array +[@@@end] + +(* This module implements a new in-place, constant heap sorting algorithm to replace the + one used by the standard libraries. Its only purpose is to be faster (hopefully + strictly faster) than the base sort and stable_sort. + + At a high level the algorithm is: + - pick two pivot points by: + - pick 5 arbitrary elements from the array + - sort them within the array + - take the elements on either side of the middle element of the sort as the pivots + - sort the array with: + - all elements less than pivot1 to the left (range 1) + - all elements >= pivot1 and <= pivot2 in the middle (range 2) + - all elements > pivot2 to the right (range 3) + - if pivot1 and pivot2 are equal, then the middle range is sorted, so ignore it + - recurse into range 1, 2 (if pivot1 and pivot2 are unequal), and 3 + - during recursion there are two inflection points: + - if the size of the current range is small, use insertion sort to sort it + - if the stack depth is large, sort the range with heap-sort to avoid n^2 worst-case + behavior + + See the following for more information: + - "Dual-Pivot Quicksort" by Vladimir Yaroslavskiy. + Available at + http://www.kriche.com.ar/root/programming/spaceTimeComplexity/DualPivotQuicksort.pdf + - "Quicksort is Optimal" by Sedgewick and Bentley. + Slides at http://www.cs.princeton.edu/~rs/talks/QuicksortIsOptimal.pdf + - http://www.sorting-algorithms.com/quick-sort-3-way *) + +module Sort = struct + (* For the sake of speed we could use unsafe get/set throughout, but speed tests don't + show a significant improvement. *) + let get = get + let set = set + + let swap arr i j = + let tmp = get arr i in + set arr i (get arr j); + set arr j tmp + ;; + + module type Sort = sig + val sort + : 'a t + -> compare:('a -> 'a -> int) + -> left:int (* leftmost index of sub-array to sort *) + -> right:int (* rightmost index of sub-array to sort *) + -> unit + end + + (* http://en.wikipedia.org/wiki/Insertion_sort *) + module Insertion_sort : Sort = struct + let sort arr ~compare ~left ~right = + (* loop invariant: + [arr] is sorted from [left] to [pos - 1], inclusive *) + for pos = left + 1 to right do + (* loop invariants: + 1. the subarray arr[left .. i-1] is sorted + 2. the subarray arr[i+1 .. pos] is sorted and contains only elements > v + 3. arr[i] may be thought of as containing v + + Note that this does not allocate a closure, but is left in the for + loop for the readability of the documentation. *) + let rec loop arr ~left ~compare i v = + let i_next = i - 1 in + if i_next >= left && compare (get arr i_next) v > 0 then begin + set arr i (get arr i_next); + loop arr ~left ~compare i_next v + end else + i + in + let v = get arr pos in + let final_pos = loop arr ~left ~compare pos v in + set arr final_pos v + done + ;; + end + + (* http://en.wikipedia.org/wiki/Heapsort *) + module Heap_sort : Sort = struct + (* loop invariant: + root's children are both either roots of max-heaps or > right *) + let rec heapify arr ~compare root ~left ~right = + let relative_root = root - left in + let left_child = (2 * relative_root) + left + 1 in + let right_child = (2 * relative_root) + left + 2 in + let largest = + if left_child <= right && compare (get arr left_child) (get arr root) > 0 + then left_child + else root + in + let largest = + if right_child <= right && compare (get arr right_child) (get arr largest) > 0 + then right_child + else largest + in + if largest <> root then begin + swap arr root largest; + heapify arr ~compare largest ~left ~right + end; + ;; + + let build_heap arr ~compare ~left ~right = + (* Elements in the second half of the array are already heaps of size 1. We move + through the first half of the array from back to front examining the element at + hand, and the left and right children, fixing the heap property as we go. *) + for i = (left + right) / 2 downto left do + heapify arr ~compare i ~left ~right; + done; + ;; + + let sort arr ~compare ~left ~right = + build_heap arr ~compare ~left ~right; + (* loop invariants: + 1. the subarray arr[left ... i] is a max-heap H + 2. the subarray arr[i+1 ... right] is sorted (call it S) + 3. every element of H is less than every element of S *) + for i = right downto left + 1 do + swap arr left i; + heapify arr ~compare left ~left ~right:(i - 1); + done; + ;; + end + + (* http://en.wikipedia.org/wiki/Introsort *) + module Intro_sort : sig + include Sort + val five_element_sort + : 'a t -> compare:('a -> 'a -> int) -> int -> int -> int -> int -> int -> unit + end = struct + + let five_element_sort arr ~compare m1 m2 m3 m4 m5 = + let compare_and_swap i j = + if compare (get arr i) (get arr j) > 0 then swap arr i j + in + (* optimal 5-element sorting network *) + compare_and_swap m1 m2; (* 1--o-----o-----o--------------1 *) + compare_and_swap m4 m5; (* | | | *) + compare_and_swap m1 m3; (* 2--o-----|--o--|-----o--o-----2 *) + compare_and_swap m2 m3; (* | | | | | *) + compare_and_swap m1 m4; (* 3--------o--o--|--o--|--o-----3 *) + compare_and_swap m3 m4; (* | | | *) + compare_and_swap m2 m5; (* 4-----o--------o--o--|-----o--4 *) + compare_and_swap m2 m3; (* | | | *) + compare_and_swap m4 m5; (* 5-----o--------------o-----o--5 *) + ;; + + (* choose pivots for the array by sorting 5 elements and examining the center three + elements. The goal is to choose two pivots that will either: + - break the range up into 3 even partitions + or + - eliminate a commonly appearing element by sorting it into the center partition + by itself + To this end we look at the center 3 elements of the 5 and return pairs of equal + elements or the widest range *) + let choose_pivots arr ~compare ~left ~right = + let sixth = (right - left) / 6 in + let m1 = left + sixth in + let m2 = m1 + sixth in + let m3 = m2 + sixth in + let m4 = m3 + sixth in + let m5 = m4 + sixth in + five_element_sort arr ~compare m1 m2 m3 m4 m5; + let m2_val = get arr m2 in + let m3_val = get arr m3 in + let m4_val = get arr m4 in + if compare m2_val m3_val = 0 then (m2_val, m3_val, true) + else if compare m3_val m4_val = 0 then (m3_val, m4_val, true) + else (m2_val, m4_val, false) + ;; + + let dual_pivot_partition arr ~compare ~left ~right = + let pivot1, pivot2, pivots_equal = choose_pivots arr ~compare ~left ~right in + (* loop invariants: + 1. left <= l < r <= right + 2. l <= p <= r + 3. l <= x < p implies arr[x] >= pivot1 + and arr[x] <= pivot2 + 4. left <= x < l implies arr[x] < pivot1 + 5. r < x <= right implies arr[x] > pivot2 *) + let rec loop l p r = + let pv = get arr p in + if compare pv pivot1 < 0 then begin + swap arr p l; + cont (l + 1) (p + 1) r + end else if compare pv pivot2 > 0 then begin + (* loop invariants: same as those of the outer loop *) + let rec scan_backwards r = + if r > p && compare (get arr r) pivot2 > 0 + then scan_backwards (r - 1) + else r + in + let r = scan_backwards r in + swap arr r p; + cont l p (r - 1) + end else + cont l (p + 1) r + and cont l p r = + if p > r then (l, r) else loop l p r + in + let (l, r) = cont left left right in + (l, r, pivots_equal) + ;; + + let rec intro_sort arr ~max_depth ~compare ~left ~right = + let len = right - left + 1 in + (* This takes care of some edge cases, such as left > right or very short arrays, + since Insertion_sort.sort handles these cases properly. Thus we don't need to + make sure that left and right are valid in recursive calls. *) + if len <= 32 then begin + Insertion_sort.sort arr ~compare ~left ~right + end else if max_depth < 0 then begin + Heap_sort.sort arr ~compare ~left ~right; + end else begin + let max_depth = max_depth - 1 in + let (l, r, middle_sorted) = dual_pivot_partition arr ~compare ~left ~right in + intro_sort arr ~max_depth ~compare ~left ~right:(l - 1); + if not middle_sorted then intro_sort arr ~max_depth ~compare ~left:l ~right:r; + intro_sort arr ~max_depth ~compare ~left:(r + 1) ~right; + end + ;; + + let log10_of_3 = Caml.log10 3. + + let log3 x = Caml.log10 x /. log10_of_3 + + let sort arr ~compare ~left ~right = + let len = right - left + 1 in + let heap_sort_switch_depth = + (* with perfect 3-way partitioning, this is the recursion depth *) + Int.of_float (log3 (Int.to_float len)) + in + intro_sort arr ~max_depth:heap_sort_switch_depth ~compare ~left ~right; + ;; + end +end + +let sort ?pos ?len arr ~compare = + let pos, len = + Ordered_collection_common.get_pos_len_exn () ?pos ?len ~total_length:(length arr) + in + Sort.Intro_sort.sort arr ~compare ~left:pos ~right:(pos + len - 1) + +let to_array t = t + +let is_empty t = length t = 0 + +let is_sorted t ~compare = + let rec is_sorted_loop t ~compare i = + if i < 1 + then true + else + compare t.(i - 1) t.(i) <= 0 + && is_sorted_loop t ~compare (i - 1) + in + is_sorted_loop t ~compare (length t - 1) +;; + +let is_sorted_strictly t ~compare = + let rec is_sorted_strictly_loop t ~compare i = + if i < 1 + then true + else + compare t.(i - 1) t.(i) < 0 + && is_sorted_strictly_loop t ~compare (i - 1) + in + is_sorted_strictly_loop t ~compare (length t - 1) +;; + +let folding_map t ~init ~f = + let acc = ref init in + map t ~f:(fun x -> + let new_acc, y = f !acc x in + acc := new_acc; + y) +;; + +let fold_map t ~init ~f = + let acc = ref init in + let result = + map t ~f:(fun x -> + let new_acc, y = f !acc x in + acc := new_acc; + y) + in + !acc, result +;; + +let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t +let fold_until t ~init ~f = Container.fold_until ~fold ~init ~f t +let count t ~f = Container.count ~fold t ~f +let sum m t ~f = Container.sum ~fold m t ~f +let min_elt t ~compare = Container.min_elt ~fold t ~compare +let max_elt t ~compare = Container.max_elt ~fold t ~compare + +let foldi t ~init ~f = + let rec foldi_loop t i ac ~f = + if i = length t + then ac + else foldi_loop t (i + 1) (f i ac t.(i)) ~f + in + foldi_loop t 0 init ~f +;; + +let folding_mapi t ~init ~f = + let acc = ref init in + mapi t ~f:(fun i x -> + let new_acc, y = f i !acc x in + acc := new_acc; + y) +;; + +let fold_mapi t ~init ~f = + let acc = ref init in + let result = + mapi t ~f:(fun i x -> + let new_acc, y = f i !acc x in + acc := new_acc; + y) + in + !acc, result +;; + +let counti t ~f = foldi t ~init:0 ~f:(fun idx count a -> if f idx a then count + 1 else count) + +let concat_map t ~f = concat (to_list (map ~f t)) +let concat_mapi t ~f = concat (to_list (mapi ~f t)) + +let rev_inplace t = + let i = ref 0 in + let j = ref (length t - 1) in + while !i < !j; do + swap t !i !j; + incr i; + decr j; + done +;; + +let of_list_rev l = + match l with + | [] -> [||] + | a :: l -> + let len = 1 + List.length l in + let t = create ~len a in + let r = ref l in + (* We start at [len - 2] because we already put [a] at [t.(len - 1)]. *) + for i = len - 2 downto 0 do + match !r with + | [] -> assert false + | a :: l -> t.(i) <- a; r := l + done; + t +;; + +(* [of_list_map] and [of_list_rev_map] are based on functions from the OCaml + distribution. *) + +let of_list_map xs ~f = + match xs with + | [] -> [||] + | hd::tl -> + let a = create ~len:(1 + List.length tl) (f hd) in + let rec fill i = function + | [] -> a + | hd::tl -> unsafe_set a i (f hd); fill (i+1) tl in + fill 1 tl + +let of_list_mapi xs ~f = + match xs with + | [] -> [||] + | hd::tl -> + let a = create ~len:(1 + List.length tl) (f 0 hd) in + let rec fill a i = function + | [] -> a + | hd::tl -> + unsafe_set a i (f i hd); + fill a (i+1) tl + in + fill a 1 tl + +let of_list_rev_map xs ~f = + let t = of_list_map xs ~f in + rev_inplace t; + t + +let of_list_rev_mapi xs ~f = + let t = of_list_mapi xs ~f in + rev_inplace t; + t + +(* [Obj.truncate] reduces the size of a block on the ocaml heap. For arrays, the block + size is the array length. This holds even for float arrays. *) +let unsafe_truncate t ~len = + if len <= 0 || len > length t then + raise_s (Sexp.message "Array.unsafe_truncate got invalid len" + ["len", sexp_of_int len]); + if len < length t then Caml.Obj.truncate (Caml.Obj.repr t) len; +;; + +let filter_mapi t ~f = + let r = ref [||] in + let k = ref 0 in + for i = 0 to length t - 1 do + match f i (unsafe_get t i) with + | None -> () + | Some a -> + if !k = 0 then begin + r := create ~len:(length t) a + end; + unsafe_set !r !k a; + incr k; + done; + if !k > 0 then begin + unsafe_truncate !r ~len:!k; + !r + end else + [||] + +let filter_map t ~f = + filter_mapi t ~f:(fun _i a -> f a) + +let filter_opt t = + filter_map t ~f:Fn.id + +let iter2_exn t1 t2 ~f = + if length t1 <> length t2 then invalid_arg "Array.iter2_exn"; + iteri t1 ~f:(fun i x1 -> f x1 t2.(i)) + +let map2_exn t1 t2 ~f = + let len = length t1 in + if length t2 <> len then invalid_arg "Array.map2_exn"; + init len ~f:(fun i -> f t1.(i) t2.(i)) + +let fold2_exn t1 t2 ~init ~f = + if length t1 <> length t2 then invalid_arg "Array.fold2_exn"; + foldi t1 ~init ~f:(fun i ac x -> f ac x t2.(i)) +;; + +let filter t ~f = filter_map t ~f:(fun x -> if f x then Some x else None) + +let filteri t ~f = filter_mapi t ~f:(fun i x -> if f i x then Some x else None) + +let exists t ~f = + let rec exists_loop t ~f i = + if i < 0 + then false + else f t.(i) || exists_loop t ~f (i - 1) + in + exists_loop t ~f (length t - 1) + +let existsi t ~f = + let rec existsi_loop t ~f i = + if i < 0 + then false + else f i t.(i) || existsi_loop t ~f (i - 1) + in + existsi_loop t ~f (length t - 1) + +let mem t a ~equal = exists t ~f:(equal a) + +let for_all t ~f = + let rec for_all_loop t ~f i = + if i < 0 + then true + else f t.(i) && for_all_loop t ~f (i - 1) + in + for_all_loop t ~f (length t - 1) + +let for_alli t ~f = + let rec for_alli_loop t ~f i = + if i < 0 + then true + else f i t.(i) && for_alli_loop t ~f (i - 1) + in + for_alli_loop t ~f (length t - 1) + +let exists2_exn t1 t2 ~f = + let rec exists2_exn_loop t1 t2 ~f i = + if i < 0 + then false + else f t1.(i) t2.(i) || exists2_exn_loop t1 t2 ~f (i - 1) + in + let len = length t1 in + if length t2 <> len then invalid_arg "Array.exists2_exn"; + exists2_exn_loop t1 t2 ~f (len - 1) + +let for_all2_exn t1 t2 ~f = + let rec for_all2_loop t1 t2 ~f i = + if i < 0 + then true + else f t1.(i) t2.(i) && for_all2_loop t1 t2 ~f (i - 1) + in + let len = length t1 in + if length t2 <> len then invalid_arg "Array.for_all2_exn"; + for_all2_loop t1 t2 ~f (len - 1) + +let equal equal t1 t2 = length t1 = length t2 && for_all2_exn t1 t2 ~f:equal + +let replace t i ~f = t.(i) <- f t.(i) + +let map_inplace t ~f = + for i = 0 to length t - 1 do + t.(i) <- f t.(i) + done + +let replace_all = map_inplace + +let findi t ~f = + let rec findi_loop t ~f ~length i = + if i >= length then None + else if f i t.(i) then Some (i, t.(i)) + else findi_loop t ~f ~length (i + 1) + in + let length = length t in + findi_loop t ~f ~length 0 +;; + +let findi_exn t ~f = + match findi t ~f with + | None -> raise Caml.Not_found + | Some x -> x +;; + +let find_exn t ~f = + match findi t ~f:(fun _i x -> f x) with + | None -> raise Caml.Not_found + | Some (_i, x) -> x +;; + +let find t ~f = Option.map (findi t ~f:(fun _i x -> f x)) ~f:(fun (_i, x) -> x) + +let find_map t ~f = + let rec find_map_loop t ~f ~length i = + if i >= length then None + else + match f t.(i) with + | None -> find_map_loop t ~f ~length (i + 1) + | Some _ as res -> res + in + let length = length t in + find_map_loop t ~f ~length 0 +;; + +let find_map_exn t ~f = + match find_map t ~f with + | None -> raise Caml.Not_found + | Some x -> x + +let find_mapi t ~f = + let rec find_mapi_loop t ~f ~length i = + if i >= length then None + else + match f i t.(i) with + | None -> find_mapi_loop t ~f ~length (i + 1) + | Some _ as res -> res + in + let length = length t in + find_mapi_loop t ~f ~length 0 +;; + +let find_mapi_exn t ~f = + match find_mapi t ~f with + | None -> raise Caml.Not_found + | Some x -> x + +let find_consecutive_duplicate t ~equal = + let n = length t in + if n <= 1 + then None + else begin + let result = ref None in + let i = ref 1 in + let prev = ref t.(0) in + while !i < n do + let cur = t.(!i) in + if equal cur !prev + then (result := Some (!prev, cur); i := n) + else (prev := cur; incr i) + done; + !result + end +;; + +let reduce t ~f = + if length t = 0 then None + else begin + let r = ref t.(0) in + for i = 1 to length t - 1 do + r := f !r t.(i) + done; + Some !r + end + +let reduce_exn t ~f = + match reduce t ~f with + | None -> invalid_arg "Array.reduce_exn" + | Some v -> v + +let permute = Array_permute.permute + +let random_element_exn ?(random_state = Random.State.default) t = + if is_empty t + then failwith "Array.random_element_exn: empty array" + else t.(Random.State.int random_state (length t)) + +let random_element ?(random_state = Random.State.default) t = + try Some (random_element_exn ~random_state t) + with _ -> None + +let zip t1 t2 = + if length t1 <> length t2 then None + else Some (map2_exn t1 t2 ~f:(fun x1 x2 -> x1, x2)) + +let zip_exn t1 t2 = + if length t1 <> length t2 then failwith "Array.zip_exn" + else map2_exn t1 t2 ~f:(fun x1 x2 -> x1, x2) + +let unzip t = + let n = length t in + if n = 0 then [||], [||] + else + let x, y = t.(0) in + let res1 = create ~len:n x in + let res2 = create ~len:n y in + for i = 1 to n - 1 do + let x, y = t.(i) in + res1.(i) <- x; + res2.(i) <- y; + done; + res1, res2 + +let sorted_copy t ~compare = + let t1 = copy t in + sort t1 ~compare; + t1 + +let partitioni_tf t ~f = + let both = mapi t ~f:(fun i x -> if f i x then Either.First x else Either.Second x) in + let trues = filter_map both ~f:(function First x -> Some x | Second _ -> None) in + let falses = filter_map both ~f:(function First _ -> None | Second x -> Some x) in + (trues, falses) + +let partition_tf t ~f = + partitioni_tf t ~f:(fun _i x -> f x) + +let last t = t.(length t - 1) + +(* Convert to a sequence but does not attempt to protect against modification + in the array. *) +let to_sequence_mutable t = + Sequence.unfold_step ~init:0 ~f:(fun i -> + if i >= length t + then Sequence.Step.Done + else Sequence.Step.Yield (t.(i), i+1)) + +let to_sequence t = to_sequence_mutable (copy t) + +let cartesian_product t1 t2 = + if is_empty t1 || is_empty t2 then + [||] + else + let n1 = length t1 in + let n2 = length t2 in + let t = create ~len:(n1 * n2) (t1.(0), t2.(0)) in + let r = ref 0 in + for i1 = 0 to n1 - 1 do + for i2 = 0 to n2 - 1 do + t.(!r) <- (t1.(i1), t2.(i2)); + incr r; + done + done; + t +;; + +let transpose tt = + if length tt = 0 + then Some [||] + else + let width = length tt in + let depth = length tt.(0) in + if exists tt ~f:(fun t -> length t <> depth) + then None + else Some (init depth ~f:(fun d -> init width ~f:(fun w -> tt.(w).(d)))) + +let transpose_exn tt = + match transpose tt with + | None -> invalid_arg "Array.transpose_exn"; + | Some tt' -> tt' + +include Binary_searchable.Make1 (struct + type nonrec 'a t = 'a t + + let get = get + let length = length + end) + +include + Blit.Make1 + (struct + type nonrec 'a t = 'a t + let length = length + let create_like ~len t = + if len = 0 + then [||] + else (assert (length t > 0); create ~len t.(0)) + ;; + let unsafe_blit = blit + end) +;; + +let invariant invariant_a t = iter t ~f:invariant_a + +module Private = struct + module Sort = Sort +end diff --git a/src/array.mli b/src/array.mli new file mode 100644 index 0000000..befce86 --- /dev/null +++ b/src/array.mli @@ -0,0 +1,339 @@ +(** Mutable vector of elements of type ['a] with O(1) [get] and [set] operations. *) + +open! Import + +type 'a t = 'a array [@@deriving_inline compare, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Binary_searchable.S1 with type 'a t := 'a t + +include Container.S1 with type 'a t := 'a t + +include Invariant.S1 with type 'a t := 'a t + +(** Maximum length of a normal array. The maximum length of a float array is + [max_length/2] on 32-bit machines and [max_length] on 64-bit machines. *) +val max_length : int + +(** [Array.get a n] returns the element number [n] of array [a]. + The first element has number 0. + The last element has number [Array.length a - 1]. + You can also write [a.(n)] instead of [Array.get a n]. + + Raise [Invalid_argument "index out of bounds"] + if [n] is outside the range 0 to [(Array.length a - 1)]. *) +external get : 'a t -> int -> 'a = "%array_safe_get" + +(** [Array.set a n x] modifies array [a] in place, replacing + element number [n] with [x]. + You can also write [a.(n) <- x] instead of [Array.set a n x]. + + Raise [Invalid_argument "index out of bounds"] + if [n] is outside the range 0 to [Array.length a - 1]. *) +external set : 'a t -> int -> 'a -> unit = "%array_safe_set" + +(** Unsafe version of [get]. Can cause arbitrary behavior when used for an out-of-bounds + array access. *) +external unsafe_get : 'a t -> int -> 'a = "%array_unsafe_get" + +(** Unsafe version of [set]. Can cause arbitrary behavior when used for an out-of-bounds + array access. *) +external unsafe_set : 'a t -> int -> 'a -> unit = "%array_unsafe_set" + +(** [create ~len x] creates an array of length [len] with the value [x] populated in + each element. *) +val create : len:int -> 'a -> 'a t + +(** [init n ~f] creates an array of length [n] where the [i]th element (starting at zero) + is initialized with [f i]. *) +val init : int -> f:(int -> 'a) -> 'a t + +(** [Array.make_matrix dimx dimy e] returns a two-dimensional array (an array of arrays) + with first dimension [dimx] and second dimension [dimy]. All the elements of this new + matrix are initially physically equal to [e]. The element ([x,y]) of a matrix [m] is + accessed with the notation [m.(x).(y)]. + + Raise [Invalid_argument] if [dimx] or [dimy] is negative or greater than + [Array.max_length]. + + If the value of [e] is a floating-point number, then the maximum size is only + [Array.max_length / 2]. *) +val make_matrix : dimx:int -> dimy:int -> 'a -> 'a t t + +(** [Array.append v1 v2] returns a fresh array containing the concatenation of the arrays + [v1] and [v2]. *) +val append : 'a t -> 'a t -> 'a t + +(** Like [Array.append], but concatenates a list of arrays. *) +val concat : 'a t list -> 'a t + +(** [Array.copy a] returns a copy of [a], that is, a fresh array + containing the same elements as [a]. *) +val copy : 'a t -> 'a t + +(** [Array.fill a ofs len x] modifies the array [a] in place, storing [x] in elements + number [ofs] to [ofs + len - 1]. + + Raise [Invalid_argument "Array.fill"] if [ofs] and [len] do not designate a valid + subarray of [a]. *) +val fill : 'a t -> pos:int -> len:int -> 'a -> unit + +(** [Array.blit v1 o1 v2 o2 len] copies [len] elements from array [v1], starting at + element number [o1], to array [v2], starting at element number [o2]. It works + correctly even if [v1] and [v2] are the same array, and the source and destination + chunks overlap. + + Raise [Invalid_argument "Array.blit"] if [o1] and [len] do not designate a valid + subarray of [v1], or if [o2] and [len] do not designate a valid subarray of [v2]. + + [int_blit] and [float_blit] provide fast bound-checked blits for immediate + data types. The unsafe versions do not bound-check the arguments. *) +include Blit.S1 with type 'a t := 'a t + +(** [Array.of_list l] returns a fresh array containing the elements of [l]. *) +val of_list : 'a list -> 'a t + +(** [Array.map t ~f] applies function [f] to all the elements of [t], and builds an array + with the results returned by [f]: [[| f t.(0); f t.(1); ...; f t.(Array.length t - 1) + |]]. *) +val map : 'a t -> f:('a -> 'b) -> 'b t + +(** [folding_map] is a version of [map] that threads an accumulator through calls to + [f]. *) +val folding_map : 'a t -> init:'b -> f:( 'b -> 'a -> 'b * 'c) -> 'c t +val folding_mapi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b * 'c) -> 'c t + +(** [Array.fold_map] is a combination of [Array.fold] and [Array.map] that threads an + accumulator through calls to [f]. *) +val fold_map : 'a t -> init:'b -> f:( 'b -> 'a -> 'b * 'c) -> 'b * 'c t +val fold_mapi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b * 'c) -> 'b * 'c t + +(** Like {!Array.iter}, but the function is applied to the index of the element as first + argument, and the element itself as second argument. *) +val iteri : 'a t -> f:(int -> 'a -> unit) -> unit + +(** Like {!Array.map}, but the function is applied to the index of the element as first + argument, and the element itself as second argument. *) +val mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + +val foldi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b) -> 'b + +(** [Array.fold_right f a ~init] computes [f a.(0) (f a.(1) ( ... (f a.(n-1) init) ...))], + where [n] is the length of the array [a]. *) +val fold_right : 'a t -> f:('a -> 'b -> 'b) -> init:'b -> 'b + +(** All sort functions in this module sort in increasing order by default. *) + +(** [sort] uses constant heap space. [stable_sort] uses linear heap space. + + To sort only part of the array, specify [pos] to be the index to start sorting from + and [len] indicating how many elements to sort. *) +val sort : ?pos:int -> ?len:int -> 'a t -> compare:('a -> 'a -> int) -> unit +val stable_sort : 'a t -> compare:('a -> 'a -> int) -> unit + +val is_sorted : 'a t -> compare:('a -> 'a -> int) -> bool + +(** [is_sorted_strictly xs ~compare] iff [is_sorted xs ~compare] and no two + consecutive elements in [xs] are equal according to [compare]. *) +val is_sorted_strictly : 'a t -> compare:('a -> 'a -> int) -> bool + +(** Like [List.concat_map], [List.concat_mapi]. *) +val concat_map : 'a t -> f:( 'a -> 'b array) -> 'b array +val concat_mapi : 'a t -> f:(int -> 'a -> 'b array) -> 'b array + +val partition_tf : 'a t -> f:('a -> bool) -> 'a t * 'a t + +val partitioni_tf : 'a t -> f:(int -> 'a -> bool) -> 'a t * 'a t + +val cartesian_product : 'a t -> 'b t -> ('a * 'b) t + +(** [transpose] in the sense of a matrix transpose. It returns [None] if the arrays are + not all the same length. *) +val transpose : 'a t t -> 'a t t option +val transpose_exn : 'a t t -> 'a t t + +(** [filter_opt array] returns a new array where [None] entries are omitted and [Some x] + entries are replaced with [x]. Note that this changes the index at which elements + will appear. *) +val filter_opt : 'a option t -> 'a t + +(** [filter_map ~f array] maps [f] over [array] and filters [None] out of the + results. *) +val filter_map : 'a t -> f:('a -> 'b option) -> 'b t + +(** Like [filter_map] but uses {!Array.mapi}. *) +val filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + +(** Like [for_all], but passes the index as an argument. *) +val for_alli : 'a t -> f:(int -> 'a -> bool) -> bool + +(** Like [exists], but passes the index as an argument. *) +val existsi : 'a t -> f:(int -> 'a -> bool) -> bool + +(** Like [count], but passes the index as an argument. *) +val counti : 'a t -> f:(int -> 'a -> bool) -> int + +(** Functions with the 2 suffix raise an exception if the lengths of the two given arrays + aren't the same. *) + +val iter2_exn : 'a t -> 'b t -> f:('a -> 'b -> unit) -> unit + +val map2_exn : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t + +val fold2_exn : 'a t -> 'b t -> init:'c -> f:('c -> 'a -> 'b -> 'c) -> 'c + +(** [for_all2_exn t1 t2 ~f] fails if [length t1 <> length t2]. *) +val for_all2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool + +(** [exists2_exn t1 t2 ~f] fails if [length t1 <> length t2]. *) +val exists2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool + +(** [filter t ~f] removes the elements for which [f] returns false. *) +val filter : 'a t -> f:('a -> bool) -> 'a t + +(** Like [filter] except [f] also receives the index. *) +val filteri : 'a t -> f:(int -> 'a -> bool) -> 'a t + +(** [swap arr i j] swaps the value at index [i] with that at index [j]. *) +val swap : 'a t -> int -> int -> unit + +(** [rev_inplace t] reverses [t] in place. *) +val rev_inplace : 'a t -> unit + +(** [of_list_rev l] converts from list then reverses in place. *) +val of_list_rev : 'a list -> 'a t + +(** [of_list_map l ~f] is the same as [of_list (List.map l ~f)]. *) +val of_list_map : 'a list -> f:('a -> 'b) -> 'b t + +(** [of_list_mapi l ~f] is the same as [of_list (List.mapi l ~f)]. *) +val of_list_mapi : 'a list -> f:(int -> 'a -> 'b) -> 'b t + +(** [of_list_rev_map l ~f] is the same as [of_list (List.rev_map l ~f)]. *) +val of_list_rev_map : 'a list -> f:('a -> 'b) -> 'b t + +(** [of_list_rev_mapi l ~f] is the same as [of_list (List.rev_mapi l ~f)]. *) +val of_list_rev_mapi : 'a list -> f:(int -> 'a -> 'b) -> 'b t + +(** [replace t i ~f] = [t.(i) <- f (t.(i))]. *) +val replace : 'a t -> int -> f:('a -> 'a) -> unit +[@@deprecated "[since 2018-09] use [t.(i) <- f (t.(i))] instead"] + +(** Modifies an array in place -- [ar.(i)] will be set to [f(ar.(i))]. *) +val replace_all : 'a t -> f:('a -> 'a) -> unit +[@@deprecated "[since 2018-03] use [map_inplace] instead"] + +(** Modifies an array in place, applying [f] to every element of the array *) +val map_inplace : 'a t -> f:('a -> 'a) -> unit + +(** [find_exn f t] returns the first [a] in [t] for which [f t.(i)] is true. It raises + [Caml.Not_found] or [Not_found_s] if there is no such [a]. *) +val find_exn : 'a t -> f:('a -> bool) -> 'a + +(** Returns the first evaluation of [f] that returns [Some]. Raises [Caml.Not_found] or + [Not_found_s] if [f] always returns [None]. *) +val find_map_exn : 'a t -> f:('a -> 'b option) -> 'b + +(** [findi t f] returns the first index [i] of [t] for which [f i t.(i)] is true *) +val findi : 'a t -> f:(int -> 'a -> bool) -> (int * 'a) option + +(** [findi_exn t f] returns the first index [i] of [t] for which [f i t.(i)] is true. It + raises [Caml.Not_found] or [Not_found_s] if there is no such element. *) +val findi_exn : 'a t -> f:(int -> 'a -> bool) -> int * 'a + +(** [find_mapi t f] is like [find_map] but passes the index as an argument. *) +val find_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b option + +(** [find_mapi_exn] is like [find_map_exn] but passes the index as an argument. *) +val find_mapi_exn : 'a t -> f:(int -> 'a -> 'b option) -> 'b + +(** [find_consecutive_duplicate t ~equal] returns the first pair of consecutive elements + [(a1, a2)] in [t] such that [equal a1 a2]. They are returned in the same order as + they appear in [t]. *) +val find_consecutive_duplicate : 'a t -> equal:('a -> 'a -> bool) -> ('a * 'a) option + +(** [reduce f [a1; ...; an]] is [Some (f (... (f (f a1 a2) a3) ...) an)]. Returns [None] + on the empty array. *) +val reduce : 'a t -> f:('a -> 'a -> 'a) -> 'a option +val reduce_exn : 'a t -> f:('a -> 'a -> 'a) -> 'a + +(** [permute ?random_state t] randomly permutes [t] in place. + + [permute] side-effects [random_state] by repeated calls to [Random.State.int]. If + [random_state] is not supplied, [permute] uses [Random.State.default]. *) +val permute : ?random_state:Random.State.t -> 'a t -> unit + +(** [random_element ?random_state t] is [None] if [t] is empty, else it is [Some x] for + some [x] chosen uniformly at random from [t]. + + [random_element] side-effects [random_state] by calling [Random.State.int]. If + [random_state] is not supplied, [random_element] uses [Random.State.default]. *) +val random_element : ?random_state:Random.State.t -> 'a t -> 'a option +val random_element_exn : ?random_state:Random.State.t -> 'a t -> 'a + +(** [zip] is like [List.zip], but for arrays. *) +val zip : 'a t -> 'b t -> ('a * 'b) t option +val zip_exn : 'a t -> 'b t -> ('a * 'b) t + +(** [unzip] is like [List.unzip], but for arrays. *) +val unzip : ('a * 'b) t -> 'a t * 'b t + +(** [sorted_copy ar compare] returns a shallow copy of [ar] that is sorted. Similar to + List.sort *) +val sorted_copy : 'a t -> compare:('a -> 'a -> int) -> 'a t + +val last : 'a t -> 'a + +val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + +(** [unsafe_truncate t ~len] drops [length t - len] elements from the end of [t], changing + [t] so that [length t = len] afterwards. + + [unsafe_truncate] raises if [len <= 0 || len > length t]. + + It is not safe to do [unsafe_truncate] in the middle of a call to [map], [iter], etc., + or if you have given this array out to anything not under your control: in general, + code can rely on an array's length not changing. One must ensure code that calls + [unsafe_truncate] on an array does not interfere with other code that manipulates the + array. *) +val unsafe_truncate : _ t -> len:int -> unit + + +(** The input array is copied internally so that future modifications of it do not change + the sequence. *) +val to_sequence : 'a t -> 'a Sequence.t + +(** The input array is shared with the sequence and modifications of it will result in + modification of the sequence. *) +val to_sequence_mutable : 'a t -> 'a Sequence.t + +(**/**) +(*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) +module Private : sig + module Sort : sig + module type Sort = sig + val sort + : 'a t + -> compare:('a -> 'a -> int) + -> left:int + -> right:int + -> unit + end + + module Insertion_sort : Sort + module Heap_sort : Sort + module Intro_sort : sig + include Sort + val five_element_sort + : 'a t -> compare:('a -> 'a -> int) -> int -> int -> int -> int -> int -> unit + end + end +end diff --git a/src/array0.ml b/src/array0.ml new file mode 100644 index 0000000..dc03510 --- /dev/null +++ b/src/array0.ml @@ -0,0 +1,62 @@ +(* [Array0] defines array functions that are primitives or can be simply defined in terms + of [Caml.Array]. [Array0] is intended to completely express the part of [Caml.Array] + that [Base] uses -- no other file in Base other than array0.ml should use [Caml.Array]. + [Array0] has few dependencies, and so is available early in Base's build order. All + Base files that need to use arrays and come before [Base.Array] in build order should + do [module Array = Array0]. This includes uses of subscript syntax ([x.(i)], [x.(i) <- + e]), which the OCaml parser desugars into calls to [Array.get] and [Array.set]. + Defining [module Array = Array0] is also necessary because it prevents ocamldep from + mistakenly causing a file to depend on [Base.Array]. *) + +open! Import0 + +module Sys = Sys0 + +let invalid_argf = Printf.invalid_argf + +module Array = struct + external create : int -> 'a -> 'a array = "caml_make_vect" + external get : 'a array -> int -> 'a = "%array_safe_get" + external length : 'a array -> int = "%array_length" + external set : 'a array -> int -> 'a -> unit = "%array_safe_set" + external unsafe_get : 'a array -> int -> 'a = "%array_unsafe_get" + external unsafe_set : 'a array -> int -> 'a -> unit = "%array_unsafe_set" +end + +include Array + +let max_length = Sys.max_array_length + +let create ~len x = + try create len x + with Invalid_argument _ -> + invalid_argf "Array.create ~len:%d: invalid length" len () +;; + +let append = Caml.Array.append +let blit = Caml.Array.blit +let concat = Caml.Array.concat +let copy = Caml.Array.copy +let fill = Caml.Array.fill +let init = Caml.Array.init +let make_matrix = Caml.Array.make_matrix +let of_list = Caml.Array.of_list +let sub = Caml.Array.sub +let to_list = Caml.Array.to_list + + +(* These are eta expanded in order to permute parameter order to follow Base + conventions. *) +let fold t ~init ~f = Caml.Array.fold_left t ~init ~f +let fold_right t ~f ~init = Caml.Array.fold_right t ~f ~init +let iter t ~f = Caml.Array.iter t ~f +let iteri t ~f = Caml.Array.iteri t ~f +let map t ~f = Caml.Array.map t ~f +let mapi t ~f = Caml.Array.mapi t ~f +let stable_sort t ~compare = Caml.Array.stable_sort t ~cmp:compare + +let swap t i j = + let tmp = t.(i) in + t.(i) <- t.(j); + t.(j) <- tmp; +;; diff --git a/src/array_permute.ml b/src/array_permute.ml new file mode 100644 index 0000000..c10a950 --- /dev/null +++ b/src/array_permute.ml @@ -0,0 +1,12 @@ +(** An internal-only module factored out due to a circular dependency between core_array + and core_list. Contains code for permuting an array. *) + +open! Import + +include Array0 + +(** randomly permute an array. *) +let permute ?(random_state = Random.State.default) t = + for i = length t downto 2 do + swap t (i - 1) (Random.State.int random_state i) + done diff --git a/src/avltree.ml b/src/avltree.ml new file mode 100644 index 0000000..d49c8cd --- /dev/null +++ b/src/avltree.ml @@ -0,0 +1,411 @@ +(* A few small things copied from other parts of Base because they depend on us, so we + can't use them. *) + +open! Import + +module Int = struct + type t = int + + let max (x : t) y = if x > y then x else y +end + +(* Its important that Empty have no args. It's tempting to make this type a record + (e.g. to hold the compare function), but a lot of memory is saved by Empty being an + immediate, since all unused buckets in the hashtbl don't use any memory (besides the + array cell) *) +type ('k, 'v) t = + | Empty + | Node of { mutable left : ('k, 'v) t + ; key : 'k + ; mutable value : 'v + ; mutable height : int + ; mutable right : ('k, 'v) t + } + | Leaf of { key : 'k + ; mutable value : 'v + } + + +let empty = Empty + +let height = function + | Empty -> 0 + | Leaf _ -> 1 + | Node { left = _; key = _; value = _; height; right = _ } -> height + +let invariant compare = + let legal_left_key key = function + | Empty -> () + | Leaf { key = left_key; value = _; } + | Node { left = _; key = left_key; value = _; height = _; right = _ } -> + assert (compare left_key key < 0) + in + let legal_right_key key = function + | Empty -> () + | Leaf { key = right_key; value = _; } + | Node { left = _; key = right_key; value = _; height = _; right = _ } -> + assert (compare right_key key > 0) + in + let rec inv = function + | Empty | Leaf _ -> () + | Node { left; key = k; value = _; height = h; right } -> + let (hl, hr) = (height left, height right) in + inv left; + inv right; + legal_left_key k left; + legal_right_key k right; + assert (h = Int.max hl hr + 1); + assert (abs (hl - hr) <= 2) + in inv + +let invariant t ~compare = invariant compare t + +(* In the following comments, + 't is balanced' means that 'invariant t' does not + raise an exception. This implies of course that each node's height field is + correct. + 't is balanceable' means that height of the left and right subtrees of t + differ by at most 3. *) + +(* @pre: left and right subtrees have correct heights + @post: output has the correct height *) +let update_height = function + | Node ({ left; key = _; value = _; height = old_height; right } as x) -> + let new_height = (Int.max (height left) (height right)) + 1 in + if new_height <> old_height then x.height <- new_height + | Empty | Leaf _ -> assert false + +(* @pre: left and right subtrees are balanced + @pre: tree is balanceable + @post: output is balanced (in particular, height is correct) *) +let balance tree = + match tree with + | Empty | Leaf _ -> tree + | Node ({ left; key = _; value = _; height = _; right } as root_node) -> + let hl = height left and hr = height right in + (* + 2 is critically important, lowering it to 1 will break the Leaf + assumptions in the code below, and will force us to promote leaf nodes in + the balance routine. It's also faster, since it will balance less often. + Note that the following code is delicate. The update_height calls must + occur in the correct order, since update_height assumes its children have + the correct heights. *) + if hl > hr + 2 then begin + match left with + (* It cannot be a leaf, because even if right is empty, a leaf + is only height 1 *) + | Empty | Leaf _ -> assert false + | Node ({ left = left_node_left; key = _; value = _; height = _; + right = left_node_right; } + as left_node) -> + if height left_node_left >= height left_node_right then begin + root_node.left <- left_node_right; + left_node.right <- tree; + update_height tree; + update_height left; + left + end else begin + (* if right is a leaf, then left must be empty. That means + height is 2. Even if hr is empty we still can't get here. *) + match left_node_right with + | Empty | Leaf _ -> assert false + | Node ({ left = lr_left; key = _; value = _; height = _; right = lr_right; } + as lr_node) -> + left_node.right <- lr_left; + root_node.left <- lr_right; + lr_node .right <- tree; + lr_node .left <- left; + update_height left; + update_height tree; + update_height left_node_right; + left_node_right + end + end else if hr > hl + 2 then begin + (* see above for an explanation of why right cannot be a leaf *) + match right with + | Empty | Leaf _ -> assert false + | Node ({ left = right_node_left; key = _; value = _; height = _; + right = right_node_right } + as right_node) -> + if height right_node_right >= height right_node_left then begin + root_node .right <- right_node_left; + right_node.left <- tree; + update_height tree; + update_height right; + right + end else begin + (* see above for an explanation of why this cannot be a leaf *) + match right_node_left with + | Empty | Leaf _ -> assert false + | Node ({ left = rl_left; key = _; value = _; height = _; right = rl_right } + as rl_node) + -> + right_node.left <- rl_right; + root_node .right <- rl_left; + rl_node .left <- tree; + rl_node .right <- right; + update_height right; + update_height tree; + update_height right_node_left; + right_node_left + end + end else begin + update_height tree; + tree + end +;; + +(* @pre: tree is balanceable + @pre: abs (height (right node) - height (balance tree)) <= 3 + @post: result is balanceable *) + +(* @pre: tree is balanceable + @pre: abs (height (right node) - height (balance tree)) <= 3 + @post: result is balanceable *) +let set_left node tree = + let tree = balance tree in + match node with + | Node ({ left; key = _; value = _; height = _; right = _ } as r) -> + if phys_equal left tree then () + else + r.left <- tree; + update_height node + | _ -> assert false + +(* @pre: tree is balanceable + @pre: abs (height (left node) - height (balance tree)) <= 3 + @post: result is balanceable *) +let set_right node tree = + let tree = balance tree in + match node with + | Node ({ left = _; key = _; value = _; height = _; right } as r) -> + if phys_equal right tree then () + else + r.right <- tree; + update_height node + | _ -> assert false + +(* @pre: t is balanced. + @post: result is balanced, with new node inserted + @post: !added = true iff the shape of the input tree changed. *) +let add = + let rec add t replace added compare k v = + match t with + | Empty -> + added := true; + Leaf { key = k; value = v } + | Leaf ({ key = k'; value = _ } as r) -> + let c = compare k' k in + (* This compare is reversed on purpose, we are pretending + that the leaf was just inserted instead of the other way + round, that way we only allocate one node. *) + if c = 0 then begin + added := false; + if replace then r.value <- v; + t + end else begin + added := true; + if c < 0 then + Node { left = t; key = k; value = v; height = 2; right = Empty } + else + Node { left = Empty; key = k; value = v; height = 2; right = t } + end + | Node ({left; key = k'; value = _; height = _; right } as r) -> + let c = compare k k' in + if c = 0 then begin + added := false; + if replace then r.value <- v; + end else if c < 0 then + set_left t (add left replace added compare k v) + else + set_right t (add right replace added compare k v); + t + in + fun t ~replace ~compare ~added ~key ~data -> + let t = add t replace added compare key data in + if !added then balance t else t +;; + +let rec first t = + match t with + | Empty -> None + | Leaf { key = k; value = v } + | Node { left = Empty; key = k; value = v; height = _; right = _ } -> Some (k, v) + | Node { left = l; key = _; value = _; height = _; right = _ } -> first l +;; + +let rec last t = + match t with + | Empty -> None + | Leaf { key = k; value = v } + | Node { left = _; key = k; value = v; height = _; right = Empty } -> Some (k, v) + | Node { left = _; key = _; value = _; height = _; right = r } -> last r +;; + + +let[@inline always] rec findi_and_call_impl t ~compare k ~call_if_found ~if_found ~if_not_found = + (* A little manual unrolling of the recursion. + This is really worth 5% on average *) + match t with + | Empty -> if_not_found k + | Leaf { key = k'; value = v } -> + if compare k k' = 0 then call_if_found ~if_found ~key:k' ~data:v + else if_not_found k + | Node { left; key = k'; value = v; height = _; right } -> + let c = compare k k' in + if c = 0 then call_if_found ~if_found ~key:k' ~data:v + else if c < 0 then begin + match left with + | Empty -> if_not_found k + | Leaf { key = k'; value = v }-> + if compare k k' = 0 then call_if_found ~if_found ~key:k' ~data:v + else if_not_found k + | Node { left; key = k'; value = v; height = _; right } -> + let c = compare k k' in + if c = 0 then call_if_found ~if_found ~key:k' ~data:v + else + findi_and_call_impl (if c < 0 then left else right) ~compare k ~call_if_found ~if_found ~if_not_found + end else begin + match right with + | Empty -> if_not_found k + | Leaf { key = k'; value = v } -> + if compare k k' = 0 then call_if_found ~if_found ~key:k' ~data:v + else if_not_found k + | Node { left; key = k'; value = v; height = _; right } -> + let c = compare k k' in + if c = 0 then call_if_found ~if_found ~key:k' ~data:v + else + findi_and_call_impl (if c < 0 then left else right) ~compare k ~call_if_found ~if_found ~if_not_found + end +;; + +let find_and_call = + let call_if_found ~if_found ~key:_ ~data = if_found data in + fun t ~compare k ~if_found ~if_not_found -> + findi_and_call_impl t ~compare k ~call_if_found ~if_found ~if_not_found + +let findi_and_call = + let call_if_found ~if_found ~key ~data = if_found ~key ~data in + fun t ~compare k ~if_found ~if_not_found -> + findi_and_call_impl t ~compare k ~call_if_found ~if_found ~if_not_found + +let find = + let if_found v = Some v in + let if_not_found _ = None in + fun t ~compare k -> + find_and_call t ~compare k ~if_found ~if_not_found + +let mem = + let if_found _ = true in + let if_not_found _ = false in + fun t ~compare k -> + find_and_call t ~compare k ~if_found ~if_not_found + +let remove = + let rec min_elt tree = + match tree with + | Empty -> Empty + | Leaf _ -> tree + | Node { left = Empty; key = _; value = _; height = _; right = _ } -> tree + | Node { left; key = _; value = _; height = _; right = _ } -> min_elt left + in + let rec remove_min_elt tree = + match tree with + | Empty -> assert false + | Leaf _ -> Empty (* This must be the root *) + | Node { left = Empty; key = _; value = _; height = _; right } -> right + | Node { left = Leaf _; key = k; value = v; height = _; right = Empty } -> + Leaf { key = k; value = v } + | Node { left = Leaf _; key = _; value = _; height = _; right = _ } as node -> + set_left node Empty; tree + | Node { left; key = _; value = _; height = _; right = _ } as node -> + set_left node (remove_min_elt left); tree + in + let merge t1 t2 = + match (t1, t2) with + | (Empty, t) -> t + | (t, Empty) -> t + | (_, _) -> + let tree = min_elt t2 in + match tree with + | Empty -> assert false + | Leaf { key = k; value = v } -> + let t2 = balance (remove_min_elt t2) in + Node { left = t1; key = k; value = v; + height = Int.max (height t1) (height t2) + 1; right = t2 + } + | Node _ as node -> + set_right node (remove_min_elt t2); + set_left node t1; + node + in + let rec remove t removed compare k = + match t with + | Empty -> + removed := false; + Empty + | Leaf { key = k'; value = _ } -> + if compare k k' = 0 then begin + removed := true; + Empty + end else begin + removed := false; + t + end + | Node { left; key = k'; value = _; height = _; right } -> + let c = compare k k' in + if c = 0 then begin + removed := true; + merge left right + end else if c < 0 then begin + set_left t (remove left removed compare k); + t + end else begin + set_right t (remove right removed compare k); + t + end + in + fun t ~removed ~compare k -> balance (remove t removed compare k) +;; + +let rec fold t ~init ~f = + match t with + | Empty -> init + | Leaf { key; value = data } -> f ~key ~data init + | Node { left = Leaf { key = lkey; value = ldata }; + key; value = data; height = _; + right = Leaf { key = rkey; value = rdata } + } -> + f ~key:rkey ~data:rdata (f ~key ~data (f ~key:lkey ~data:ldata init)) + | Node { left = Leaf { key = lkey; value = ldata}; key; value = data; height = _; + right = Empty } + -> f ~key ~data (f ~key:lkey ~data:ldata init) + | Node { left = Empty; key; value = data; height = _; + right = Leaf { key = rkey; value = rdata } + } -> + f ~key:rkey ~data:rdata (f ~key ~data init) + | Node { left; key; value = data; height = _; right = Leaf { key = rkey; value = rdata } + } -> + f ~key:rkey ~data:rdata (f ~key ~data (fold left ~init ~f)) + | Node { left = Leaf { key = lkey; value = ldata}; key; value = data; height = _; right } -> + fold right ~init:(f ~key ~data (f ~key:lkey ~data:ldata init)) ~f + | Node { left; key; value = data; height = _; right } -> + fold right ~init:(f ~key ~data (fold left ~init ~f)) ~f + +let rec iter t ~f = + match t with + | Empty -> () + | Leaf { key; value = data } -> f ~key ~data + | Node { left; key; value = data; height = _; right } -> + iter left ~f; + f ~key ~data; + iter right ~f + +let rec mapi_inplace t ~f = + match t with + | Empty -> () + | Leaf ({ key ; value } as t) -> + t.value <- f ~key ~data:value + | Node ({ left ; key ; value ; height = _ ; right } as t) -> + mapi_inplace ~f left; + t.value <- f ~key ~data:value; + mapi_inplace ~f right diff --git a/src/avltree.mli b/src/avltree.mli new file mode 100644 index 0000000..382467b --- /dev/null +++ b/src/avltree.mli @@ -0,0 +1,129 @@ +(** A low-level, mutable AVL tree. + + It is not intended to be used directly by casual users. It is used for implementing + other data structures. The interface is somewhat ugly, and it's that way for a + reason: the goal of this module is minimum memory overhead and maximum performance. + + {2 Caveats} + + 1. [compare] is passed to every function where it is used. If you pass a different + [compare] to functions on the same tree, then behavior is indeterminate. Why? Because + otherwise we'd need a top-level record to store [compare], and when building a hash + table, or other structure, that little [t] is a block that increases memory + overhead. However, if an empty tree is just a constructor [Empty], then it's just a + number, and uses no extra memory beyond the array bucket that holds it. That's the + first secret of how Hashtbl's memory overhead isn't higher than INRIA's, even though + it uses a tree instead of a list for buckets. + + 2. But if it's mutable, why do all the "mutators" return [t]? Answer: it is mutable, + but the root node might change due to balancing. Since we have no top-level record to + hold the current root node (see point 1), you have to do it. If you fail to do it, and + use an old root node, you're responsible for the (sure to be nasty) consequences. + + 3. What on earth is up with the [~removed] argument to some functions? See point 1: + since there is no top-level node, it isn't possible to keep track of how many nodes + are in the tree unless each mutator tells you whether or not it added or removed a + node (vs. replacing an existing one). If you intend to keep a count (as you must in a + hash table), then you will need to pay attention to this flag. + + After all this, you're probably asking yourself whether all these hacks are worth + it. Yes! They are! With them, we built a hash table that is faster than INRIA's (no + small feat) with the same memory overhead, sane add semantics (the add semantics they + used were a performance hack), and worst-case log(N) insertion, lookup, and + removal. *) + +open! Import + +(** We expose [t] to allow an optimization in Hashtbl that makes iter and fold more than + twice as fast. We keep the type private to reduce opportunities for external code to + violate avltree invariants. *) +type ('k, 'v) t = private + | Empty + | Node of { mutable left : ('k, 'v) t + ; key : 'k + ; mutable value : 'v + ; mutable height : int + ; mutable right : ('k, 'v) t + } + | Leaf of { key : 'k + ; mutable value : 'v + } + +val empty : ('k, 'v) t + +(** Checks invariants, raising an exception if any invariants fail. *) +val invariant : ('k, 'v) t -> compare:('k -> 'k -> int) -> unit + +(** Adds the specified key and data to the tree destructively (previous [t]'s are no + longer valid) using the specified comparison function. O(log(N)) time, O(1) space. + + The returned [t] is the new root node of the tree, and should be used on all further + calls to any other function in this module. The bool [ref], added, will be set to + [true] if a new node is added to the tree, or [false] if an existing node is replaced + (in the case that the key already exists). + + If [replace] (default true) is true then [add] will overwrite any existing mapping for + [key]. If [replace] is false, and there is an existing mapping for key, then [add] has + no effect. *) +val add + : ('k, 'v) t + -> replace:bool + -> compare:('k -> 'k -> int) + -> added:bool ref + -> key:'k + -> data:'v + -> ('k, 'v) t + +(** Returns the first (leftmost) or last (rightmost) element in the tree. *) + +val first : ('k, 'v) t -> ('k * 'v) option +val last : ('k, 'v) t -> ('k * 'v) option + +(** If the specified key exists in the tree, returns the corresponding value. O(log(N)) + time and O(1) space. *) +val find : ('k, 'v) t -> compare:('k -> 'k -> int) -> 'k -> 'v option + +(** [find_and_call t ~compare k ~if_found ~if_not_found] + + is equivalent to: + + [match find t ~compare k with Some v -> if_found v | None -> if_not_found k] + + except that it doesn't allocate the option. *) +val find_and_call + : ('k, 'v) t + -> compare:('k -> 'k -> int) + -> 'k + -> if_found:('v -> 'a) + -> if_not_found:('k -> 'a) + -> 'a + +val findi_and_call + : ('k, 'v) t + -> compare:('k -> 'k -> int) + -> 'k + -> if_found:(key:'k -> data:'v -> 'a) + -> if_not_found:('k -> 'a) + -> 'a + +(** Returns true if key is present in the tree, and false otherwise. *) +val mem : ('k, 'v) t -> compare:('k -> 'k -> int) -> 'k -> bool + +(** Removes key destructively from the tree if it exists, returning the new root node. + Previous root nodes are not usable anymore; do so at your peril. The [removed] ref + will be set to true if a node was actually removed, and false otherwise. *) +val remove + : ('k, 'v) t + -> removed:bool ref + -> compare:('k -> 'k -> int) + -> 'k + -> ('k, 'v) t + +(** Folds over the tree. *) +val fold : ('k, 'v) t -> init:'a -> f:(key:'k -> data:'v -> 'a -> 'a) -> 'a + +(** Iterates over the tree. *) +val iter : ('k, 'v) t -> f:(key:'k -> data:'v -> unit) -> unit + +(** Map over the the tree, changing the data in place. *) +val mapi_inplace : ('k, 'v) t -> f:(key:'k -> data:'v -> 'v) -> unit diff --git a/src/backtrace.ml b/src/backtrace.ml new file mode 100644 index 0000000..5a268a5 --- /dev/null +++ b/src/backtrace.ml @@ -0,0 +1,51 @@ +open! Import + +module Sys = Sys0 + +type t = Caml.Printexc.raw_backtrace + +let elide = ref am_testing +let elided_message = "" + +let get ?(at_most_num_frames = Int.max_value) () = + Caml.Printexc.get_callstack at_most_num_frames +;; + +let to_string t = + if !elide + then elided_message + else Caml.Printexc.raw_backtrace_to_string t +;; + +let to_string_list t = String.split_lines (to_string t) + +let sexp_of_t t = + Sexp.List (List.map (to_string_list t) ~f:(fun x -> Sexp.Atom x)) +;; + +module Exn = struct + + let set_recording = Caml.Printexc.record_backtrace + let am_recording = Caml.Printexc.backtrace_status + + let most_recent () = + Caml.Printexc.get_raw_backtrace () + ;; + + (* We turn on backtraces by default if OCAMLRUNPARAM isn't set. *) + let maybe_set_recording () = + match Sys.getenv "OCAMLRUNPARAM" with + | exception _ -> set_recording true + | (_ : string) -> () (* the caller set something, they are responsible *) + ;; + + let with_recording b ~f = + let saved = am_recording () in + set_recording b; + Exn.protect ~f ~finally:(fun () -> set_recording saved) + ;; +end + +let initialize_module () = + Exn.maybe_set_recording (); +;; diff --git a/src/backtrace.mli b/src/backtrace.mli new file mode 100644 index 0000000..e929816 --- /dev/null +++ b/src/backtrace.mli @@ -0,0 +1,79 @@ +(** Module for managing stack backtraces. + + The [Backtrace] module deals with two different kinds of backtraces: + + + Snapshots of the stack obtained on demand ([Backtrace.get]) + + The stack frames unwound when an exception is raised ([Backtrace.Exn]) +*) + +open! Import + +(** A [Backtrace.t] is a snapshot of the stack obtained by calling [Backtrace.get]. It is + represented as a string with newlines separating the frames. [sexp_of_t] splits the + string at newlines and removes some of the cruft, leaving a human-friendly list of + frames, but [to_string] does not. *) +type t [@@deriving_inline sexp_of] +include +sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] + +val get : ?at_most_num_frames:int -> unit -> t + +val to_string : t -> string + +val to_string_list : t -> string list + +(** The value of [elide] controls the behavior of backtrace serialization functions such + as {!to_string}, {!to_string_list}, and {!sexp_of_t}. When set to [false], these + functions behave as expected, returning a faithful representation of their argument. + When set to [true], these functions will ignore their argument and return a message + indicating that behavior. + + The default value is {!am_testing}. *) +val elide : bool ref + +(** [Backtrace.Exn] has functions for controlling and printing the backtrace of the most + recently raised exception. + + When an exception is raised, the runtime "unwinds" the stack, i.e., removes stack + frames, until it reaches a frame with an exception handler. It then matches the + exception against the patterns in the handler. If the exception matches, then the + program continues. If not, then the runtime continues unwinding the stack to the next + handler. + + If [am_recording () = true], then while the runtime is unwinding the stack, it keeps + track of the part of the stack that is unwound. This is available as a backtrace via + [most_recent ()]. Calling [most_recent] if [am_recording () = false] will yield the + empty backtrace. + + With [am_recording () = true], OCaml keeps only a backtrace for the most recently + raised exception. When one raises an exception, OCaml checks if it is physically equal + to the most recently raised exception. If it is, then OCaml appends the string + representation of the stack unwound by the current raise to the stored backtrace. If + the exception being raised is not physically equally to the most recently raised + exception, then OCaml starts recording a new backtrace. Thus one must call + [most_recent] before a subsequent [raise] of a (physically) distinct exception, or the + backtrace is lost. + + The initial value of [am_recording ()] is determined by the setting of the environment + variable OCAMLRUNPARAM. If OCAMLRUNPARAM is set, then [am_recording () = true] iff the + character "b" occurs in OCAMLRUNPARAM. If OCAMLRUNPARAM is not set (as is always the + case when running in a web browser), then [am_recording ()] is initially true. + + This is the same functionality as provided by the OCaml stdlib [Printexc] functions + [backtrace_status], [record_backtraces], [get_backtrace]. *) +module Exn : sig + val am_recording : unit -> bool + val set_recording : bool -> unit + + val with_recording : bool -> f:(unit -> 'a) -> 'a + + (** [most_recent ()] returns a backtrace containing the stack that was unwound by the + most recently raised exception. *) + val most_recent : unit -> t +end + +(** User code never calls this. It is called only in [std_kernel.ml], as a top-level side + effect, to initialize [am_recording ()] as specified above. *) +val initialize_module : unit -> unit diff --git a/src/base.ml b/src/base.ml new file mode 100644 index 0000000..1ef404b --- /dev/null +++ b/src/base.ml @@ -0,0 +1,471 @@ +(** This module is the toplevel of the Base library; it's what you get when you write + [open Base]. + + The recommended way to use Base is to build with [-open Base]. Files compiled this + way will have the environment described in this file as their initial environment. + + Base extends some modules and data structures from the standard library, like [Array], + [Buffer], [Bytes], [Char], [Hashtbl], [Int32], [Int64], [Lazy], [List], [Map], + [Nativeint], [Printf], [Random], [Set], [String], [Sys], and [Uchar]. One key + difference is that Base doesn't use exceptions as much as the standard library and + instead makes heavy use of the [Result] type, as in: + + {[ type ('a,'b) result = Ok of 'a | Error of 'b ]} + + Base also adds entirely new modules, most notably: + + - [Comparable], [Comparator], and [Comparisons] in lieu of polymorphic compare. + - [Container], which provides a consistent interface across container-like data + structures (arrays, lists, strings). + - [Result], [Error], and [Or_error], supporting the or-error pattern. + + Broadly the goal of Base is both to be a more complete standard library, with richer + APIs, and to be more consistent in its design. For instance, in the standard library + some things have modules and others don't; in Base, everything is a module. +*) + +(*_ We hide this from the web docs because the line wrapping is bad, making it + pretty much inscrutable. *) +(**/**) + +(** The intent is to shadow all of INRIA's standard library. Modules below would cause + compilation errors without being removed from [Shadow_stdlib] before inclusion. *) +include (Shadow_stdlib + : module type of struct include Shadow_stdlib end + (* Modules defined in Base *) + with module Array := Shadow_stdlib.Array + with module Bool := Shadow_stdlib.Bool + with module Buffer := Shadow_stdlib.Buffer + with module Bytes := Shadow_stdlib.Bytes + with module Char := Shadow_stdlib.Char + with module Float := Shadow_stdlib.Float + with module Hashtbl := Shadow_stdlib.Hashtbl + with module Int := Shadow_stdlib.Int + with module Int32 := Shadow_stdlib.Int32 + with module Int64 := Shadow_stdlib.Int64 + with module Lazy := Shadow_stdlib.Lazy + with module List := Shadow_stdlib.List + with module Map := Shadow_stdlib.Map + with module Nativeint := Shadow_stdlib.Nativeint + with module Option := Shadow_stdlib.Option + with module Printf := Shadow_stdlib.Printf + with module Queue := Shadow_stdlib.Queue + with module Random := Shadow_stdlib.Random + with module Result := Shadow_stdlib.Result + with module Set := Shadow_stdlib.Set + with module Stack := Shadow_stdlib.Stack + with module String := Shadow_stdlib.String + with module Sys := Shadow_stdlib.Sys + with module Uchar := Shadow_stdlib.Uchar + with module Unit := Shadow_stdlib.Unit + + (* Support for generated lexers *) + with module Lexing := Shadow_stdlib.Lexing + + with type ('a, 'b, 'c) format := ('a, 'b, 'c) format + with type ('a, 'b, 'c, 'd) format4 := ('a, 'b, 'c, 'd) format4 + with type ('a, 'b, 'c, 'd, 'e, 'f) format6 := ('a, 'b, 'c, 'd, 'e, 'f) format6 + + with type 'a ref := 'a ref + ) [@ocaml.warning "-3"] + +(**/**) + +open! Import + +module Applicative = Applicative +module Array = Array +module Avltree = Avltree +module Backtrace = Backtrace +module Binary_search = Binary_search +module Binary_searchable = Binary_searchable +module Blit = Blit +module Bool = Bool +module Buffer = Buffer +module Bytes = Bytes +module Char = Char +module Comparable = Comparable +module Comparator = Comparator +module Comparisons = Comparisons +module Container = Container +module Either = Either +module Equal = Equal +module Error = Error +module Exn = Exn +module Field = Field +module Float = Float +module Floatable = Floatable +module Fn = Fn +module Formatter = Formatter +module Hash = Hash +module Hash_set = Hash_set +module Hashable = Hashable +module Hasher = Hasher +module Hashtbl = Hashtbl +module Identifiable = Identifiable +module Indexed_container = Indexed_container +module Info = Info +module Int = Int +module Int32 = Int32 +module Int63 = Int63 +module Int64 = Int64 +module Intable = Intable +module Invariant = Invariant +module Lazy = Lazy +module List = List +module Map = Map +module Maybe_bound = Maybe_bound +module Monad = Monad +module Nativeint = Nativeint +module Option = Option +module Option_array = Option_array +module Or_error = Or_error +module Ordered_collection_common = Ordered_collection_common +module Ordering = Ordering +module Poly = Poly +module Polymorphic_compare = Poly +[@@deprecated "[since 2018-11] use [Poly] instead"] +module Popcount = Popcount +[@@deprecated "[since 2018-10] use [popcount] functions in the individual int modules"] +module Pretty_printer = Pretty_printer +module Printf = Printf +module Linked_queue = Linked_queue +module Queue = Queue +module Random = Random +module Ref = Ref +module Result = Result +module Sequence = Sequence +module Set = Set +module Sexpable = Sexpable +module Sign = Sign +module Sign_or_nan = Sign_or_nan +module Source_code_position = Source_code_position +module Stack = Stack +module Staged = Staged +module String = String +module Stringable = Stringable +module Sys = Sys +module T = T +module Type_equal = Type_equal +module Uniform_array = Uniform_array +module Unit = Unit +module Uchar = Uchar +module Validate = Validate +module Variant = Variant +module With_return = With_return +module Word_size = Word_size + +(* Avoid a level of indirection for uses of the signatures defined in [T]. *) +include T + +(* This is a hack so that odoc creates better documentation. *) +module Sexp = struct + include Sexp_with_comparable (** @inline *) +end + +(**/**) +module Exported_for_specific_uses = struct + module Fieldslib = Fieldslib + module Ppx_hash_lib = Ppx_hash_lib + module Sexplib = Sexplib + module Variantslib = Variantslib + module Ppx_compare_lib = Ppx_compare_lib + module Ppx_sexp_conv_lib = Ppx_sexp_conv_lib + let am_testing = am_testing +end +(**/**) + +module Export = struct + (* [deriving hash] is missing for [array] and [ref] since these types are mutable. + (string is also mutable, but we pretend it isn't for hashing purposes) *) + type 'a array = 'a Array. t [@@deriving_inline compare, equal, sexp] + let compare_array : 'a . ('a -> 'a -> int) -> 'a array -> 'a array -> int = + Array.compare + let equal_array : 'a . ('a -> 'a -> bool) -> 'a array -> 'a array -> bool = + Array.equal + let array_of_sexp : + 'a . + (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a array + = Array.t_of_sexp + let sexp_of_array : + 'a . + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a array -> Ppx_sexp_conv_lib.Sexp.t + = Array.sexp_of_t + [@@@end] + type bool = Bool. t [@@deriving_inline compare, equal, hash, sexp] + let compare_bool : bool -> bool -> int = Bool.compare + let equal_bool : bool -> bool -> bool = Bool.equal + let (hash_fold_bool : + Ppx_hash_lib.Std.Hash.state -> bool -> Ppx_hash_lib.Std.Hash.state) = + Bool.hash_fold_t + and (hash_bool : bool -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Bool.hash in fun x -> func x + let bool_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> bool = Bool.t_of_sexp + let sexp_of_bool : bool -> Ppx_sexp_conv_lib.Sexp.t = Bool.sexp_of_t + [@@@end] + type char = Char. t [@@deriving_inline compare, equal, hash, sexp] + let compare_char : char -> char -> int = Char.compare + let equal_char : char -> char -> bool = Char.equal + let (hash_fold_char : + Ppx_hash_lib.Std.Hash.state -> char -> Ppx_hash_lib.Std.Hash.state) = + Char.hash_fold_t + and (hash_char : char -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Char.hash in fun x -> func x + let char_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> char = Char.t_of_sexp + let sexp_of_char : char -> Ppx_sexp_conv_lib.Sexp.t = Char.sexp_of_t + [@@@end] + type exn = Exn. t [@@deriving_inline sexp_of] + let sexp_of_exn : exn -> Ppx_sexp_conv_lib.Sexp.t = Exn.sexp_of_t + [@@@end] + type float = Float. t [@@deriving_inline compare, equal, hash, sexp] + let compare_float : float -> float -> int = Float.compare + let equal_float : float -> float -> bool = Float.equal + let (hash_fold_float : + Ppx_hash_lib.Std.Hash.state -> float -> Ppx_hash_lib.Std.Hash.state) = + Float.hash_fold_t + and (hash_float : float -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Float.hash in fun x -> func x + let float_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> float = Float.t_of_sexp + let sexp_of_float : float -> Ppx_sexp_conv_lib.Sexp.t = Float.sexp_of_t + [@@@end] + type int = Int. t [@@deriving_inline compare, equal, hash, sexp] + let compare_int : int -> int -> int = Int.compare + let equal_int : int -> int -> bool = Int.equal + let (hash_fold_int : + Ppx_hash_lib.Std.Hash.state -> int -> Ppx_hash_lib.Std.Hash.state) = + Int.hash_fold_t + and (hash_int : int -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Int.hash in fun x -> func x + let int_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> int = Int.t_of_sexp + let sexp_of_int : int -> Ppx_sexp_conv_lib.Sexp.t = Int.sexp_of_t + [@@@end] + type int32 = Int32. t [@@deriving_inline compare, equal, hash, sexp] + let compare_int32 : int32 -> int32 -> int = Int32.compare + let equal_int32 : int32 -> int32 -> bool = Int32.equal + let (hash_fold_int32 : + Ppx_hash_lib.Std.Hash.state -> int32 -> Ppx_hash_lib.Std.Hash.state) = + Int32.hash_fold_t + and (hash_int32 : int32 -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Int32.hash in fun x -> func x + let int32_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> int32 = Int32.t_of_sexp + let sexp_of_int32 : int32 -> Ppx_sexp_conv_lib.Sexp.t = Int32.sexp_of_t + [@@@end] + type int64 = Int64. t [@@deriving_inline compare, equal, hash, sexp] + let compare_int64 : int64 -> int64 -> int = Int64.compare + let equal_int64 : int64 -> int64 -> bool = Int64.equal + let (hash_fold_int64 : + Ppx_hash_lib.Std.Hash.state -> int64 -> Ppx_hash_lib.Std.Hash.state) = + Int64.hash_fold_t + and (hash_int64 : int64 -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Int64.hash in fun x -> func x + let int64_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> int64 = Int64.t_of_sexp + let sexp_of_int64 : int64 -> Ppx_sexp_conv_lib.Sexp.t = Int64.sexp_of_t + [@@@end] + type 'a list = 'a List. t [@@deriving_inline compare, equal, hash, sexp] + let compare_list : 'a . ('a -> 'a -> int) -> 'a list -> 'a list -> int = + List.compare + let equal_list : 'a . ('a -> 'a -> bool) -> 'a list -> 'a list -> bool = + List.equal + let hash_fold_list : + 'a . + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a list -> Ppx_hash_lib.Std.Hash.state + = List.hash_fold_t + let list_of_sexp : + 'a . + (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a list + = List.t_of_sexp + let sexp_of_list : + 'a . + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a list -> Ppx_sexp_conv_lib.Sexp.t + = List.sexp_of_t + [@@@end] + type nativeint = Nativeint. t [@@deriving_inline compare, equal, hash, sexp] + let compare_nativeint : nativeint -> nativeint -> int = Nativeint.compare + let equal_nativeint : nativeint -> nativeint -> bool = Nativeint.equal + let (hash_fold_nativeint : + Ppx_hash_lib.Std.Hash.state -> nativeint -> Ppx_hash_lib.Std.Hash.state) = + Nativeint.hash_fold_t + and (hash_nativeint : nativeint -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Nativeint.hash in fun x -> func x + let nativeint_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> nativeint = + Nativeint.t_of_sexp + let sexp_of_nativeint : nativeint -> Ppx_sexp_conv_lib.Sexp.t = + Nativeint.sexp_of_t + [@@@end] + type 'a option = 'a Option. t [@@deriving_inline compare, equal, hash, sexp] + let compare_option : 'a . ('a -> 'a -> int) -> 'a option -> 'a option -> int + = Option.compare + let equal_option : 'a . ('a -> 'a -> bool) -> 'a option -> 'a option -> bool + = Option.equal + let hash_fold_option : + 'a . + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a option -> Ppx_hash_lib.Std.Hash.state + = Option.hash_fold_t + let option_of_sexp : + 'a . + (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a option + = Option.t_of_sexp + let sexp_of_option : + 'a . + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a option -> Ppx_sexp_conv_lib.Sexp.t + = Option.sexp_of_t + [@@@end] + type 'a ref = 'a Ref. t [@@deriving_inline compare, equal, sexp] + let compare_ref : 'a . ('a -> 'a -> int) -> 'a ref -> 'a ref -> int = + Ref.compare + let equal_ref : 'a . ('a -> 'a -> bool) -> 'a ref -> 'a ref -> bool = + Ref.equal + let ref_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a ref + = Ref.t_of_sexp + let sexp_of_ref : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a ref -> Ppx_sexp_conv_lib.Sexp.t + = Ref.sexp_of_t + [@@@end] + type string = String. t [@@deriving_inline compare, equal, hash, sexp] + let compare_string : string -> string -> int = String.compare + let equal_string : string -> string -> bool = String.equal + let (hash_fold_string : + Ppx_hash_lib.Std.Hash.state -> string -> Ppx_hash_lib.Std.Hash.state) = + String.hash_fold_t + and (hash_string : string -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = String.hash in fun x -> func x + let string_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> string = String.t_of_sexp + let sexp_of_string : string -> Ppx_sexp_conv_lib.Sexp.t = String.sexp_of_t + [@@@end] + type bytes = Bytes. t [@@deriving_inline compare, equal, sexp] + let compare_bytes : bytes -> bytes -> int = Bytes.compare + let equal_bytes : bytes -> bytes -> bool = Bytes.equal + let bytes_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> bytes = Bytes.t_of_sexp + let sexp_of_bytes : bytes -> Ppx_sexp_conv_lib.Sexp.t = Bytes.sexp_of_t + [@@@end] + type unit = Unit. t [@@deriving_inline compare, equal, hash, sexp] + let compare_unit : unit -> unit -> int = Unit.compare + let equal_unit : unit -> unit -> bool = Unit.equal + let (hash_fold_unit : + Ppx_hash_lib.Std.Hash.state -> unit -> Ppx_hash_lib.Std.Hash.state) = + Unit.hash_fold_t + and (hash_unit : unit -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Unit.hash in fun x -> func x + let unit_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> unit = Unit.t_of_sexp + let sexp_of_unit : unit -> Ppx_sexp_conv_lib.Sexp.t = Unit.sexp_of_t + [@@@end] + + (** Format stuff *) + + type nonrec ('a, 'b, 'c) format = ('a, 'b, 'c) format + type nonrec ('a, 'b, 'c, 'd) format4 = ('a, 'b, 'c, 'd) format4 + type nonrec ('a, 'b, 'c, 'd, 'e, 'f) format6 = ('a, 'b, 'c, 'd, 'e, 'f) format6 + + (** {2 Sexp} + + Exporting the ad-hoc types that are recognized by [ppx_sexp_*] converters. + [sexp_array], [sexp_list], and [sexp_option] allow a record field to be absent when + converting from a sexp, and if absent, the field will take a default value of the + appropriate type: + + {v + sexp_array [||] + sexp_bool false + sexp_list [] + sexp_option None + v} + + [sexp_opaque] causes the conversion to sexp to produce the atom []. + + For more documentation, see sexplib/README.md. *) + + type 'a sexp_array = 'a array + type 'a sexp_list = 'a list + type 'a sexp_opaque = 'a + type 'a sexp_option = 'a option + + (** List operators *) + + include List.Infix + + (** Int operators and comparisons *) + + include Int.O + include Int_replace_polymorphic_compare + + (** Float operators *) + + include Float.O_dot + + (** Reverse application operator. [x |> g |> f] is equivalent to [f (g (x))]. *) + (* This is declared as an external to be optimized away in more contexts. *) + external ( |> ) : 'a -> ( 'a -> 'b) -> 'b = "%revapply" + + (** Application operator. [g @@ f @@ x] is equivalent to [g (f (x))]. *) + external ( @@ ) : ('a -> 'b) -> 'a -> 'b = "%apply" + + (** Boolean operations *) + + (* These need to be declared as an external to get the lazy behavior *) + external ( && ) : bool -> bool -> bool = "%sequand" + external ( || ) : bool -> bool -> bool = "%sequor" + external not : bool -> bool = "%boolnot" + + (* This must be declared as an external for the warnings to work properly. *) + external ignore : _ -> unit = "%ignore" + + (** Common string operations *) + let ( ^ ) = String.( ^ ) + + (** Reference operations *) + + (* Declared as an externals so that the compiler skips the caml_modify when possible and + to keep reference unboxing working *) + external ( ! ) : 'a ref -> 'a = "%field0" + external ref : 'a -> 'a ref = "%makemutable" + external ( := ) : 'a ref -> 'a -> unit = "%setfield0" + + (** Pair operations *) + + let fst = fst + let snd = snd + + (** Exceptions stuff *) + + (* Declared as an external so that the compiler may rewrite '%raise' as '%reraise'. *) + external raise : exn -> _ = "%raise" + + let failwith = failwith + let invalid_arg = invalid_arg + let raise_s = Error.raise_s + + (** Misc *) + + let phys_equal = phys_equal + + external force : 'a Lazy.t -> 'a = "%lazy_force" +end + +include Export +include Container_intf.Export (** @inline *) + +exception Not_found_s = Not_found_s + +(* Various things to cleanup that were used without going through Base. *) +module Not_exposed_properly = struct + module Int63_emul = Int63_emul + module Float0 = Float0 + module Import = Import + module Int_conversions = Int_conversions + module Int_math = Int_math + module Pow_overflow_bounds = Pow_overflow_bounds + module Sexp_conv = Sexplib0.Sexp_conv + module Obj_array = Obj_array +end + +(* We perform these side effects here because we want them to run for any code that uses + [Base]. If this were in another module in [Base] that was not used in some program, + then the side effects might not be run in that program. This will run as long as the + program refers to at least one value directly in [Base]; referring to values in + [Base.Bool], for example, is not sufficient. *) +let () = + Backtrace.initialize_module (); +;; diff --git a/src/base.mld b/src/base.mld new file mode 100644 index 0000000..4a09082 --- /dev/null +++ b/src/base.mld @@ -0,0 +1,173 @@ +{1 Base} + +{b {{!Base} The full API is browsable here}}. + +Base is a standard library for OCaml. It provides a standard set of +general-purpose modules that are well tested, performant, and fully +portable across any environment that can run OCaml code. + +Unlike other standard library projects, Base is meant to be used as a +wholesale replacement of the standard library distributed with the +OCaml compiler. In particular, it makes different choices and doesn't +re-export features that are not fully portable such as I/O, which are +left to other libraries. + +Note that an API for OCaml's channel-based I/O can be found in the +{{!module:Stdio}[Stdio]} library. + +{2 Relationship to Core_kernel and Core} + +Base is the smallest, most self-contained version of Jane Street's +family of three standard library replacements. It is extended by +[Core_kernel], which is in turn extended by [Core]. + +In sum: + +- {b {!Base}}: Minimal stdlib replacement. Portable and lightweight and + intended to be highly stable. +- {b {!Core_kernel}}: Extension of Base. More fully featured, with more + code and dependencies, and APIs that evolve more quickly. Portable, + and works on Javascript. +- {b {!Core}}: Core_kernel extended with UNIX APIs. + + +{2 Using the OCaml standard library with Base} + +Base is intended as a full stdlib replacement. As a result, after an +[open Base], all the modules, values, types, etc., coming from the OCaml +standard library that one normally gets in the default environment are +deprecated. + +In order to access these values, one must use the [Caml] library, +which re-exports them all through the toplevel name +{{!module:Caml}[Caml]}: [Caml.String], [Caml.print_string], ... + +The new modules and values made available by Base are documented +{{!Base} here}. + +{2 Differences between Base and the OCaml standard library} + +Programmers who are used to the OCaml standard library should read +through this section to understand major differences between the two +libraries that one should be aware of when switching to Base. + +{3 Comparison operators} + +The comparison operators exposed by the OCaml standard library are +polymorphic: + +{[ +val compare : 'a -> 'a -> int +val ( <= ) : 'a -> 'a -> bool +(* ... *) +]} + +What they implement is structural comparison of the runtime +representation of values. Since these are often error-prone, +i.e., they don't correspond to what the user expects, they are not +exposed directly by Base. + +To use polymorphic comparison with Base, one should use the +{{!Base.Polymorphic_compare}[Polymorphic_compare]} module. The default +comparison operators exposed by Base are the integer ones, just like +the default arithmetic operators are the integer ones. + +The recommended way to compare arbitrary complex data structures is to +use the specific [compare] functions. For instance: + +{[ List.compare String.compare x y ]} + +The [ppx_compare] rewriter offers an alternative way to write this: + +{[ [%compare: string list] x y ]} + +{2 Base and ppx code generators} + +Base uses a few ppx code generators to implement: + +- reliable and customizable comparison of OCaml values; +- reliable and customizable hash of OCaml values; and +- conversions between OCaml values and s-expression. + +However, it doesn't need these code generators to build. Instead, it +uses ppx as a code verification tool during development. It works in a +very similar fashion to {{: https://github.com/janestreet/ppx_expect} +expect tests}. + +Whenever you see this in the code source: + +{[ +type t = ... [@@deriving_inline sexp_of] +let sexp_of_t = ... +[@@@end] +]} + +the code between the [[@@deriving_inline]] and the [[@@@end]] is +generated code. The generated code is currently quite big and hard to +read, however we are working on making it look like human-written +code. + +You can put the following elisp code in your [~/.emacs] file to hide +these blocks: + +{v +(defun deriving-inline-forward-sexp (&optional arg) + (search-forward-regexp "\\[@@@end\\]") nil nil arg) + +(defun setup-hide-deriving-inline () + (inline) + (hs-minor-mode t) + (let ((hs-hide-comments-when-hiding-all nil)) + (hs-hide-all))) + +(require 'hideshow) +(add-to-list 'hs-special-modes-alist + '(tuareg-mode "\\[@@deriving_inline[^]]*\\]" "\\[@@@end\\]" nil + deriving-inline-forward-sexp nil)) +(add-hook 'tuareg-mode-hook 'setup-hide-deriving-inline) +v} + +Things are not yet set up in the git repository to make it convenient +to change types and update the generated code, but they will be set up +soon. + +{2 Base coding rules} + +There are a few coding rules across the code base that are enforced by +lint tools. + +These rules are: + +{ul +{- Opening the [Caml] module is not allowed. Inside Base, the OCaml + stdlib is shadowed and accessible through the [Caml] module. We + forbid opening [Caml] so that we know exactly where things come + from.} +{- [Caml.Foo] modules cannot be aliased, one must use [Caml.Foo] + explicitly. This is to avoid having to remember a list of aliases + at the beginning of each file.} +{- For some modules that are both in the OCaml stdlib and Base, such as + [String], we define a module [String0] for common functions that + cannot be defined directly in [Base.String] to avoid creating a + circular dependency. Except for [String] itself, other modules + are not allowed to use [Caml.String] and must use either [String] or + [String0] instead.} +{- Indentation is exactly the one of [ocp-indent].} +{- A few other coding style rules enforced by + {{: https://github.com/janestreet/ppx_js_style} ppx_js_style}.} +} + + +The Base specific coding rules are checked by [ppx_base_lint], in the +[lint] subfolder. The indentation rules are checked by a wrapper around +[ocp-indent] and the coding style rules are checked by [ppx_js_style]. + +These checks are currently not run by [jbuilder], but it will soon get +a [-dev] flag to run them automatically. + +{2 Roadmap} + +Base is still under active development and there are several missing +feature that are yet to be added. Consult the +{{:https://github.com/janestreet/base/blob/master/ROADMAP.md}roadmap} to +see what is happening. diff --git a/src/binary_search.ml b/src/binary_search.ml new file mode 100644 index 0000000..0011158 --- /dev/null +++ b/src/binary_search.ml @@ -0,0 +1,109 @@ +open! Import + +(* These functions implement a search for the first (resp. last) element + satisfying a predicate, assuming that the predicate is increasing on + the container, meaning that, if the container is [u1...un], there exists a + k such that p(u1)=....=p(uk) = false and p(uk+1)=....=p(un)= true. + If this k = 1 (resp n), find_last_not_satisfying (resp find_first_satisfying) + will return None. *) + +let rec linear_search_first_satisfying t ~get ~lo ~hi ~pred = + if lo > hi + then None + else + if pred (get t lo) + then Some lo + else linear_search_first_satisfying t ~get ~lo:(lo + 1) ~hi ~pred +;; + +(* Takes a container [t], a predicate [pred] and two indices [lo < hi], such that + [pred] is increasing on [t] between [lo] and [hi]. + + return a range (lo, hi) where: + - lo and hi are close enough together for a linear search + - If [pred] is not constantly [false] on [t] between [lo] and [hi], the first element + on which [pred] is [true] is between [lo] and [hi]. *) +(* Invariant: the first element satisfying [pred], if it exists is between [lo] and [hi] *) +let rec find_range_near_first_satisfying t ~get ~lo ~hi ~pred = + (* Warning: this function will not terminate if the constant (currently 8) is + set <= 1 *) + if hi - lo <= 8 + then (lo,hi) + else + let mid = lo + ((hi - lo) / 2) in + if pred (get t mid) + (* INVARIANT check: it means the first satisfying element is between [lo] and [mid] *) + then find_range_near_first_satisfying t ~get ~lo ~hi:mid ~pred + (* INVARIANT check: it means the first satisfying element, if it exists, + is between [mid+1] and [hi] *) + else find_range_near_first_satisfying t ~get ~lo:(mid+1) ~hi ~pred +;; + +let find_first_satisfying ?pos ?len t ~get ~length ~pred = + let pos, len = + Ordered_collection_common.get_pos_len_exn () ?pos ?len ~total_length:(length t) + in + let lo = pos in + let hi = pos + len - 1 in + let (lo, hi) = find_range_near_first_satisfying t ~get ~lo ~hi ~pred in + linear_search_first_satisfying t ~get ~lo ~hi ~pred +;; + +(* Takes an array with shape [true,...true,false,...false] (i.e., the _reverse_ of what + is described above) and returns the index of the last true or None if there are no + true*) +let find_last_satisfying ?pos ?len t ~pred ~get ~length = + let pos, len = + Ordered_collection_common.get_pos_len_exn () ?pos ?len ~total_length:(length t) + in + if len = 0 + then None + else begin + (* The last satisfying is the one just before the first not satisfying *) + match find_first_satisfying ~pos ~len t ~get ~length ~pred:(Fn.non pred) with + | None -> Some (pos + len - 1) (* This means that all elements satisfy pred. + There is at least an element as (len > 0) *) + | Some i when i = pos -> None (* no element satisfies pred *) + | Some i -> Some (i - 1) + end +;; + +let binary_search ?pos ?len t ~length ~get ~compare how v = + match how with + | `Last_strictly_less_than -> + find_last_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v < 0) + | `Last_less_than_or_equal_to -> + find_last_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v <= 0) + | `First_equal_to -> + begin + match + find_first_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v >= 0) + with + | Some x when compare (get t x) v = 0 -> Some x + | None | Some _ -> None + end + | `Last_equal_to -> + begin + match + find_last_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v <= 0) + with + | Some x when compare (get t x) v = 0 -> Some x + | None | Some _ -> None + end + | `First_greater_than_or_equal_to -> + find_first_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v >= 0) + | `First_strictly_greater_than -> + find_first_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v > 0) +;; + +let binary_search_segmented ?pos ?len t ~length ~get ~segment_of how = + let is_left x = + match segment_of x with + | `Left -> true + | `Right -> false + in + let is_right x = not (is_left x) in + match how with + | `Last_on_left -> find_last_satisfying ?pos ?len t ~length ~get ~pred:is_left + | `First_on_right -> find_first_satisfying ?pos ?len t ~length ~get ~pred:is_right +;; diff --git a/src/binary_search.mli b/src/binary_search.mli new file mode 100644 index 0000000..abbe3da --- /dev/null +++ b/src/binary_search.mli @@ -0,0 +1,86 @@ +(** General functions for performing binary searches over ordered sequences given + [length] and [get] functions. + + These functions can be specialized and added to a data structure using the functors + supplied in {{!Base.Binary_searchable}[Binary_searchable]} and described in + {{!Base.Binary_searchable_intf}[Binary_searchable_intf]}. + + {2:examples Examples} + + Below we assume that the functions [get], [length] and [compare] are in scope: + + {[ + (* Find the index of an element [e] in [t] *) + binary_search t ~get ~length ~compare `First_equal_to e; + + (* Find the index where an element [e] should be inserted *) + binary_search t ~get ~length ~compare `First_greater_than_or_equal_to e; + + (* Find the index in [t] where all elements to the left are less than [e] *) + binary_search_segmented t ~get ~length ~segment_of:(fun e' -> + if compare e' e <= 0 then `Left else `Right) `First_on_right + ]} *) + +open! Import + +(** [binary_search ?pos ?len t ~length ~get ~compare which elt] takes [t] that is sorted + in increasing order according to [compare], where [compare] and [elt] divide [t] into + three (possibly empty) segments: + + {v + | < elt | = elt | > elt | + v} + + [binary_search] returns the index in [t] of an element on the boundary of segments + as specified by [which]. See the diagram below next to the [which] variants. + + By default, [binary_search] searches the entire [t]. One can supply [?pos] or + [?len] to search a slice of [t]. + + [binary_search] does not check that [compare] orders [t], and behavior is + unspecified if [compare] doesn't order [t]. Behavior is also unspecified if + [compare] mutates [t]. *) +val binary_search + : ?pos:int + -> ?len:int + -> 't + -> length:('t -> int) + -> get:('t -> int -> 'elt) + -> compare:('elt -> 'key -> int) + -> [ `Last_strictly_less_than (** {v | < elt X | v} *) + | `Last_less_than_or_equal_to (** {v | <= elt X | v} *) + | `Last_equal_to (** {v | = elt X | v} *) + | `First_equal_to (** {v | X = elt | v} *) + | `First_greater_than_or_equal_to (** {v | X >= elt | v} *) + | `First_strictly_greater_than (** {v | X > elt | v} *) + ] + -> 'key + -> int option + +(** [binary_search_segmented ?pos ?len t ~length ~get ~segment_of which] takes a + [segment_of] function that divides [t] into two (possibly empty) segments: + + {v + | segment_of elt = `Left | segment_of elt = `Right | + v} + + [binary_search_segmented] returns the index of the element on the boundary of the + segments as specified by [which]: [`Last_on_left] yields the index of the last + element of the left segment, while [`First_on_right] yields the index of the first + element of the right segment. It returns [None] if the segment is empty. + + By default, [binary_search] searches the entire [t]. One can supply [?pos] or + [?len] to search a slice of [t]. + + [binary_search_segmented] does not check that [segment_of] segments [t] as in the + diagram, and behavior is unspecified if [segment_of] doesn't segment [t]. Behavior + is also unspecified if [segment_of] mutates [t]. *) +val binary_search_segmented + : ?pos:int + -> ?len:int + -> 't + -> length:('t -> int) + -> get:('t -> int -> 'elt) + -> segment_of:('elt -> [ `Left | `Right ]) + -> [ `Last_on_left | `First_on_right ] + -> int option diff --git a/src/binary_searchable.ml b/src/binary_searchable.ml new file mode 100644 index 0000000..8779b98 --- /dev/null +++ b/src/binary_searchable.ml @@ -0,0 +1,36 @@ +open! Import + +include Binary_searchable_intf + +module type Arg = sig + type 'a elt + type 'a t + val get : 'a t -> int -> 'a elt + val length : _ t -> int +end + +module Make_gen (T : Arg) = struct + let get = T.get + let length = T.length + + let binary_search ?pos ?len t ~compare how v = + Binary_search.binary_search ?pos ?len t ~get ~length ~compare how v + + let binary_search_segmented ?pos ?len t ~segment_of how = + Binary_search.binary_search_segmented ?pos ?len t ~get ~length ~segment_of how +end + +module Make (T : Indexable) = + Make_gen (struct + type 'a elt = T.elt + type 'a t = T.t + include (T : Indexable with type elt := T.elt with type t := T.t) + end) + +module Make1 (T : Indexable1) = + Make_gen (struct + type 'a elt = 'a + type 'a t = 'a T.t + let get = T.get + let length = T.length + end) diff --git a/src/binary_searchable.mli b/src/binary_searchable.mli new file mode 100644 index 0000000..85837ea --- /dev/null +++ b/src/binary_searchable.mli @@ -0,0 +1 @@ +include Binary_searchable_intf.Binary_searchable (** @inline *) diff --git a/src/binary_searchable_intf.ml b/src/binary_searchable_intf.ml new file mode 100644 index 0000000..03b0725 --- /dev/null +++ b/src/binary_searchable_intf.ml @@ -0,0 +1,76 @@ +(** Module types for a [binary_search] function for a sequence, and functors for building + [binary_search] functions. *) + +open! Import + +(** An [Indexable] type is a finite sequence of elements indexed by consecutive integers + [0] ... [length t - 1]. [get] and [length] must be O(1) for the resulting + [binary_search] to be lg(n). *) +module type Indexable = sig + type elt + type t + + val get : t -> int -> elt + val length : t -> int +end + +module type Indexable1 = sig + type 'a t + + val get : 'a t -> int -> 'a + val length : _ t -> int +end + +type ('t, 'elt, 'key) binary_search = + ?pos:int + -> ?len:int + -> 't + -> compare:('elt -> 'key -> int) + -> [ `Last_strictly_less_than (** {v | < elt X | v} *) + | `Last_less_than_or_equal_to (** {v | <= elt X | v} *) + | `Last_equal_to (** {v | = elt X | v} *) + | `First_equal_to (** {v | X = elt | v} *) + | `First_greater_than_or_equal_to (** {v | X >= elt | v} *) + | `First_strictly_greater_than (** {v | X > elt | v} *) + ] + -> 'key + -> int option + +type ('t, 'elt) binary_search_segmented = + ?pos:int + -> ?len:int + -> 't + -> segment_of:('elt -> [ `Left | `Right ]) + -> [ `Last_on_left | `First_on_right ] + -> int option + +module type S = sig + type elt + type t + + (** See [Binary_search.binary_search] in binary_search.ml *) + val binary_search : (t, elt, 'key) binary_search + + (** See [Binary_search.binary_search_segmented] in binary_search.ml *) + val binary_search_segmented : (t, elt) binary_search_segmented +end + +module type S1 = sig + type 'a t + + val binary_search : ('a t, 'a, 'key) binary_search + val binary_search_segmented : ('a t, 'a) binary_search_segmented +end + +module type Binary_searchable = sig + module type S = S + module type S1 = S1 + module type Indexable = Indexable + module type Indexable1 = Indexable1 + + type nonrec ('t, 'elt, 'key) binary_search = ('t, 'elt, 'key) binary_search + type nonrec ('t, 'elt) binary_search_segmented = ('t, 'elt) binary_search_segmented + + module Make (T : Indexable) : S with type t := T.t with type elt := T.elt + module Make1 (T : Indexable1) : S1 with type 'a t := 'a T.t +end diff --git a/src/blit.ml b/src/blit.ml new file mode 100644 index 0000000..c2a3ae9 --- /dev/null +++ b/src/blit.ml @@ -0,0 +1,113 @@ +open! Import + + +include Blit_intf + +module type Sequence_gen = sig + type 'a t + val length : _ t -> int +end + +module Make_gen + (Src : Sequence_gen) + (Dst : sig + include Sequence_gen + val create_like : len:int -> 'a Src.t -> 'a t + val unsafe_blit : ('a Src.t, 'a t) blit + end) = struct + + let unsafe_blit = Dst.unsafe_blit + + let blit ~src ~src_pos ~dst ~dst_pos ~len = + Ordered_collection_common.check_pos_len_exn + ~pos:src_pos ~len ~total_length:(Src.length src); + Ordered_collection_common.check_pos_len_exn + ~pos:dst_pos ~len ~total_length:(Dst.length dst); + if len > 0 then unsafe_blit ~src ~src_pos ~dst ~dst_pos ~len; + ;; + + let blito + ~src ?(src_pos = 0) ?(src_len = Src.length src - src_pos) ~dst ?(dst_pos = 0) + () = + blit ~src ~src_pos ~len:src_len ~dst ~dst_pos; + ;; + + (* [sub] and [subo] ensure that every position of the created sequence is populated by + an element of the source array. Thus every element of [dst] below is well + defined. *) + let sub src ~pos ~len = + Ordered_collection_common.check_pos_len_exn ~pos ~len + ~total_length:(Src.length src); + let dst = Dst.create_like ~len src in + if len > 0 then unsafe_blit ~src ~src_pos:pos ~dst ~dst_pos:0 ~len; + dst + ;; + + let subo ?(pos = 0) ?len src = + sub src ~pos ~len:(match len with Some i -> i | None -> Src.length src - pos) + ;; +end + +module Make1 + (Sequence : sig + include Sequence_gen + val create_like : len:int -> 'a t -> 'a t + val unsafe_blit : ('a t, 'a t) blit + end) = + Make_gen + (Sequence) + (Sequence) + +module Make1_generic + (Sequence : Sequence1) = + Make_gen + (Sequence) + (Sequence) + +module Make + (Sequence : sig + include Sequence + val create : len:int -> t + val unsafe_blit : (t, t) blit + end) = struct + module Sequence = struct + type 'a t = Sequence.t + open Sequence + let create_like ~len _ = create ~len + let length = length + let unsafe_blit = unsafe_blit + end + include Make_gen (Sequence) (Sequence) +end + +module Make_distinct + (Src : Sequence) + (Dst : sig + include Sequence + val create : len:int -> t + val unsafe_blit : (Src.t, t) blit + end) = + Make_gen + (struct + type 'a t = Src.t + open Src + let length = length + end) + (struct + type 'a t = Dst.t + open Dst + let length = length + let create_like ~len _ = create ~len + let unsafe_blit = unsafe_blit + end) + +module Make_to_string + (T : sig type t end) + (To_bytes : S_distinct with type src := T.t with type dst := bytes) += struct + open To_bytes + let sub src ~pos ~len = + Bytes0.unsafe_to_string ~no_mutation_while_string_reachable:(sub src ~pos ~len) + let subo ?pos ?len src = + Bytes0.unsafe_to_string ~no_mutation_while_string_reachable:(subo ?pos ?len src) +end diff --git a/src/blit.mli b/src/blit.mli new file mode 100644 index 0000000..ea4c29c --- /dev/null +++ b/src/blit.mli @@ -0,0 +1 @@ +include Blit_intf.Blit (** @inline *) diff --git a/src/blit_intf.ml b/src/blit_intf.ml new file mode 100644 index 0000000..b99f2a3 --- /dev/null +++ b/src/blit_intf.ml @@ -0,0 +1,167 @@ +(** Standard type for [blit] functions, and reusable code for validating [blit] + arguments. *) + +open! Import + +(** If [blit : (src, dst) blit], then [blit ~src ~src_pos ~len ~dst ~dst_pos] blits [len] + values from [src] starting at position [src_pos] to [dst] at position [dst_pos]. + Furthermore, [blit] raises if [src_pos], [len], and [dst_pos] don't specify valid + slices of [src] and [dst]. *) +type ('src, 'dst) blit + = src : 'src + -> src_pos : int + -> dst : 'dst + -> dst_pos : int + -> len : int + -> unit + +(** [blito] is like [blit], except that the [src_pos], [src_len], and [dst_pos] are + optional (hence the "o" in "blito"). Also, we use [src_len] rather than [len] as a + reminder that if [src_len] isn't supplied, then the default is to take the slice + running from [src_pos] to the end of [src]. *) +type ('src, 'dst) blito + = src : 'src + -> ?src_pos : int (** default is [0] *) + -> ?src_len : int (** default is [length src - src_pos] *) + -> dst : 'dst + -> ?dst_pos : int (** default is [0] *) + -> unit + -> unit + +(** If [sub : (src, dst) sub], then [sub ~src ~pos ~len] returns a sequence of type [dst] + containing [len] characters of [src] starting at [pos]. + + [subo] is like [sub], except [pos] and [len] are optional. *) +type ('src, 'dst) sub = 'src -> pos:int -> len:int -> 'dst +type ('src, 'dst) subo + = ?pos : int (** default is [0] *) + -> ?len : int (** default is [length src - pos] *) + -> 'src + -> 'dst + +(*_ These are not implemented less-general-in-terms-of-more-general because odoc produces + unreadable documentation in that case, with or without [inline] on [include]. *) + +module type S = sig + type t + val blit : (t, t) blit + val blito : (t, t) blito + val unsafe_blit : (t, t) blit + val sub : (t, t) sub + val subo : (t, t) subo +end + +module type S1 = sig + type 'a t + val blit : ('a t, 'a t) blit + val blito : ('a t, 'a t) blito + val unsafe_blit : ('a t, 'a t) blit + val sub : ('a t, 'a t) sub + val subo : ('a t, 'a t) subo +end + +module type S_distinct = sig + type src + type dst + val blit : (src, dst) blit + val blito : (src, dst) blito + val unsafe_blit : (src, dst) blit + val sub : (src, dst) sub + val subo : (src, dst) subo +end + +module type S_to_string = sig + type t + val sub : (t, string) sub + val subo : (t, string) subo +end + +(** Users of modules matching the blit signatures [S], [S1], and [S1_distinct] only need + to understand the code above. The code below is only for those that need to implement + modules that match those signatures. *) + +module type Sequence = sig + type t + val length : t -> int +end + +type 'a poly = 'a + +module type Sequence1 = sig + type 'a t + + (** [Make1*] guarantees to only call [create_like ~len t] with [len > 0] if [length t > + 0]. *) + val create_like : len:int -> 'a t -> 'a t + val length : _ t -> int + + val unsafe_blit : ('a t, 'a t) blit +end + +module type Blit = sig + type nonrec ('src, 'dst) blit = ('src, 'dst) blit + type nonrec ('src, 'dst) blito = ('src, 'dst) blito + type nonrec ('src, 'dst) sub = ('src, 'dst) sub + type nonrec ('src, 'dst) subo = ('src, 'dst) subo + + module type S = S + module type S1 = S1 + module type S_distinct = S_distinct + module type S_to_string = S_to_string + module type Sequence = Sequence + module type Sequence1 = Sequence1 + + (** There are various [Make*] functors that turn an [unsafe_blit] function into a [blit] + function. The functors differ in whether the sequence type is monomorphic or + polymorphic, and whether the src and dst types are distinct or are the same. + + The blit functions make sure the slices are valid and then call [unsafe_blit]. They + guarantee at a call [unsafe_blit ~src ~src_pos ~dst ~dst_pos ~len] that: + + {[ + len > 0 + && src_pos >= 0 + && src_pos + len <= get_src_len src + && dst_pos >= 0 + && dst_pos + len <= get_dst_len dst + ]} + + The [Make*] functors also automatically create unit tests. *) + + (** [Make] is for blitting between two values of the same monomorphic type. *) + module Make + (Sequence : sig + include Sequence + val create : len:int -> t + val unsafe_blit : (t, t) blit + end) + : S with type t := Sequence.t + + (** [Make_distinct] is for blitting between values of distinct monomorphic types. *) + module Make_distinct + (Src : Sequence) + (Dst : sig + include Sequence + val create : len:int -> t + val unsafe_blit : (Src.t, t) blit + end) + : S_distinct + with type src := Src.t + with type dst := Dst.t + + module Make_to_string + (T : sig type t end) + (To_bytes : S_distinct with type src := T.t with type dst := bytes) + : S_to_string with type t := T.t + + (** [Make1] is for blitting between two values of the same polymorphic type. *) + module Make1 + (Sequence : Sequence1) + : S1 with type 'a t := 'a Sequence.t + + (** [Make1_generic] is for blitting between two values of the same container type that's + not fully polymorphic (in the sense of Container.Generic). *) + module Make1_generic + (Sequence : Sequence1) + : S1 with type 'a t := 'a Sequence.t +end diff --git a/src/bool.ml b/src/bool.ml new file mode 100644 index 0000000..1fbf23b --- /dev/null +++ b/src/bool.ml @@ -0,0 +1,89 @@ +open! Import + +let invalid_argf = Printf.invalid_argf + +module T = struct + type t = bool [@@deriving_inline compare, enumerate, hash, sexp] + let compare : t -> t -> int = compare_bool + let all : t list = [false; true] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_bool + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_bool in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = bool_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_bool + [@@@end] + + let of_string = function + | "true" -> true + | "false" -> false + | s -> invalid_argf "Bool.of_string: expected true or false but got %s" s () + ;; + + let to_string = Caml.string_of_bool +end + +include T +include Comparator.Make(T) +include Comparable.Validate(T) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string = to_string + let module_name = "Base.Bool" + end) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Bool_replace_polymorphic_compare + + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +(* We use [Obj.magic] here as other implementations generate a conditional jump and the + performance difference is noticeable. *) +let to_int (x : bool) = (Caml.Obj.magic x : int) + +module Non_short_circuiting = struct + (* We don't expose this, since we don't want to break the invariant mentioned below of + (to_int true = 1) and (to_int false = 0). *) + let unsafe_of_int (x : int) = (Caml.Obj.magic x : bool) + + let (||) a b = + unsafe_of_int (to_int a lor to_int b) + + let (&&) a b = + unsafe_of_int (to_int a land to_int b) +end + +(* We do this as a direct assert on the theory that it's a cheap thing to test and a + really core invariant that we never expect to break, and we should be happy for a + program to fail immediately if this is violated. *) +let () = + assert (Poly.(=) (to_int true ) 1 && + Poly.(=) (to_int false) 0); +;; + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Bool_replace_polymorphic_compare diff --git a/src/bool.mli b/src/bool.mli new file mode 100644 index 0000000..163d304 --- /dev/null +++ b/src/bool.mli @@ -0,0 +1,34 @@ +(** Boolean type extended to be enumerable, hashable, sexpable, comparable, and + stringable. *) + +open! Import + +type t = bool [@@deriving_inline enumerate, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Identifiable.S with type t := t + +(** + - [to_int true = 1] + - [to_int false = 0] *) +val to_int : t -> int + +module Non_short_circuiting : sig + (** Non-short circuiting and branch-free boolean operators. + + The default versions of these infix operators are short circuiting, which + requires branching instructions to implement. The operators below are + instead branch-free, and therefore not short-circuiting. *) + + val (&&) : t -> t -> t + val (||) : t -> t -> t +end diff --git a/src/buffer.ml b/src/buffer.ml new file mode 100644 index 0000000..37338f7 --- /dev/null +++ b/src/buffer.ml @@ -0,0 +1,28 @@ +open! Import + +include Buffer_intf + +include Caml.Buffer + +let contents_bytes = to_bytes + +let add_substring t s ~pos ~len = add_substring t s pos len +let add_subbytes t s ~pos ~len = add_subbytes t s pos len +let sexp_of_t t = sexp_of_string (contents t) + +module To_bytes = + Blit.Make_distinct + (struct + type nonrec t = t + let length = length + end) + (struct + type t = Bytes.t + let create ~len = Bytes.create len + let length = Bytes.length + let unsafe_blit ~src ~src_pos ~dst ~dst_pos ~len = + Caml.Buffer.blit src src_pos dst dst_pos len + end) + +include To_bytes +module To_string = Blit.Make_to_string (Caml.Buffer) (To_bytes) diff --git a/src/buffer.mli b/src/buffer.mli new file mode 100644 index 0000000..a5212a3 --- /dev/null +++ b/src/buffer.mli @@ -0,0 +1,8 @@ +(** Extensible character buffers. + + This module implements character buffers that automatically expand as necessary. It + provides cumulative concatenation of strings in quasi-linear time (instead of + quadratic time when strings are concatenated pairwise). +*) + +include Buffer_intf.Buffer (** @inline *) diff --git a/src/buffer_intf.ml b/src/buffer_intf.ml new file mode 100644 index 0000000..37ee8e5 --- /dev/null +++ b/src/buffer_intf.ml @@ -0,0 +1,81 @@ +open! Import + +module type S = sig + (** The abstract type of buffers. *) + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** [create n] returns a fresh buffer, initially empty. The [n] parameter is the + initial size of the internal storage medium that holds the buffer contents. That + storage is automatically reallocated when more than [n] characters are stored in the + buffer, but shrinks back to [n] characters when [reset] is called. + + For best performance, [n] should be of the same order of magnitude as the number of + characters that are expected to be stored in the buffer (for instance, 80 for a + buffer that holds one output line). Nothing bad will happen if the buffer grows + beyond that limit, however. In doubt, take [n = 16] for instance. *) + val create : int -> t + + (** Return a copy of the current contents of the buffer. The buffer itself is + unchanged. *) + val contents : t -> string + val contents_bytes : t -> bytes + + (** [blit ~src ~src_pos ~dst ~dst_pos ~len] copies [len] characters from the current + contents of the buffer [src], starting at offset [src_pos] to bytes [dst], starting + at character [dst_pos]. + + Raises [Invalid_argument] if [src_pos] and [len] do not designate a valid substring + of [src], or if [dst_pos] and [len] do not designate a valid substring of [dst]. *) + + include Blit.S_distinct with type src := t with type dst := bytes + module To_string : Blit.S_to_string with type t := t + + (** Gets the (zero-based) n-th character of the buffer. Raises [Invalid_argument] if + index out of bounds. *) + val nth : t -> int -> char + + (** Returns the number of characters currently contained in the buffer. *) + val length : t -> int + + (** Empties the buffer. *) + val clear : t -> unit + + (** Empties the buffer and deallocates the internal storage holding the buffer contents, + replacing it with the initial internal storage of length [n] that was allocated by + [create n]. For long-lived buffers that may have grown a lot, [reset] allows faster + reclamation of the space used by the buffer. *) + val reset : t -> unit + + (** [add_char b c] appends the character [c] at the end of the buffer [b]. *) + val add_char : t -> char -> unit + + (** [add_string b s] appends the string [s] at the end of the buffer [b]. *) + val add_string : t -> string -> unit + + (** [add_substring b s pos len] takes [len] characters from offset [pos] in string [s] + and appends them at the end of the buffer [b]. *) + val add_substring : t -> string -> pos:int -> len:int -> unit + + (** [add_bytes b s] appends the bytes [s] at the end of the buffer [b]. *) + val add_bytes : t -> bytes -> unit + + (** [add_subbytes b s pos len] takes [len] characters from offset [pos] in bytes [s] + and appends them at the end of the buffer [b]. *) + val add_subbytes : t -> bytes -> pos:int -> len:int -> unit + + (** [add_buffer b1 b2] appends the current contents of buffer [b2] at the end of buffer + [b1]. [b2] is not modified. *) + val add_buffer : t -> t -> unit +end + +module type Buffer = sig + module type S = S + + (** Buffers using strings as underlying storage medium: *) + + include S with type t = Caml.Buffer.t (** @open *) +end diff --git a/src/bytes.ml b/src/bytes.ml new file mode 100644 index 0000000..c701427 --- /dev/null +++ b/src/bytes.ml @@ -0,0 +1,125 @@ +open! Import + +let stage = Staged.stage + +module T = struct + type t = bytes [@@deriving_inline sexp] + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = bytes_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_bytes + [@@@end] + include Bytes0 + + let module_name = "Base.Bytes" + + let pp fmt t = Caml.Format.fprintf fmt "%S" (to_string t) +end + +include T + +module To_bytes = + Blit.Make + (struct + include T + let create ~len = create len + end) +include To_bytes + +include Comparator.Make(T) +include Comparable.Validate(T) + +include Pretty_printer.Register_pp(T) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Bytes_replace_polymorphic_compare + +module To_string = Blit.Make_to_string (T) (To_bytes) + +module From_string = Blit.Make_distinct(struct + type t = string + let length = String.length + end) + (struct + type nonrec t = t + let create ~len = create len + let length = length + let unsafe_blit = unsafe_blit_string + end) + +let init n ~f = + if Int_replace_polymorphic_compare.(<) n 0 then Printf.invalid_argf "Bytes.init %d" n (); + let t = create n in + for i = 0 to n - 1 do + unsafe_set t i (f i) + done; + t + +let of_char_list l = + let t = create (List.length l) in + List.iteri l ~f:(fun i c -> set t i c); + t + +let to_list t = + let rec loop t i acc = + if Int_replace_polymorphic_compare.(<) i 0 + then acc + else loop t (i - 1) (unsafe_get t i :: acc) + in + loop t (length t - 1) [] + +let tr ~target ~replacement s = + for i = 0 to length s - 1 do + if Char.equal (unsafe_get s i) target + then unsafe_set s i replacement + done + +let tr_multi ~target ~replacement = + if Int_replace_polymorphic_compare.(=) (String.length target) 0 + then stage ignore + else if Int_replace_polymorphic_compare.(=) (String.length replacement) 0 + then invalid_arg "tr_multi: replacement is the empty string" + else + match Bytes_tr.tr_create_map ~target ~replacement with + | None -> stage ignore + | Some tr_map -> + stage (fun s -> + for i = 0 to length s - 1 do + unsafe_set s i (String.unsafe_get tr_map (Char.to_int (unsafe_get s i))) + done) + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let contains ?pos ?len t char = + let (pos, len) = + Ordered_collection_common.get_pos_len_exn () ?pos ?len ~total_length:(length t) + in + let last = pos + len in + let rec loop i = + Int_replace_polymorphic_compare.(<) i last + && (Char.equal (get t i) char || loop (i + 1)) + in + loop pos +;; + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Bytes_replace_polymorphic_compare diff --git a/src/bytes.mli b/src/bytes.mli new file mode 100644 index 0000000..88f2480 --- /dev/null +++ b/src/bytes.mli @@ -0,0 +1,197 @@ +(** OCaml's byte sequence type, semantically similar to a [char array], but + taking less space in memory. + + A byte sequence is a mutable data structure that contains a fixed-length + sequence of bytes (of type [char]). Each byte can be indexed in constant + time for reading or writing. *) + +open! Import + +type t = bytes [@@deriving_inline sexp] +include +sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + + +(** {1 Common Interfaces} *) + +include Blit .S with type t := t +include Comparable .S with type t := t +include Stringable .S with type t := t + +(** Note that [pp] allocates in order to preserve the state of the byte + sequence it was initially called with. *) +include Pretty_printer.S with type t := t + +module To_string : sig + val sub : (t, string) Blit.sub + val subo : (t, string) Blit.subo +end + +module From_string : Blit.S_distinct with type src := string and type dst := t + +(** [create len] returns a newly-allocated and uninitialized byte sequence of + length [len]. No guarantees are made about the contents of the return + value. *) +val create : int -> t + +(** [make len c] returns a newly-allocated byte sequence of length [len] filled + with the byte [c]. *) +val make : int -> char -> t + +(** [copy t] returns a newly-allocated byte sequence that contains the same + bytes as [t]. *) +val copy : t -> t + +(** [init len ~f] returns a newly-allocated byte sequence of length [len] with + index [i] in the sequence being initialized with the result of [f i]. *) +val init : int -> f:(int -> char) -> t + +(** [of_char_list l] returns a newly-allocated byte sequence where each byte in + the sequence corresponds to the byte in [l] at the same index. *) +val of_char_list : char list -> t + +(** [length t] returns the number of bytes in [t]. *) +val length : t -> int + +(** [get t i] returns the [i]th byte of [t]. *) +val get : t -> int -> char +external unsafe_get : t -> int -> char = "%bytes_unsafe_get" + +(** [set t i c] sets the [i]th byte of [t] to [c]. *) +val set : t -> int -> char -> unit +external unsafe_set : t -> int -> char -> unit = "%bytes_unsafe_set" + +(** [fill t ~pos ~len c] modifies [t] in place, replacing all the bytes from + [pos] to [pos + len] with [c]. *) +val fill : t -> pos:int -> len:int -> char -> unit + +(** [tr ~target ~replacement t] modifies [t] in place, replacing every instance + of [target] in [s] with [replacement]. *) +val tr : target:char -> replacement:char -> t -> unit + +(** [tr_multi ~target ~replacement] returns an in-place function that replaces + every instance of a character in [target] with the corresponding character + in [replacement]. + + If [replacement] is shorter than [target], it is lengthened by repeating + its last character. Empty [replacement] is illegal unless [target] also is. + + If [target] contains multiple copies of the same character, the last + corresponding [replacement] character is used. Note that character ranges + are {b not} supported, so [~target:"a-z"] means the literal characters ['a'], + ['-'], and ['z']. *) +val tr_multi : target:string -> replacement:string -> (t -> unit) Staged.t + +(** [to_list t] returns the bytes in [t] as a list of chars. *) +val to_list : t -> char list + +(** [contains ?pos ?len t c] returns [true] iff [c] appears in [t] between [pos] + and [pos + len]. *) +val contains : ?pos:int -> ?len:int -> t -> char -> bool + +(** Maximum length of a byte sequence, which is architecture-dependent. Attempting to + create a [Bytes] larger than this will raise an exception. *) +val max_length : int + +(** {2:unsafe Unsafe conversions (for advanced users)} + + This section describes unsafe, low-level conversion functions between + [bytes] and [string]. They might not copy the internal data; used + improperly, they can break the immutability invariant on strings provided + by the [-safe-string] option. They are available for expert library + authors, but for most purposes you should use the always-correct + {!Bytes.to_string} and {!Bytes.of_string} instead. +*) + +(** Unsafely convert a byte sequence into a string. + + To reason about the use of [unsafe_to_string], it is convenient to + consider an "ownership" discipline. A piece of code that + manipulates some data "owns" it; there are several disjoint ownership + modes, including: + {ul + {- Unique ownership: the data may be accessed and mutated} + {- Shared ownership: the data has several owners, that may only + access it, not mutate it.}} + Unique ownership is linear: passing the data to another piece of + code means giving up ownership (we cannot access the + data again). A unique owner may decide to make the data shared + (giving up mutation rights on it), but shared data may not become + uniquely-owned again. + [unsafe_to_string s] can only be used when the caller owns the byte + sequence [s] -- either uniquely or as shared immutable data. The + caller gives up ownership of [s], and gains (the same mode of) ownership + of the returned string. + There are two valid use-cases that respect this ownership + discipline: + {ol + {- The first is creating a string by initializing and mutating a byte + sequence that is never changed after initialization is performed. + {[ + let string_init len f : string = + let s = Bytes.create len in + for i = 0 to len - 1 do Bytes.set s i (f i) done; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:s + ]} + This function is safe because the byte sequence [s] will never be + accessed or mutated after [unsafe_to_string] is called. The + [string_init] code gives up ownership of [s], and returns the + ownership of the resulting string to its caller. + + Note that it would be unsafe if [s] was passed as an additional + parameter to the function [f] as it could escape this way and be + mutated in the future -- [string_init] would give up ownership of + [s] to pass it to [f], and could not call [unsafe_to_string] + safely. + + We have provided the {!String.init}, {!String.map} and + {!String.mapi} functions to cover most cases of building + new strings. You should prefer those over [to_string] or + [unsafe_to_string] whenever applicable.} + {- The second is temporarily giving ownership of a byte sequence to + a function that expects a uniquely owned string and returns ownership + back, so that we can mutate the sequence again after the call ended. + {[ + let bytes_length (s : bytes) = + String.length + (Bytes.unsafe_to_string ~no_mutation_while_string_reachable:s) + ]} + In this use-case, we do not promise that [s] will never be mutated + after the call to [bytes_length s]. The {!String.length} function + temporarily borrows unique ownership of the byte sequence + (and sees it as a [string]), but returns this ownership back to + the caller, which may assume that [s] is still a valid byte + sequence after the call. Note that this is only correct because we + know that {!String.length} does not capture its argument -- it could + escape by a side-channel such as a memoization combinator. + The caller may not mutate [s] while the string is borrowed (it has + temporarily given up ownership). This affects concurrent programs, + but also higher-order functions: if {!String.length} returned + a closure to be called later, [s] should not be mutated until this + closure is fully applied and returns ownership.}} +*) +val unsafe_to_string : no_mutation_while_string_reachable:t -> string + +(** Unsafely convert a shared string to a byte sequence that should + not be mutated. + + The same ownership discipline that makes [unsafe_to_string] + correct applies to [unsafe_of_string_promise_no_mutation], + however unique ownership of string values is extremely difficult + to reason about correctly in practice. As such, one should always + assume strings are shared, never uniquely owned (For example, + string literals are implicitly shared by the compiler, so you + never uniquely own them) + + The only case we have reasonable confidence is safe is if the + produced [bytes] is shared -- used as an immutable byte + sequence. This is possibly useful for incremental migration of + low-level programs that manipulate immutable sequences of bytes + (for example {!Marshal.from_bytes}) and previously used the + [string] type for this purpose. +*) +val unsafe_of_string_promise_no_mutation : string -> t diff --git a/src/bytes0.ml b/src/bytes0.ml new file mode 100644 index 0000000..40ae910 --- /dev/null +++ b/src/bytes0.ml @@ -0,0 +1,61 @@ +(* [Bytes0] defines string functions that are primitives or can be simply + defined in terms of [Caml.Bytes]. [Bytes0] is intended to completely express + the part of [Caml.Bytes] that [Base] uses -- no other file in Base other + than bytes0.ml should use [Caml.Bytes]. [Bytes0] has few dependencies, and + so is available early in Base's build order. + + All Base files that need to use strings and come before [Base.Bytes] in + build order should do: + + {[ + module Bytes = Bytes0 + ]} + + Defining [module Bytes = Bytes0] is also necessary because it prevents + ocamldep from mistakenly causing a file to depend on [Base.Bytes]. *) + +let blit_string = Caml.Bytes.blit_string + +let sub_string t ~pos ~len = + Caml.Bytes.sub_string t pos len + +open! Import0 + +module Sys = Sys0 + +module Primitives = struct + external get : bytes -> int -> char = "%bytes_safe_get" + external length : bytes -> int = "%bytes_length" + external unsafe_get : bytes -> int -> char = "%bytes_unsafe_get" + include Bytes_set_primitives + + (* [unsafe_blit_string] is not exported in the [stdlib] so we export it here *) + external unsafe_blit_string + : src:string -> src_pos:int -> dst:bytes -> dst_pos:int -> len:int -> unit + = "caml_blit_string" [@@noalloc] +end + +include Primitives + +let max_length = Sys.max_string_length + +let blit = Caml.Bytes.blit +let compare = Caml.Bytes.compare +let copy = Caml.Bytes.copy +let create = Caml.Bytes.create +let fill = Caml.Bytes.fill +let make = Caml.Bytes.make +let sub = Caml.Bytes.sub +let unsafe_blit = Caml.Bytes.unsafe_blit + +let to_string = Caml.Bytes.to_string +let of_string = Caml.Bytes.of_string + +let unsafe_to_string ~no_mutation_while_string_reachable:s = + Caml.Bytes.unsafe_to_string s +let unsafe_of_string_promise_no_mutation = Caml.Bytes.unsafe_of_string + +(* These are eta expanded in order to label arguments, following the + Base conventions. *) +let blit_string ~src ~src_pos ~dst ~dst_pos ~len = + blit_string src src_pos dst dst_pos len diff --git a/src/bytes_tr.ml b/src/bytes_tr.ml new file mode 100644 index 0000000..b8f6e89 --- /dev/null +++ b/src/bytes_tr.ml @@ -0,0 +1,40 @@ +open! Import0.Int_replace_polymorphic_compare + +module Bytes = Bytes0 +module String = String0 + +(* Construct a byte string of length 256, mapping every input character code to + its corresponding output character. + + Benchmarks indicate that this is faster than the lambda (including cost of + this function), even if target/replacement are just 2 characters each. + + Return None if the translation map is equivalent to just the identity. *) +let tr_create_map ~target ~replacement = + let tr_map = Bytes.create 256 in + for i = 0 to 255 do + Bytes.unsafe_set tr_map i (Char.of_int_exn i) + done; + for i = 0 to (min (String.length target) (String.length replacement)) - 1 do + let index = Char.to_int (String.unsafe_get target i) in + Bytes.unsafe_set tr_map index (String.unsafe_get replacement i) + done; + let last_replacement = String.unsafe_get replacement (String.length replacement - 1) in + for i = min (String.length target) (String.length replacement) to String.length target - 1 do + let index = Char.to_int (String.unsafe_get target i) in + Bytes.unsafe_set tr_map index last_replacement + done; + + let rec have_any_different tr_map i = + if i = 256 + then false + else if Char.(<>) (Bytes0.unsafe_get tr_map i) (Char.of_int_exn i) + then true + else have_any_different tr_map (i + 1) + in + (* quick check on the first target character which will 99% be true *) + let first_target = String.get target 0 in + if Char.(<>) (Bytes0.unsafe_get tr_map (Char.to_int first_target)) first_target + || have_any_different tr_map 0 + then Some (Bytes0.unsafe_to_string ~no_mutation_while_string_reachable:tr_map) + else None diff --git a/src/char.ml b/src/char.ml new file mode 100644 index 0000000..ffa06b1 --- /dev/null +++ b/src/char.ml @@ -0,0 +1,105 @@ +open! Import + +module Array = Array0 +module String = String0 + +include Char0 + +module T = struct + type t = char [@@deriving_inline compare, hash, sexp] + let compare : t -> t -> int = compare_char + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_char + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_char in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = char_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_char + [@@@end] + + let to_string t = String.make 1 t + + let of_string s = + match String.length s with + | 1 -> String.get s 0 + | _ -> failwithf "Char.of_string: %S" s () +end + +include T + +include Identifiable.Make (struct + include T + let module_name = "Base.Char" + end) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Char_replace_polymorphic_compare + +let all = + Array.init 256 ~f:unsafe_of_int + |> Array.to_list + +let is_lowercase = function + | 'a' .. 'z' -> true + | _ -> false + +let is_uppercase = function + | 'A' .. 'Z' -> true + | _ -> false + +let is_print = function + | ' ' .. '~' -> true + | _ -> false + +let is_whitespace = function + | '\t' + | '\n' + | '\011' (* vertical tab *) + | '\012' (* form feed *) + | '\r' + | ' ' + -> true + | _ + -> false +;; + +let is_digit = function + | '0' .. '9' -> true + | _ -> false + +let is_alpha = function + | 'a' .. 'z' | 'A' .. 'Z' -> true + | _ -> false + +(* Writing these out, instead of calling [is_alpha] and [is_digit], reduces + runtime by approx. 30% *) +let is_alphanum = function + | 'a' .. 'z' | 'A' .. 'Z' | '0' .. '9' -> true + | _ -> false + +let get_digit_unsafe t = to_int t - to_int '0' + +let get_digit_exn t = + if is_digit t + then get_digit_unsafe t + else failwithf "Char.get_digit_exn %C: not a digit" t () +;; + +let get_digit t = if is_digit t then Some (get_digit_unsafe t) else None + +module O = struct + let ( >= ) = ( >= ) + let ( <= ) = ( <= ) + let ( = ) = ( = ) + let ( > ) = ( > ) + let ( < ) = ( < ) + let ( <> ) = ( <> ) +end + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Char_replace_polymorphic_compare diff --git a/src/char.mli b/src/char.mli new file mode 100644 index 0000000..adfd960 --- /dev/null +++ b/src/char.mli @@ -0,0 +1,73 @@ +(** A type for 8-bit characters. *) + +open! Import + +(** An alias for the type of characters. *) +type t = char [@@deriving_inline enumerate, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Identifiable.S with type t := t + +module O : Comparisons.Infix with type t := t + +(** Returns the ASCII code of the argument. *) +val to_int : t -> int + +(** Returns the character with the given ASCII code or [None] is the argument is outside + the range 0 to 255. *) +val of_int : int -> t option + +(** Returns the character with the given ASCII code. Raises [Failure] if the argument is + outside the range 0 to 255. *) +val of_int_exn : int -> t + +val unsafe_of_int : int -> t + +(** Returns a string representing the given character, with special characters escaped + following the lexical conventions of OCaml. *) +val escaped : t -> string + +(** Converts the given character to its equivalent lowercase character. *) +val lowercase : t -> t + +(** Converts the given character to its equivalent uppercase character. *) +val uppercase : t -> t + +(** '0' - '9' *) +val is_digit : t -> bool + +(** 'a' - 'z' *) +val is_lowercase : t -> bool + +(** 'A' - 'Z' *) +val is_uppercase : t -> bool + +(** 'a' - 'z' or 'A' - 'Z' *) +val is_alpha : t -> bool + +(** 'a' - 'z' or 'A' - 'Z' or '0' - '9' *) +val is_alphanum : t -> bool + +(** ' ' - '~' *) +val is_print : t -> bool + +(** ' ' or '\t' or '\r' or '\n' *) +val is_whitespace : t -> bool + +(** Returns [Some i] if [is_digit c] and [None] otherwise. *) +val get_digit : t -> int option + +(** Returns [i] if [is_digit c] and raises [Failure] otherwise. *) +val get_digit_exn : t -> int + +val min_value : t +val max_value : t diff --git a/src/char0.ml b/src/char0.ml new file mode 100644 index 0000000..1e90989 --- /dev/null +++ b/src/char0.ml @@ -0,0 +1,39 @@ +(* [Char0] defines char functions that are primitives or can be simply defined in terms of + [Caml.Char]. [Char0] is intended to completely express the part of [Caml.Char] that + [Base] uses -- no other file in Base other than char0.ml should use [Caml.Char]. + [Char0] has few dependencies, and so is available early in Base's build order. All + Base files that need to use chars and come before [Base.Char] in build order should do + [module Char = Char0]. Defining [module Char = Char0] is also necessary because it + prevents ocamldep from mistakenly causing a file to depend on [Base.Char]. *) + +open! Import0 + +let failwithf = Printf.failwithf + +let escaped = Caml.Char.escaped +let lowercase = Caml.Char.lowercase_ascii +let to_int = Caml.Char.code +let unsafe_of_int = Caml.Char.unsafe_chr +let uppercase = Caml.Char.uppercase_ascii + +(* We use our own range test when converting integers to chars rather than + calling [Caml.Char.chr] because it's simple and it saves us a function call + and the try-with (exceptions cost, especially in the world with backtraces). *) +let int_is_ok i = 0 <= i && i <= 255 + +let min_value = unsafe_of_int 0 +let max_value = unsafe_of_int 255 + +let of_int i = + if int_is_ok i + then Some (unsafe_of_int i) + else None +;; + +let of_int_exn i = + if int_is_ok i + then unsafe_of_int i + else failwithf "Char.of_int_exn got integer out of range: %d" i () +;; + +let equal (t1 : char) t2 = Poly.equal t1 t2 diff --git a/src/comparable.ml b/src/comparable.ml new file mode 100644 index 0000000..97c9d6d --- /dev/null +++ b/src/comparable.ml @@ -0,0 +1,217 @@ +open! Import + +include Comparable_intf + +module Validate + (T : sig type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) : Validate with type t := T.t = +struct + + module V = Validate + open Maybe_bound + + let to_string t = Sexp.to_string (T.sexp_of_t t) + + let validate_bound ~min ~max t = + V.bounded ~name:to_string ~lower:min ~upper:max ~compare:T.compare t + ;; + + let validate_lbound ~min t = validate_bound ~min ~max:Unbounded t + let validate_ubound ~max t = validate_bound ~max ~min:Unbounded t +end + +module With_zero + (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val zero : t + include Validate with type t := t + end) = struct + open T + + (* Preallocate the interesting bounds to minimize allocation in the implementations of + [validate_*]. *) + let excl_zero = Maybe_bound.Excl zero + let incl_zero = Maybe_bound.Incl zero + + let validate_positive t = validate_lbound ~min:excl_zero t + let validate_non_negative t = validate_lbound ~min:incl_zero t + let validate_negative t = validate_ubound ~max:excl_zero t + let validate_non_positive t = validate_ubound ~max:incl_zero t + let is_positive t = compare t zero > 0 + let is_non_negative t = compare t zero >= 0 + let is_negative t = compare t zero < 0 + let is_non_positive t = compare t zero <= 0 + let sign t = Sign0.of_int (compare t zero) +end + +module Validate_with_zero + (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val zero : t + end) = struct + module V = Validate (T) + include V + include With_zero (struct include T include V end) +end + +module Poly (T : sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) = struct + module Replace_polymorphic_compare = struct + type t = T.t [@@deriving_inline sexp_of] + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = T.sexp_of_t + [@@@end] + include Poly + end + include Poly + + let between t ~low ~high = low <= t && t <= high + + let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + + let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + + let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + + module C = struct + include T + include Comparator.Make (Replace_polymorphic_compare) + end + include C + include Validate (struct type nonrec t = t [@@deriving_inline compare, sexp_of] + let compare : t -> t -> int = compare + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_t + [@@@end] end) +end + +module Make_using_comparator (T : sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + include Comparator.S with type t := t + end) : S with type t := T.t and type comparator_witness = T.comparator_witness = struct + module T = struct + include T + let compare = comparator.compare + end + include T + + module Replace_polymorphic_compare = struct + module Without_squelch = struct + let compare = T.compare + let (>) a b = compare a b > 0 + let (<) a b = compare a b < 0 + let (>=) a b = compare a b >= 0 + let (<=) a b = compare a b <= 0 + let (=) a b = compare a b = 0 + let (<>) a b = compare a b <> 0 + let equal = (=) + let min t t' = if t <= t' then t else t' + let max t t' = if t >= t' then t else t' + end + include Without_squelch + end + include Replace_polymorphic_compare.Without_squelch + let ascending = compare + let descending t t' = compare t' t + let between t ~low ~high = low <= t && t <= high + + let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + + let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + + let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + + include Validate (T) +end + +module Make (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end) = Make_using_comparator(struct + include T + include Comparator.Make (T) + end) + +module Inherit + (C : sig type t [@@deriving_inline compare] + include sig [@@@ocaml.warning "-32"] val compare : t -> t -> int end[@@ocaml.doc + "@inline"] + [@@@end] end) + (T : sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val component : t -> C.t + end) = + Make (struct + type t = T.t [@@deriving_inline sexp_of] + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = T.sexp_of_t + [@@@end] + let compare t t' = C.compare (T.component t) (T.component t') + end) + +(* compare [x] and [y] lexicographically using functions in the list [cmps] *) +let lexicographic cmps x y = + let rec loop = function + | cmp :: cmps -> let res = cmp x y in if res = 0 then loop cmps else res + | [] -> 0 + in + loop cmps +;; + +let lift cmp ~f x y = cmp (f x) (f y) diff --git a/src/comparable.mli b/src/comparable.mli new file mode 100644 index 0000000..da677d9 --- /dev/null +++ b/src/comparable.mli @@ -0,0 +1 @@ +include Comparable_intf.Comparable (** @inline *) diff --git a/src/comparable_intf.ml b/src/comparable_intf.ml new file mode 100644 index 0000000..b09359b --- /dev/null +++ b/src/comparable_intf.ml @@ -0,0 +1,217 @@ +open! Import + +module type Infix = Comparisons.Infix +module type Polymorphic_compare = Comparisons.S + +module type Validate = sig + type t + + val validate_lbound : min : t Maybe_bound.t -> t Validate.check + val validate_ubound : max : t Maybe_bound.t -> t Validate.check + val validate_bound + : min : t Maybe_bound.t + -> max : t Maybe_bound.t + -> t Validate.check +end + +module type With_zero = sig + type t + + val validate_positive : t Validate.check + val validate_non_negative : t Validate.check + val validate_negative : t Validate.check + val validate_non_positive : t Validate.check + val is_positive : t -> bool + val is_non_negative : t -> bool + val is_negative : t -> bool + val is_non_positive : t -> bool + + (** Returns [Neg], [Zero], or [Pos] in a way consistent with the above functions. *) + val sign : t -> Sign0.t +end + +module type S = sig + include Polymorphic_compare + + (** [ascending] is identical to [compare]. [descending x y = ascending y x]. These are + intended to be mnemonic when used like [List.sort ~compare:ascending] and [List.sort + ~cmp:descending], since they cause the list to be sorted in ascending or descending + order, respectively. *) + val ascending : t -> t -> int + val descending : t -> t -> int + + (** [between t ~low ~high] means [low <= t <= high] *) + val between : t -> low:t -> high:t -> bool + + (** [clamp_exn t ~min ~max] returns [t'], the closest value to [t] such that + [between t' ~low:min ~high:max] is true. + + Raises if [not (min <= max)]. *) + val clamp_exn : t -> min:t -> max:t -> t + val clamp : t -> min:t -> max:t -> t Or_error.t + + include Comparator.S with type t := t + + include Validate with type t := t +end + +(** Usage example: + + {[ + module Foo : sig + type t = ... + include Comparable.S with type t := t + end + ]} + + Then use [Comparable.Make] in the struct (see comparable.mli for an example). *) + +module type Comparable = sig + (** Defines functors for making modules comparable. *) + + (** Usage example: + + {[ + module Foo = struct + module T = struct + type t = ... [@@deriving_inline compare, sexp][@@@end] + end + include T + include Comparable.Make (T) + end + ]} + + Then include [Comparable.S] in the signature + + {[ + module Foo : sig + type t = ... + include Comparable.S with type t := t + end + ]} + + To add an [Infix] submodule: + + {[ + module C = Comparable.Make (T) + include C + module Infix = (C : Comparable.Infix with type t := t) + ]} + + A common pattern is to define a module [O] with a restricted signature. It aims to be + (locally) opened to bring useful operators into scope without shadowing unexpected + variable names. E.g., in the [Date] module: + + {[ + module O = struct + include (C : Comparable.Infix with type t := t) + let to_string t = .. + end + ]} + + Opening [Date] would shadow [now], but opening [Date.O] doesn't: + + {[ + let now = .. in + let someday = .. in + Date.O.(now > someday) + ]} *) + + + module type Infix = Infix + module type S = S + module type Polymorphic_compare = Polymorphic_compare + module type Validate = Validate + module type With_zero = With_zero + + (** [lexicographic cmps x y] compares [x] and [y] lexicographically using functions in the + list [cmps]. *) + val lexicographic : ('a -> 'a -> int) list -> 'a -> 'a -> int + + (** [lift cmp ~f x y] compares [x] and [y] by comparing [f x] and [f y] via [cmp]. *) + val lift : ('a -> 'a -> 'int_or_bool) -> f:('b -> 'a) -> ('b -> 'b -> 'int_or_bool) + + (** Inherit comparability from a component. *) + module Inherit + (C : sig type t [@@deriving_inline compare] + include sig [@@@ocaml.warning "-32"] val compare : t -> t -> int end[@@ocaml.doc + "@inline"] + [@@@end] end) + (T : sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val component : t -> C.t + end) : S with type t := T.t + + module Make (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end) : S with type t := T.t + + module Make_using_comparator (T : sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + include Comparator.S with type t := t + end) : S + with type t := T.t + with type comparator_witness := T.comparator_witness + + module Poly (T : sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) : S with type t := T.t + + module Validate (T : sig type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) + : Validate with type t := T.t + + module With_zero + (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val zero : t + include Validate with type t := t + end) : With_zero with type t := T.t + + module Validate_with_zero + (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val zero : t + end) + : sig + include Validate with type t := T.t + include With_zero with type t := T.t + end +end diff --git a/src/comparator.ml b/src/comparator.ml new file mode 100644 index 0000000..f843bb4 --- /dev/null +++ b/src/comparator.ml @@ -0,0 +1,127 @@ +open! Import + +type ('a, 'witness) t = + { compare : 'a -> 'a -> int + ; sexp_of_t : 'a -> Sexp.t + } + +type ('a, 'b) comparator = ('a, 'b) t + +module type S = sig + type t + type comparator_witness + val comparator : (t, comparator_witness) comparator +end + +module type S1 = sig + type 'a t + type comparator_witness + val comparator : ('a t, comparator_witness) comparator +end + +module type S_fc = sig + type comparable_t + include S with type t := comparable_t +end + +let make (type t) ~compare ~sexp_of_t = + (module struct + type comparable_t = t + type comparator_witness + let comparator = { compare; sexp_of_t } + end : S_fc with type comparable_t = t) + +module S_to_S1 (S : S) = struct + type 'a t = S.t + type comparator_witness = S.comparator_witness + open S + let comparator = comparator +end + +module Make (M : sig type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) = struct + include M + type comparator_witness + let comparator = M.({ compare; sexp_of_t }) +end + +module Make1 (M : sig + type 'a t + val compare : 'a t -> 'a t -> int + val sexp_of_t : 'a t -> Sexp.t + end) = struct + type comparator_witness + let comparator = M.({ compare; sexp_of_t }) +end + +module Poly = struct + type 'a t = 'a + include Make1 (struct + type 'a t = 'a + let compare = Poly.compare + let sexp_of_t _ = Sexp.Atom "_" + end) +end + +module type Derived = sig + type 'a t + type 'cmp comparator_witness + + val comparator + : ('a, 'cmp) comparator + -> ('a t, 'cmp comparator_witness) comparator +end + +module Derived (M : sig type 'a t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) = struct + type 'cmp comparator_witness + + let comparator a = { + compare = M.compare a.compare; + sexp_of_t = M.sexp_of_t a.sexp_of_t; + } +end + +module type Derived2 = sig + type ('a, 'b) t + type ('cmp_a, 'cmp_b) comparator_witness + + val comparator + : ('a, 'cmp_a) comparator + -> ('b, 'cmp_b) comparator + -> (('a, 'b) t, ('cmp_a, 'cmp_b) comparator_witness) comparator +end + +module Derived2 (M : sig type ('a, 'b) t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('a -> 'a -> int) -> + ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('b -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 'b) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) = struct + type ('cmp_a, 'cmp_b) comparator_witness + + let comparator a b = { + compare = M.compare a.compare b.compare; + sexp_of_t = M.sexp_of_t a.sexp_of_t b.sexp_of_t; + } +end diff --git a/src/comparator.mli b/src/comparator.mli new file mode 100644 index 0000000..c069f62 --- /dev/null +++ b/src/comparator.mli @@ -0,0 +1,121 @@ +(** A type-indexed value that allows one to compare (and for generating error messages, + serialize) values of the type in question. + + One of the type parameters is a phantom parameter used to distinguish comparators + potentially built on different comparison functions. In particular, we want to + distinguish those using polymorphic compare from those using a monomorphic compare. *) + +open! Import + +type ('a, 'witness) t = + private + { compare : 'a -> 'a -> int + ; sexp_of_t : 'a -> Sexp.t + } + +type ('a, 'b) comparator = ('a, 'b) t + +module type S = sig + type t + type comparator_witness + val comparator : (t, comparator_witness) comparator +end + +module type S1 = sig + type 'a t + type comparator_witness + val comparator : ('a t, comparator_witness) comparator +end + +module type S_fc = sig + type comparable_t + include S with type t := comparable_t +end + +(** [make] creates a comparator witness for the given comparison. It is intended as a + lightweight alternative to the functors below, to be used like so: + [include (val Comparator.make ~compare ~sexp_of_t)] *) +val make + : compare:('a -> 'a -> int) + -> sexp_of_t:('a -> Sexp.t) + -> (module S_fc with type comparable_t = 'a) + +module Poly : S1 with type 'a t = 'a + +module S_to_S1 (S : S) : S1 + with type 'a t = S.t + with type comparator_witness = S.comparator_witness + +(** [Make] creates a [comparator] value and its phantom [comparator_witness] type for a + nullary type. *) +module Make (M : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end) : S with type t := M.t + +(** [Make1] creates a [comparator] value and its phantom [comparator_witness] type for a + unary type. It takes a [compare] and [sexp_of_t] that have + non-standard types because the [Comparator.t] type doesn't allow passing in + additional values for the type argument. *) +module Make1 (M : sig + type 'a t + val compare : 'a t -> 'a t -> int + val sexp_of_t : _ t -> Sexp.t + end) : S1 with type 'a t := 'a M.t + +module type Derived = sig + type 'a t + type 'cmp comparator_witness + + val comparator + : ('a, 'cmp) comparator + -> ('a t, 'cmp comparator_witness) comparator +end + +(** [Derived] creates a [comparator] function that constructs a comparator for the type + ['a t] given a comparator for the type ['a]. *) +module Derived (M : sig + type 'a t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end) : Derived with type 'a t := 'a M.t + +module type Derived2 = sig + type ('a, 'b) t + type ('cmp_a, 'cmp_b) comparator_witness + + val comparator + : ('a, 'cmp_a) comparator + -> ('b, 'cmp_b) comparator + -> (('a, 'b) t, ('cmp_a, 'cmp_b) comparator_witness) comparator +end + +(** [Derived2] creates a [comparator] function that constructs a comparator for the type + [('a, 'b) t] given comparators for the type ['a] and ['b]. *) +module Derived2 (M : sig + type ('a, 'b) t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('a -> 'a -> int) -> + ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('b -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 'b) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end) : Derived2 with type ('a, 'b) t := ('a, 'b) M.t diff --git a/src/comparisons.ml b/src/comparisons.ml new file mode 100644 index 0000000..1cbf91c --- /dev/null +++ b/src/comparisons.ml @@ -0,0 +1,29 @@ +(** Interfaces for infix comparison operators and comparison functions. *) + +open! Import + +(** [Infix] lists the typical infix comparison operators. These functions are provided by + [.O] modules, i.e., modules that expose monomorphic infix comparisons over some + [.t]. *) +module type Infix = sig + type t + val ( >= ) : t -> t -> bool + val ( <= ) : t -> t -> bool + val ( = ) : t -> t -> bool + val ( > ) : t -> t -> bool + val ( < ) : t -> t -> bool + val ( <> ) : t -> t -> bool +end + +module type S = sig + include Infix + + val equal : t -> t -> bool + + (** [compare t1 t2] returns 0 if [t1] is equal to [t2], a negative integer if [t1] is + less than [t2], and a positive integer if [t1] is greater than [t2]. *) + val compare : t -> t -> int + + val min : t -> t -> t + val max : t -> t -> t +end diff --git a/src/container.ml b/src/container.ml new file mode 100644 index 0000000..7a253cd --- /dev/null +++ b/src/container.ml @@ -0,0 +1,175 @@ +open! Import + +module Array = Array0 +module List = List0 + +include Container_intf + +let with_return = With_return.with_return + +type ('t, 'a, 'accum) fold = 't -> init:'accum -> f:('accum -> 'a -> 'accum) -> 'accum +type ('t, 'a) iter = 't -> f:('a -> unit) -> unit +type 't length = 't -> int + +let iter ~fold t ~f = fold t ~init:() ~f:(fun () a -> f a) + +let count ~fold t ~f = fold t ~init:0 ~f:(fun n a -> if f a then n + 1 else n) + +let sum (type a) ~fold (module M : Summable with type t = a) t ~f = + fold t ~init:M.zero ~f:(fun n a -> M.(+) n (f a)) +;; + +let fold_result ~fold ~init ~f t = + with_return (fun {return} -> + Result.Ok (fold t ~init ~f:(fun acc item -> + match f acc item with + | Result.Ok x -> x + | Error _ as e -> return e))) +;; + +let fold_until ~fold ~init ~f ~finish t = + with_return (fun {return} -> + finish (fold t ~init ~f:(fun acc item -> + match f acc item with + | Continue_or_stop.Continue x -> x + | Stop x -> return x))) +;; + +let min_elt ~fold t ~compare = + fold t ~init:None ~f:(fun acc elt -> + match acc with + | None -> Some elt + | Some min -> if compare min elt > 0 then Some elt else acc) +;; + +let max_elt ~fold t ~compare = + fold t ~init:None ~f:(fun acc elt -> + match acc with + | None -> Some elt + | Some max -> if compare max elt < 0 then Some elt else acc) +;; + +let length ~fold c = fold c ~init:0 ~f:(fun acc _ -> acc + 1) + +let is_empty ~iter c = + with_return (fun r -> + iter c ~f:(fun _ -> r.return false); + true) +;; + +let exists ~iter c ~f = + with_return (fun r -> + iter c ~f:(fun x -> if f x then r.return true); + false) +;; + +let for_all ~iter c ~f = + with_return (fun r -> + iter c ~f:(fun x -> if not (f x) then r.return false); + true) +;; + +let find_map ~iter t ~f = + with_return (fun r -> + iter t ~f:(fun x -> match f x with None -> () | Some _ as res -> r.return res); + None) +;; + +let find ~iter c ~f = + with_return (fun r -> + iter c ~f:(fun x -> if f x then r.return (Some x)); + None) +;; + +let to_list ~fold c = List.rev (fold c ~init:[] ~f:(fun acc x -> x :: acc)) + +let to_array ~length ~iter c = + let array = ref [||] in + let i = ref 0 in + iter c ~f:(fun x -> + if !i = 0 then (array := Array.create ~len:(length c) x); + !array.(!i) <- x; + incr i); + !array +;; + +module Make_gen (T : Make_gen_arg) : sig + include Generic with type 'a t := 'a T.t with type 'a elt := 'a T.elt +end = struct + let fold = T.fold + + let iter = + match T.iter with + | `Custom iter -> iter + | `Define_using_fold -> fun t ~f -> iter ~fold t ~f + ;; + + let length = + match T.length with + | `Custom length -> length + | `Define_using_fold -> fun t -> length ~fold t + ;; + + let is_empty t = is_empty ~iter t + let sum m t = sum ~fold m t + let count t ~f = count ~fold t ~f + let exists t ~f = exists ~iter t ~f + let for_all t ~f = for_all ~iter t ~f + let find_map t ~f = find_map ~iter t ~f + let find t ~f = find ~iter t ~f + let to_list t = to_list ~fold t + let to_array t = to_array ~length ~iter t + + let min_elt t ~compare = min_elt ~fold t ~compare + let max_elt t ~compare = max_elt ~fold t ~compare + + let fold_result t ~init ~f = fold_result t ~fold ~init ~f + let fold_until t ~init ~f ~finish = fold_until t ~fold ~init ~f ~finish +end + +module Make (T : Make_arg) = struct + include + Make_gen (struct + include T + type 'a elt = 'a + end) + + let mem t a ~equal = exists t ~f:(equal a) +end + +module Make0 (T : Make0_arg) = struct + include + Make_gen (struct + include (T : Make0_arg with type t := T.t with module Elt := T.Elt) + type 'a t = T.t + type 'a elt = T.Elt.t + end) + + let mem t elt = exists t ~f:(T.Elt.equal elt) +end + +open T + + +(* The following functors exist as a consistency check among all the various [S?] + interfaces. They ensure that each particular [S?] is an instance of a more generic + signature. *) +module Check (T : T1) (Elt : T1) + (M : Generic with type 'a t := 'a T.t with type 'a elt := 'a Elt.t) = struct end + +module Check_S0 (M : S0) = + Check (struct type 'a t = M.t end) (struct type 'a t = M.elt end) (M) + +module Check_S0_phantom (M : S0_phantom) = + Check (struct type 'a t = 'a M.t end) (struct type 'a t = M.elt end) (M) + +module Check_S1 (M : S1) = + Check (struct type 'a t = 'a M.t end) (struct type 'a t = 'a end) (M) + +type phantom + +module Check_S1_phantom (M : S1_phantom) = + Check (struct type 'a t = ('a, phantom) M.t end) (struct type 'a t = 'a end) (M) + +module Check_S1_phantom_invariant (M : S1_phantom_invariant) = + Check (struct type 'a t = ('a, phantom) M.t end) (struct type 'a t = 'a end) (M) diff --git a/src/container.mli b/src/container.mli new file mode 100644 index 0000000..37313b1 --- /dev/null +++ b/src/container.mli @@ -0,0 +1 @@ +include Container_intf.Container (** @inline *) diff --git a/src/container_intf.ml b/src/container_intf.ml new file mode 100644 index 0000000..9c41fe4 --- /dev/null +++ b/src/container_intf.ml @@ -0,0 +1,624 @@ +(** Provides generic signatures for container data structures. + + These signatures include functions ([iter], [fold], [exists], [for_all], ...) that + you would expect to find in any container. Used by including [Container.S0] or + [Container.S1] in the signature for every container-like data structure ([Array], + [List], [String], ...) to ensure a consistent interface. *) + +open! Import + +module Export = struct + (** [Continue_or_stop.t] is used by the [f] argument to [fold_until] in order to + indicate whether folding should continue, or stop early. *) + module Continue_or_stop = struct + type ('a, 'b) t = + | Continue of 'a + | Stop of 'b + end +end +include Export + +module type Summable = sig + type t + + (** The result of summing no values. *) + val zero : t + + (** An operation that combines two [t]'s and handles [zero + x] by just returning [x], + as well as in the symmetric case. *) + val (+) : t -> t -> t +end + +(** Signature for monomorphic container, e.g., string. *) +module type S0 = sig + type t + type elt + + (** Checks whether the provided element is there, using equality on [elt]s. *) + val mem : t -> elt -> bool + + val length : t -> int + + val is_empty : t -> bool + + (** [iter] must allow exceptions raised in [f] to escape, terminating the iteration + cleanly. The same holds for all functions below taking an [f]. *) + val iter : t -> f:(elt -> unit) -> unit + + (** [fold t ~init ~f] returns [f (... f (f (f init e1) e2) e3 ...) en], where [e1..en] + are the elements of [t]. *) + val fold : t -> init:'accum -> f:('accum -> elt -> 'accum) -> 'accum + + (** [fold_result t ~init ~f] is a short-circuiting version of [fold] that runs in the + [Result] monad. If [f] returns an [Error _], that value is returned without any + additional invocations of [f]. *) + val fold_result + : t + -> init:'accum + -> f:('accum -> elt -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + (** [fold_until t ~init ~f ~finish] is a short-circuiting version of [fold]. If [f] + returns [Stop _] the computation ceases and results in that value. If [f] returns + [Continue _], the fold will proceed. If [f] never returns [Stop _], the final result + is computed by [finish]. + + Example: + + {[ + type maybe_negative = + | Found_negative of int + | All_nonnegative of { sum : int } + + (** [first_neg_or_sum list] returns the first negative number in [list], if any, + otherwise returns the sum of the list. *) + let first_neg_or_sum = + List.fold_until ~init:0 + ~f:(fun sum x -> + if x < 0 + then Stop (Found_negative x) + else Continue (sum + x)) + ~finish:(fun sum -> All_nonnegative { sum }) + ;; + + let x = first_neg_or_sum [1; 2; 3; 4; 5] + val x : maybe_negative = All_nonnegative {sum = 15} + + let y = first_neg_or_sum [1; 2; -3; 4; 5] + val y : maybe_negative = Found_negative -3 + ]} *) + val fold_until + : t + -> init:'accum + -> f:('accum -> elt -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + (** Returns [true] if and only if there exists an element for which the provided + function evaluates to [true]. This is a short-circuiting operation. *) + val exists : t -> f:(elt -> bool) -> bool + + (** Returns [true] if and only if the provided function evaluates to [true] for all + elements. This is a short-circuiting operation. *) + val for_all : t -> f:(elt -> bool) -> bool + + (** Returns the number of elements for which the provided function evaluates to true. *) + val count : t -> f:(elt -> bool) -> int + + (** Returns the sum of [f i] for all [i] in the container. *) + val sum + : (module Summable with type t = 'sum) + -> t -> f:(elt -> 'sum) -> 'sum + + (** Returns as an [option] the first element for which [f] evaluates to true. *) + val find : t -> f:(elt -> bool) -> elt option + + (** Returns the first evaluation of [f] that returns [Some], and returns [None] if there + is no such element. *) + val find_map : t -> f:(elt -> 'a option) -> 'a option + + val to_list : t -> elt list + val to_array : t -> elt array + + (** Returns a min (resp. max) element from the collection using the provided [compare] + function. In case of a tie, the first element encountered while traversing the + collection is returned. The implementation uses [fold] so it has the same + complexity as [fold]. Returns [None] iff the collection is empty. *) + val min_elt : t -> compare:(elt -> elt -> int) -> elt option + val max_elt : t -> compare:(elt -> elt -> int) -> elt option +end + +module type S0_phantom = sig + type elt + type 'a t + + (** Checks whether the provided element is there, using equality on [elt]s. *) + val mem : _ t -> elt -> bool + + val length : _ t -> int + + val is_empty : _ t -> bool + + val iter : _ t -> f:(elt -> unit) -> unit + + (** [fold t ~init ~f] returns [f (... f (f (f init e1) e2) e3 ...) en], where [e1..en] + are the elements of [t]. *) + val fold : _ t -> init:'accum -> f:('accum -> elt -> 'accum) -> 'accum + + (** [fold_result t ~init ~f] is a short-circuiting version of [fold] that runs in the + [Result] monad. If [f] returns an [Error _], that value is returned without any + additional invocations of [f]. *) + val fold_result + : _ t + -> init:'accum + -> f:('accum -> elt -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + (** [fold_until t ~init ~f ~finish] is a short-circuiting version of [fold]. If [f] + returns [Stop _] the computation ceases and results in that value. If [f] returns + [Continue _], the fold will proceed. If [f] never returns [Stop _], the final result + is computed by [finish]. + + Example: + + {[ + type maybe_negative = + | Found_negative of int + | All_nonnegative of { sum : int } + + (** [first_neg_or_sum list] returns the first negative number in [list], if any, + otherwise returns the sum of the list. *) + let first_neg_or_sum = + List.fold_until ~init:0 + ~f:(fun sum x -> + if x < 0 + then Stop (Found_negative x) + else Continue (sum + x)) + ~finish:(fun sum -> All_nonnegative { sum }) + ;; + + let x = first_neg_or_sum [1; 2; 3; 4; 5] + val x : maybe_negative = All_nonnegative {sum = 15} + + let y = first_neg_or_sum [1; 2; -3; 4; 5] + val y : maybe_negative = Found_negative -3 + ]} *) + val fold_until + : _ t + -> init:'accum + -> f:('accum -> elt -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + (** Returns [true] if and only if there exists an element for which the provided + function evaluates to [true]. This is a short-circuiting operation. *) + val exists : _ t -> f:(elt -> bool) -> bool + + (** Returns [true] if and only if the provided function evaluates to [true] for all + elements. This is a short-circuiting operation. *) + val for_all : _ t -> f:(elt -> bool) -> bool + + (** Returns the number of elements for which the provided function evaluates to true. *) + val count : _ t -> f:(elt -> bool) -> int + + (** Returns the sum of [f i] for all [i] in the container. The order in which the + elements will be summed is unspecified. *) + val sum + : (module Summable with type t = 'sum) + -> _ t -> f:(elt -> 'sum) -> 'sum + + (** Returns as an [option] the first element for which [f] evaluates to true. *) + val find : _ t -> f:(elt -> bool) -> elt option + + (** Returns the first evaluation of [f] that returns [Some], and returns [None] if there + is no such element. *) + val find_map : _ t -> f:(elt -> 'a option) -> 'a option + + val to_list : _ t -> elt list + val to_array : _ t -> elt array + + (** Returns a min (resp max) element from the collection using the provided [compare] + function, or [None] if the collection is empty. In case of a tie, the first element + encountered while traversing the collection is returned. *) + val min_elt : _ t -> compare:(elt -> elt -> int) -> elt option + val max_elt : _ t -> compare:(elt -> elt -> int) -> elt option +end + +(** Signature for polymorphic container, e.g., ['a list] or ['a array]. *) +module type S1 = sig + type 'a t + + (** Checks whether the provided element is there, using [equal]. *) + val mem : 'a t -> 'a -> equal:('a -> 'a -> bool) -> bool + + val length : 'a t -> int + + val is_empty : 'a t -> bool + + val iter : 'a t -> f:('a -> unit) -> unit + + (** [fold t ~init ~f] returns [f (... f (f (f init e1) e2) e3 ...) en], where [e1..en] + are the elements of [t] *) + val fold : 'a t -> init:'accum -> f:('accum -> 'a -> 'accum) -> 'accum + + (** [fold_result t ~init ~f] is a short-circuiting version of [fold] that runs in the + [Result] monad. If [f] returns an [Error _], that value is returned without any + additional invocations of [f]. *) + val fold_result + : 'a t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + (** [fold_until t ~init ~f ~finish] is a short-circuiting version of [fold]. If [f] + returns [Stop _] the computation ceases and results in that value. If [f] returns + [Continue _], the fold will proceed. If [f] never returns [Stop _], the final result + is computed by [finish]. + + Example: + + {[ + type maybe_negative = + | Found_negative of int + | All_nonnegative of { sum : int } + + (** [first_neg_or_sum list] returns the first negative number in [list], if any, + otherwise returns the sum of the list. *) + let first_neg_or_sum = + List.fold_until ~init:0 + ~f:(fun sum x -> + if x < 0 + then Stop (Found_negative x) + else Continue (sum + x)) + ~finish:(fun sum -> All_nonnegative { sum }) + ;; + + let x = first_neg_or_sum [1; 2; 3; 4; 5] + val x : maybe_negative = All_nonnegative {sum = 15} + + let y = first_neg_or_sum [1; 2; -3; 4; 5] + val y : maybe_negative = Found_negative -3 + ]} *) + val fold_until + : 'a t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + (** Returns [true] if and only if there exists an element for which the provided + function evaluates to [true]. This is a short-circuiting operation. *) + val exists : 'a t -> f:('a -> bool) -> bool + + (** Returns [true] if and only if the provided function evaluates to [true] for all + elements. This is a short-circuiting operation. *) + val for_all : 'a t -> f:('a -> bool) -> bool + + (** Returns the number of elements for which the provided function evaluates to true. *) + val count : 'a t -> f:('a -> bool) -> int + + (** Returns the sum of [f i] for all [i] in the container. *) + val sum + : (module Summable with type t = 'sum) + -> 'a t -> f:('a -> 'sum) -> 'sum + + (** Returns as an [option] the first element for which [f] evaluates to true. *) + val find : 'a t -> f:('a -> bool) -> 'a option + + (** Returns the first evaluation of [f] that returns [Some], and returns [None] if there + is no such element. *) + val find_map : 'a t -> f:('a -> 'b option) -> 'b option + + val to_list : 'a t -> 'a list + val to_array : 'a t -> 'a array + + (** Returns a minimum (resp maximum) element from the collection using the provided + [compare] function, or [None] if the collection is empty. In case of a tie, the first + element encountered while traversing the collection is returned. The implementation + uses [fold] so it has the same complexity as [fold]. *) + val min_elt : 'a t -> compare:('a -> 'a -> int) -> 'a option + val max_elt : 'a t -> compare:('a -> 'a -> int) -> 'a option +end + +module type S1_phantom_invariant = sig + type ('a, 'phantom) t + + (** Checks whether the provided element is there, using [equal]. *) + val mem : ('a, _) t -> 'a -> equal:('a -> 'a -> bool) -> bool + + val length : (_, _) t -> int + val is_empty : (_, _) t -> bool + val iter : ('a, _) t -> f:('a -> unit) -> unit + + (** [fold t ~init ~f] returns [f (... f (f (f init e1) e2) e3 ...) en], where [e1..en] + are the elements of [t]. *) + val fold : ('a, _) t -> init:'accum -> f:('accum -> 'a -> 'accum) -> 'accum + + (** [fold_result t ~init ~f] is a short-circuiting version of [fold] that runs in the + [Result] monad. If [f] returns an [Error _], that value is returned without any + additional invocations of [f]. *) + val fold_result + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + (** [fold_until t ~init ~f ~finish] is a short-circuiting version of [fold]. If [f] + returns [Stop _] the computation ceases and results in that value. If [f] returns + [Continue _], the fold will proceed. If [f] never returns [Stop _], the final result + is computed by [finish]. + + Example: + + {[ + type maybe_negative = + | Found_negative of int + | All_nonnegative of { sum : int } + + (** [first_neg_or_sum list] returns the first negative number in [list], if any, + otherwise returns the sum of the list. *) + let first_neg_or_sum = + List.fold_until ~init:0 + ~f:(fun sum x -> + if x < 0 + then Stop (Found_negative x) + else Continue (sum + x)) + ~finish:(fun sum -> All_nonnegative { sum }) + ;; + + let x = first_neg_or_sum [1; 2; 3; 4; 5] + val x : maybe_negative = All_nonnegative {sum = 15} + + let y = first_neg_or_sum [1; 2; -3; 4; 5] + val y : maybe_negative = Found_negative -3 + ]} *) + val fold_until + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + (** Returns [true] if and only if there exists an element for which the provided + function evaluates to [true]. This is a short-circuiting operation. *) + val exists : ('a, _) t -> f:('a -> bool) -> bool + + (** Returns [true] if and only if the provided function evaluates to [true] for all + elements. This is a short-circuiting operation. *) + val for_all : ('a, _) t -> f:('a -> bool) -> bool + + (** Returns the number of elements for which the provided function evaluates to true. *) + val count : ('a, _) t -> f:('a -> bool) -> int + + (** Returns the sum of [f i] for all [i] in the container. *) + val sum + : (module Summable with type t = 'sum) + -> ('a, _) t -> f:('a -> 'sum) -> 'sum + + (** Returns as an [option] the first element for which [f] evaluates to true. *) + val find : ('a, _) t -> f:('a -> bool) -> 'a option + + (** Returns the first evaluation of [f] that returns [Some], and returns [None] if there + is no such element. *) + val find_map : ('a, _) t -> f:('a -> 'b option) -> 'b option + + val to_list : ('a, _) t -> 'a list + val to_array : ('a, _) t -> 'a array + + (** Returns a min (resp max) element from the collection using the provided [compare] + function. In case of a tie, the first element encountered while traversing the + collection is returned. The implementation uses [fold] so it has the same complexity + as [fold]. Returns [None] iff the collection is empty. *) + val min_elt : ('a, _) t -> compare:('a -> 'a -> int) -> 'a option + val max_elt : ('a, _) t -> compare:('a -> 'a -> int) -> 'a option +end + +module type S1_phantom = sig + type ('a, +'phantom) t + include S1_phantom_invariant with type ('a, 'phantom) t := ('a, 'phantom) t +end + +module type Generic = sig + type 'a t + type 'a elt + val length : _ t -> int + val is_empty : _ t -> bool + val iter : 'a t -> f:('a elt -> unit) -> unit + val fold : 'a t -> init:'accum -> f:('accum -> 'a elt -> 'accum) -> 'accum + val fold_result + : 'a t + -> init:'accum + -> f:('accum -> 'a elt -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + val fold_until + : 'a t + -> init:'accum + -> f:('accum -> 'a elt -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + val exists : 'a t -> f:('a elt -> bool) -> bool + val for_all : 'a t -> f:('a elt -> bool) -> bool + val count : 'a t -> f:('a elt -> bool) -> int + val sum + : (module Summable with type t = 'sum) + -> 'a t -> f:('a elt -> 'sum) -> 'sum + val find : 'a t -> f:('a elt -> bool) -> 'a elt option + val find_map : 'a t -> f:('a elt -> 'b option) -> 'b option + val to_list : 'a t -> 'a elt list + val to_array : 'a t -> 'a elt array + val min_elt : 'a t -> compare:('a elt -> 'a elt -> int) -> 'a elt option + val max_elt : 'a t -> compare:('a elt -> 'a elt -> int) -> 'a elt option +end + +module type Generic_phantom = sig + type ('a, 'phantom) t + type 'a elt + val length : (_, _) t -> int + val is_empty : (_, _) t -> bool + val iter : ('a, _) t -> f:('a elt -> unit) -> unit + val fold : ('a, _) t -> init:'accum -> f:('accum -> 'a elt -> 'accum) -> 'accum + val fold_result + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a elt -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + val fold_until + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a elt -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + val exists : ('a, _) t -> f:('a elt -> bool) -> bool + val for_all : ('a, _) t -> f:('a elt -> bool) -> bool + val count : ('a, _) t -> f:('a elt -> bool) -> int + val sum + : (module Summable with type t = 'sum) + -> ('a, _) t -> f:('a elt -> 'sum) -> 'sum + val find : ('a, _) t -> f:('a elt -> bool) -> 'a elt option + val find_map : ('a, _) t -> f:('a elt -> 'b option) -> 'b option + val to_list : ('a, _) t -> 'a elt list + val to_array : ('a, _) t -> 'a elt array + val min_elt : ('a, _) t -> compare:('a elt -> 'a elt -> int) -> 'a elt option + val max_elt : ('a, _) t -> compare:('a elt -> 'a elt -> int) -> 'a elt option +end + +module type Make_gen_arg = sig + type 'a t + type 'a elt + + val fold : 'a t -> init:'accum -> f:('accum -> 'a elt -> 'accum) -> 'accum + + (** The [iter] argument to [Container.Make] specifies how to implement the + container's [iter] function. [`Define_using_fold] means to define [iter] + via: + + {[ + iter t ~f = Container.iter ~fold t ~f + ]} + + [`Custom] overrides the default implementation, presumably with something more + efficient. Several other functions returned by [Container.Make] are defined in + terms of [iter], so passing in a more efficient [iter] will improve their efficiency + as well. *) + val iter : [ `Define_using_fold + | `Custom of 'a t -> f:('a elt -> unit) -> unit + ] + + (** The [length] argument to [Container.Make] specifies how to implement the + container's [length] function. [`Define_using_fold] means to define + [length] via: + + {[ + length t ~f = Container.length ~fold t ~f + ]} + + [`Custom] overrides the default implementation, presumably with something more + efficient. Several other functions returned by [Container.Make] are defined in + terms of [length], so passing in a more efficient [length] will improve their + efficiency as well. *) + val length : [ `Define_using_fold + | `Custom of 'a t -> int + ] +end + +module type Make_arg = Make_gen_arg with type 'a elt := 'a Monad.Ident.t + +module type Make0_arg = sig + module Elt : sig + type t + val equal : t -> t -> bool + end + + type t + + val fold : t -> init:'accum -> f:('accum -> Elt.t -> 'accum) -> 'accum + val iter : [ `Define_using_fold + | `Custom of t -> f:(Elt.t -> unit) -> unit + ] + val length : [ `Define_using_fold + | `Custom of t -> int + ] +end + +module type Container = sig + include module type of struct include Export end + + module type S0 = S0 + module type S0_phantom = S0_phantom + module type S1 = S1 + module type S1_phantom_invariant = S1_phantom_invariant + module type S1_phantom = S1_phantom + module type Generic = Generic + module type Generic_phantom = Generic_phantom + + module type Summable = Summable + + (** Generic definitions of container operations in terms of [fold]. + + E.g.: [iter ~fold t ~f = fold t ~init:() ~f:(fun () a -> f a)]. *) + + type ('t, 'a, 'accum) fold = 't -> init:'accum -> f:('accum -> 'a -> 'accum) -> 'accum + type ('t, 'a) iter = 't -> f:('a -> unit) -> unit + type 't length = 't -> int + + val iter : fold:('t, 'a, unit ) fold -> ('t, 'a) iter + val count : fold:('t, 'a, int ) fold -> 't -> f:('a -> bool) -> int + val min_elt : fold:('t, 'a, 'a option) fold -> 't -> compare:('a -> 'a -> int) -> 'a option + val max_elt : fold:('t, 'a, 'a option) fold -> 't -> compare:('a -> 'a -> int) -> 'a option + val length : fold:('t, _, int ) fold -> 't -> int + val to_list : fold:('t, 'a, 'a list ) fold -> 't -> 'a list + val sum + : fold : ('t, 'a, 'sum) fold + -> (module Summable with type t = 'sum) + -> 't -> f:('a -> 'sum) -> 'sum + + val fold_result + : fold:('t, 'a, 'b) fold + -> init:'b + -> f:('b -> 'a -> ('b, 'e) Result.t) + -> 't + -> ('b, 'e) Result.t + + val fold_until + : fold:('t, 'a, 'b) fold + -> init:'b + -> f:('b -> 'a -> ('b, 'final) Continue_or_stop.t) + -> finish:('b -> 'final) + -> 't + -> 'final + + (** Generic definitions of container operations in terms of [iter] and [length]. *) + val is_empty : iter:('t, 'a) iter -> 't -> bool + val exists : iter:('t, 'a) iter -> 't -> f:('a -> bool) -> bool + val for_all : iter:('t, 'a) iter -> 't -> f:('a -> bool) -> bool + val find : iter:('t, 'a) iter -> 't -> f:('a -> bool) -> 'a option + val find_map : iter:('t, 'a) iter -> 't -> f:('a -> 'b option) -> 'b option + val to_array : length:'t length -> iter:('t, 'a) iter -> 't -> 'a array + + (** The idiom for using [Container.Make] is to bind the resulting module and to + explicitly import each of the functions that one wants: + + {[ + module C = Container.Make (struct ... end) + let count = C.count + let exists = C.exists + let find = C.find + (* ... *) + ]} + + This is preferable to: + + {[ + include Container.Make (struct ... end) + ]} + + because the [include] makes it too easy to shadow specialized implementations of + container functions ([length] being a common one). + + [Container.Make0] is like [Container.Make], but for monomorphic containers like + [string]. *) + module Make (T : Make_arg) : S1 with type 'a t := 'a T.t + module Make0 (T : Make0_arg) : S0 with type t := T.t and type elt := T.Elt.t +end diff --git a/src/discover/discover.ml b/src/discover/discover.ml new file mode 100644 index 0000000..f2e0b92 --- /dev/null +++ b/src/discover/discover.ml @@ -0,0 +1,20 @@ +open Configurator.V1 + +let program = + {| +int main(int argc, char ** argv) +{ + return __builtin_popcount(argc); +} +|} +;; + +let () = + let output = ref "" in + main + ~name:"discover" + ~args:[ "-o", Set_string output, "FILENAME output file" ] + (fun c -> + let has_popcnt = c_test c ~c_flags:[ "-mpopcnt" ] program in + Flags.write_sexp !output (if has_popcnt then [ "-mpopcnt" ] else [])) +;; diff --git a/src/discover/discover.mli b/src/discover/discover.mli new file mode 100644 index 0000000..e790aeb --- /dev/null +++ b/src/discover/discover.mli @@ -0,0 +1 @@ +(* empty *) diff --git a/src/discover/dune b/src/discover/dune new file mode 100644 index 0000000..9e74a1e --- /dev/null +++ b/src/discover/dune @@ -0,0 +1,2 @@ +(executables (names discover) (libraries dune.configurator) + (preprocess no_preprocessing)) \ No newline at end of file diff --git a/src/dune b/src/dune new file mode 100644 index 0000000..422e43d --- /dev/null +++ b/src/dune @@ -0,0 +1,30 @@ +(rule (targets int63_backend.ml) + (deps (:first_dep select-int63-backend/select.ml)) + (action + (run %{ocaml} %{first_dep} -portable-int63 + !%{lib-available:base-native-int63} -arch-sixtyfour %{arch_sixtyfour} -o + %{targets}))) + +(rule (targets bytes_set_primitives.ml) + (deps (:first_dep select-bytes-set-primitives/select.ml)) + (action + (run %{ocaml} %{first_dep} -ocaml-version %{ocaml_version} -o %{targets}))) + +(rule (targets pow_overflow_bounds.ml) + (deps (:first_dep ../generate/generate_pow_overflow_bounds.exe)) + (action (run %{first_dep} -atomic -o %{targets})) (mode fallback)) + +(library (name base) (public_name base) + (libraries caml sexplib0 shadow_stdlib) (install_c_headers internalhash) + (c_flags :standard -D_LARGEFILE64_SOURCE (:include mpopcnt.sexp)) + (c_names exn_stubs int_math_stubs internalhash_stubs hash_stubs am_testing) + (preprocess no_preprocessing) + (lint + (pps ppx_base ppx_base_lint -check-doc-comments -type-conv-keep-w32=impl + -apply=js_style,base_lint,type_conv)) + (js_of_ocaml (javascript_files runtime.js))) + +(rule (targets mpopcnt.sexp) (deps discover/discover.exe) + (action (run ./discover/discover.exe -o %{targets}))) + +(ocamllex hex_lexer) diff --git a/src/either.ml b/src/either.ml new file mode 100644 index 0000000..3a0165a --- /dev/null +++ b/src/either.ml @@ -0,0 +1,289 @@ +open! Import + +include Either_intf + +module Array = Array0 + +type ('f, 's) t = + | First of 'f + | Second of 's +[@@deriving_inline compare, hash, sexp] +let compare : + 'f 's . + ('f -> 'f -> int) -> ('s -> 's -> int) -> ('f, 's) t -> ('f, 's) t -> int + = + fun _cmp__f -> + fun _cmp__s -> + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match (a__001_, b__002_) with + | (First _a__003_, First _b__004_) -> _cmp__f _a__003_ _b__004_ + | (First _, _) -> (-1) + | (_, First _) -> 1 + | (Second _a__005_, Second _b__006_) -> + _cmp__s _a__005_ _b__006_) +let hash_fold_t : type f s. + (Ppx_hash_lib.Std.Hash.state -> f -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> s -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> (f, s) t -> Ppx_hash_lib.Std.Hash.state + = + fun _hash_fold_f -> + fun _hash_fold_s -> + fun hsv -> + fun arg -> + match arg with + | First _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 0 in + let hsv = hsv in _hash_fold_f hsv _a0 + | Second _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 1 in + let hsv = hsv in _hash_fold_s hsv _a0 +let t_of_sexp : type f s. + (Ppx_sexp_conv_lib.Sexp.t -> f) -> + (Ppx_sexp_conv_lib.Sexp.t -> s) -> Ppx_sexp_conv_lib.Sexp.t -> (f, s) t + = + let _tp_loc = "src/either.ml.t" in + fun _of_f -> + fun _of_s -> + function + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("first"|"First" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_f v0 in First v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("second"|"Second" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_s v0 in Second v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.Atom ("first"|"First") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.Atom ("second"|"Second") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as + sexp -> + Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_t : type f s. + (f -> Ppx_sexp_conv_lib.Sexp.t) -> + (s -> Ppx_sexp_conv_lib.Sexp.t) -> (f, s) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_f -> + fun _of_s -> + function + | First v0 -> + let v0 = _of_f v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "First"; v0] + | Second v0 -> + let v0 = _of_s v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Second"; v0] +[@@@end] + +let swap = function + | First x -> Second x + | Second x -> First x +;; + +let is_first = function + | First _ -> true + | Second _ -> false +;; + +let is_second = function + | First _ -> false + | Second _ -> true +;; + +let value (First x | Second x) = x +;; + +let value_map t ~first ~second = + match t with + | First x -> first x + | Second x -> second x +;; + +let iter = value_map +;; + +let map t ~first ~second = + match t with + | First x -> First (first x) + | Second x -> Second (second x) +;; + +let first x = First x +let second x = Second x +;; + +let equal eq1 eq2 t1 t2 = + match t1, t2 with + | First x, First y -> eq1 x y + | Second x, Second y -> eq2 x y + | First _, Second _ + | Second _, First _ -> false +;; + +let invariant f s = function + | First x -> f x + | Second y -> s y +;; + +module Make_focused (M : sig + type (+'a, +'b) t + + val return : 'a -> ('a, _) t + val other : 'b -> (_, 'b) t + + val either : ('a, 'b) t -> return:('a -> 'c) -> other:('b -> 'c) -> 'c + + val combine + : ('a, 'd) t + -> ('b, 'd) t + -> f:('a -> 'b -> 'c) + -> other:('d -> 'd -> 'd) + -> ('c, 'd) t + end) = struct + include M + open With_return + + let map t ~f = either t ~return:(fun x -> return (f x)) ~other + + include Monad.Make2 (struct + type nonrec ('a, 'b) t = ('a, 'b) t + + let return = return + ;; + + let bind t ~f = either t ~return:f ~other + ;; + + let map = `Custom map + end) + + module App = Applicative.Make2 (struct + type nonrec ('a, 'b) t = ('a, 'b) t + + let return = return + ;; + + let apply t1 t2 = + let return f = either t2 ~return:(fun x -> return (f x)) ~other in + either t1 ~return ~other + ;; + + let map = `Custom map + end) + + include App + + module Args = Applicative.Make_args2 (struct + type nonrec ('a, 'b) t = ('a, 'b) t + include App + end) + [@@warning "-3"] + + let combine_all = + let rec other_loop f acc = function + | [] -> other acc + | t :: ts -> + either t ~return:(fun _ -> other_loop f acc ts) + ~other:(fun o -> other_loop f (f acc o) ts) + in + let rec return_loop f acc = function + | [] -> return (List.rev acc) + | t :: ts -> + either t ~return:(fun x -> return_loop f (x :: acc) ts) + ~other:(fun o -> other_loop f o ts) + in + fun ts ~f -> return_loop f [] ts + ;; + + let combine_all_unit = + let rec other_loop f acc = function + | [] -> other acc + | t :: ts -> + either t ~return:(fun () -> other_loop f acc ts) + ~other:(fun o -> other_loop f (f acc o) ts) + in + let rec return_loop f = function + | [] -> return () + | t :: ts -> + either t ~return:(fun () -> return_loop f ts) + ~other:(fun o -> other_loop f o ts) + in + fun ts ~f -> return_loop f ts + ;; + + let to_option t = either t ~return:Option.some ~other:(fun _ -> None) + ;; + + let value t ~default = either t ~return:Fn.id ~other:(fun _ -> default) + ;; + + let with_return f = + with_return (fun ret -> + other (f (With_return.prepend ret ~f:return))) + ;; + +end + +module First = Make_focused (struct + type nonrec ('a, 'b) t = ('a, 'b) t + + let return = first + let other = second + ;; + + let either t ~return ~other = + match t with + | First x -> return x + | Second y -> other y + ;; + + let combine t1 t2 ~f ~other = + match t1, t2 with + | First x, First y -> First (f x y) + | Second x, Second y -> Second (other x y) + | Second x, _ + | _, Second x -> Second x + end) + +module Second = Make_focused (struct + type nonrec ('a, 'b) t = ('b, 'a) t + + let return = second + let other = first + ;; + + let either t ~return ~other = + match t with + | Second y -> return y + | First x -> other x + ;; + + let combine t1 t2 ~f ~other = + match t1, t2 with + | Second x, Second y -> Second (f x y) + | First x, First y -> First (other x y) + | First x, _ + | _, First x -> First x + end) + +module Export = struct + type ('f, 's) _either + = ('f, 's) t + = First of 'f + | Second of 's +end diff --git a/src/either.mli b/src/either.mli new file mode 100644 index 0000000..e3587aa --- /dev/null +++ b/src/either.mli @@ -0,0 +1 @@ +include Either_intf.Either (** @inline *) diff --git a/src/either_intf.ml b/src/either_intf.ml new file mode 100644 index 0000000..05a9d80 --- /dev/null +++ b/src/either_intf.ml @@ -0,0 +1,93 @@ +(** A type that represents values with two possibilities. + + [Either] can be seen as a generic sum type, the dual of [Tuple]. [First] is neither + more important nor less important than [Second]. + + Many functions in [Either] focus on just one constructor. The [Focused] signature + abstracts over which constructor is the focus. To use these functions, use the + [First] or [Second] modules in [S]. *) + +open! Import + +module type Focused = sig + type (+'focus, +'other) t + + include Monad.S2 with type ('a, 'b) t := ('a, 'b) t + include Applicative.S2 with type ('a, 'b) t := ('a, 'b) t + + module Args : Applicative.Args2 with type ('a, 'e) arg := ('a, 'e) t + [@@warning "-3"] + [@@deprecated "[since 2018-09] Use [ppx_let] instead."] + + val value : ('a, _) t -> default:'a -> 'a + + val to_option : ('a, _) t -> 'a option + + val with_return : ('a With_return.return -> 'b) -> ('a, 'b) t + + val combine + : ('a, 'd) t + -> ('b, 'd) t + -> f:('a -> 'b -> 'c) + -> other:('d -> 'd -> 'd) + -> ('c, 'd) t + + val combine_all : ('a, 'b) t list -> f:('b -> 'b -> 'b) -> ('a list, 'b) t + + val combine_all_unit : (unit, 'b) t list -> f:('b -> 'b -> 'b) -> (unit, 'b) t +end + +module type Either = sig + + type ('f, 's) t = + | First of 'f + | Second of 's + [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('f -> 'f -> int) -> + ('s -> 's -> int) -> ('f, 's) t -> ('f, 's) t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'f -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> 's -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> + ('f, 's) t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('f,'s) t := ('f, 's) t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Invariant.S2 with type ('a, 'b) t := ('a, 'b) t + + val swap : ('f, 's) t -> ('s, 'f) t + + val value : ('a, 'a) t -> 'a + + val iter : ('a, 'b) t -> first:('a -> unit) -> second:('b -> unit) -> unit + val value_map : ('a, 'b) t -> first:('a -> 'c) -> second:('b -> 'c) -> 'c + val map : ('a, 'b) t -> first:('a -> 'c) -> second:('b -> 'd) -> ('c, 'd) t + + val equal : ('f -> 'f -> bool) -> ('s -> 's -> bool) -> ('f, 's) t -> ('f, 's) t -> bool + + module type Focused = Focused + + module First : Focused with type ('a, 'b) t = ('a, 'b) t + module Second : Focused with type ('a, 'b) t = ('b, 'a) t + + val is_first : (_, _) t -> bool + val is_second : (_, _) t -> bool + + (** [first] and [second] are [First.return] and [Second.return]. *) + val first : 'f -> ('f, _) t + val second : 's -> (_, 's) t + + (**/**) + + module Export : sig + type ('f, 's) _either + = ('f, 's) t + = First of 'f + | Second of 's + end +end diff --git a/src/equal.ml b/src/equal.ml new file mode 100644 index 0000000..df359e6 --- /dev/null +++ b/src/equal.ml @@ -0,0 +1,41 @@ +(** This module defines signatures that are to be included in other signatures to ensure a + consistent interface to [equal] functions. There is a signature ([S], [S1], [S2], + [S3]) for each arity of type. Usage looks like: + + {[ + type t + include Equal.S with type t := t + ]} + + or + + {[ + type 'a t + include Equal.S1 with type 'a t := 'a t + ]} *) + +open! Import + +type 'a t = 'a -> 'a -> bool + +type 'a equal = 'a t + +module type S = sig + type t + val equal : t equal +end + +module type S1 = sig + type 'a t + val equal : 'a equal -> 'a t equal +end + +module type S2 = sig + type ('a, 'b) t + val equal : 'a equal -> 'b equal -> ('a, 'b) t equal +end + +module type S3 = sig + type ('a, 'b, 'c) t + val equal : 'a equal -> 'b equal -> 'c equal -> ('a, 'b, 'c) t equal +end diff --git a/src/error.ml b/src/error.ml new file mode 100644 index 0000000..34b6b1a --- /dev/null +++ b/src/error.ml @@ -0,0 +1,20 @@ +(* This module is trying to minimize dependencies on modules in Core, so as to allow + [Error] and [Or_error] to be used in various places. Please avoid adding new + dependencies. *) + +open! Import + +include Info + +let raise t = raise (to_exn t) + +let raise_s sexp = raise (create_s sexp) + +let to_info t = t +let of_info t = t + +include Pretty_printer.Register_pp(struct + type nonrec t = t + let module_name = "Base.Error" + let pp = pp + end) diff --git a/src/error.mli b/src/error.mli new file mode 100644 index 0000000..49d34da --- /dev/null +++ b/src/error.mli @@ -0,0 +1,15 @@ +(** A lazy string, implemented with [Info], but intended specifically for error + messages. *) + +open! Import + +include Info_intf.S with type t = private Info.t (** @open *) + +(** Note that the exception raised by this function maintains a reference to the [t] + passed in. *) +val raise : t -> _ + +val raise_s : Sexp.t -> _ + +val to_info : t -> Info.t +val of_info : Info.t -> t diff --git a/src/exn.ml b/src/exn.ml new file mode 100644 index 0000000..3f22fb6 --- /dev/null +++ b/src/exn.ml @@ -0,0 +1,151 @@ +open! Import + +type t = exn [@@deriving_inline sexp_of] +let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_exn +[@@@end] + +let exit = Caml.exit + +exception Finally of t * t [@@deriving_inline sexp] +let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add ([%extension_constructor Finally]) + (function + | Finally (v0, v1) -> + let v0 = sexp_of_t v0 + and v1 = sexp_of_t v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "src/exn.ml.Finally"; v0; v1] + | _ -> assert false) +[@@@end] +exception Reraised of string * t [@@deriving_inline sexp] +let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Reraised]) + (function + | Reraised (v0, v1) -> + let v0 = sexp_of_string v0 + and v1 = sexp_of_t v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "src/exn.ml.Reraised"; v0; v1] + | _ -> assert false) +[@@@end] + +exception Sexp of Sexp.t + +(* We install a custom exn-converter rather than use: + + {[ + exception Sexp of Sexp.t [@@deriving_inline sexp][@@@end] + ]} + + to eliminate the extra wrapping of [(Sexp ...)]. *) +let () = + Sexplib.Conv.Exn_converter.add [%extension_constructor Sexp] + (function + | Sexp t -> t + | _ -> + (* Reaching this branch indicates a bug in sexplib. *) + assert false) +;; + +let create_s sexp = Sexp sexp + +let reraise exc str = + raise (Reraised (str, exc)) + +let reraisef exc format = + Printf.ksprintf (fun str () -> reraise exc str) format + +let to_string exc = Sexp.to_string_hum ~indent:2 (sexp_of_exn exc) +let to_string_mach exc = Sexp.to_string_mach (sexp_of_exn exc) + +let sexp_of_t = sexp_of_exn + +let protectx ~f x ~(finally : _ -> unit) = + let res = + try f x + with exn -> + (try finally x with final_exn -> raise (Finally (exn, final_exn))); + raise exn + in + finally x; + res +;; + +let protect ~f ~finally = protectx ~f () ~finally + +let does_raise (type a) (f : unit -> a) = + try + ignore (f () : a); + false + with _ -> + true +;; + +include Pretty_printer.Register_pp (struct + type t = exn + let pp ppf t = + match sexp_of_exn_opt t with + | Some sexp -> Sexp.pp_hum ppf sexp + | None -> Caml.Format.pp_print_string ppf (Caml.Printexc.to_string t) + ;; + let module_name = "Base.Exn" + end) + +let print_with_backtrace exc raw_backtrace = + Caml.Format.eprintf "@[<2>Uncaught exception:@\n@\n@[%a@]@]@\n@." pp exc; + if Caml.Printexc.backtrace_status () + then Caml.Printexc.print_raw_backtrace Caml.stderr raw_backtrace; + Caml.flush Caml.stderr; +;; + +let set_uncaught_exception_handler () = + Caml.Printexc.set_uncaught_exception_handler print_with_backtrace +;; + +let handle_uncaught_aux ~do_at_exit ~exit f = + try f () + with exc -> + let raw_backtrace = Caml.Printexc.get_raw_backtrace () in + (* One reason to run [do_at_exit] handlers before printing out the error message is + that it helps curses applications bring the terminal in a good state, otherwise the + error message might get corrupted. Also, the OCaml top-level uncaught exception + handler does the same. *) + if do_at_exit then (try Caml.do_at_exit () with _ -> ()); + begin + try + print_with_backtrace exc raw_backtrace + with _ -> + try + Caml.Printf.eprintf "Exn.handle_uncaught could not print; exiting anyway\n%!"; + with _ -> () + end; + exit 1 +;; + +let handle_uncaught_and_exit f = handle_uncaught_aux f ~exit ~do_at_exit:true + +let handle_uncaught ~exit:must_exit f = + handle_uncaught_aux f ~exit:(if must_exit then exit else ignore) + ~do_at_exit:must_exit + +let reraise_uncaught str func = + try func () with + | exn -> raise (Reraised (str, exn)) + +external clear_backtrace : unit -> unit = "Base_clear_caml_backtrace_pos" [@@noalloc] + +let raise_without_backtrace e = + (* We clear the backtrace to reduce confusion, so that people don't think whatever + is stored corresponds to this raise. *) + clear_backtrace (); + Caml.raise_notrace e +;; + +let initialize_module () = + set_uncaught_exception_handler (); +;; + +module Private = struct + let clear_backtrace = clear_backtrace +end diff --git a/src/exn.mli b/src/exn.mli new file mode 100644 index 0000000..bba935c --- /dev/null +++ b/src/exn.mli @@ -0,0 +1,95 @@ +(** Exceptions. + + [sexp_of_t] uses a global table of sexp converters. To register a converter for a new + exception, add [[@@deriving_inline sexp][@@@end]] to its definition. If no suitable converter is + found, the standard converter in [Printexc] will be used to generate an atomic + S-expression. *) + +open! Import + +type t = exn [@@deriving_inline sexp_of] +include +sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Pretty_printer.S with type t := t + +(** Raised when finalization after an exception failed, too. + The first exception argument is the one raised by the initial + function, the second exception the one raised by the finalizer. *) +exception Finally of t * t + +exception Reraised of string * t + +(** [create_s sexp] returns an exception [t] such that [phys_equal (sexp_of_t t) sexp]. + This is useful when one wants to create an exception that serves as a message and the + particular exn constructor doesn't matter. *) +val create_s : Sexp.t -> t + +(** Same as [raise], except that the backtrace is not recorded. *) +val raise_without_backtrace : t -> _ + +val reraise : t -> string -> _ + +(** Types with [format4] are hard to read, so here's an example. + + {[ + let foobar str = + try + ... + with exn -> + Exn.reraisef exn "Foobar is buggy on: %s" str () + ]} *) +val reraisef : t -> ('a, unit, string, unit -> _) format4 -> 'a + +val to_string : t -> string (** Human-readable, multi-line. *) + +val to_string_mach : t -> string (** Machine format, single-line. *) + +(** Executes [f] and afterwards executes [finally], whether [f] throws an exception or + not. *) +val protectx : f:('a -> 'b) -> 'a -> finally:('a -> unit) -> 'b + +val protect : f:(unit -> 'a) -> finally:(unit -> unit) -> 'a + +(** [handle_uncaught ~exit f] catches an exception escaping [f] and prints an error + message to stderr. Exits with return code 1 if [exit] is [true], and returns unit + otherwise. + + Note that since OCaml 4.02.0, you don't need to use this at the entry point of your + program, as the OCaml runtime will do better than this function. *) +val handle_uncaught : exit:bool -> (unit -> unit) -> unit + +(** [handle_uncaught_and_exit f] returns [f ()], unless that raises, in which case it + prints the exception and exits nonzero. *) +val handle_uncaught_and_exit : (unit -> 'a) -> 'a + +(** Traces exceptions passing through. Useful because in practice, backtraces still don't + seem to work. + + Example: + {[ + let rogue_function () = if Random.bool () then failwith "foo" else 3 + let traced_function () = Exn.reraise_uncaught "rogue_function" rogue_function + traced_function ();; + ]} + {v : Program died with Reraised("rogue_function", Failure "foo") v} *) +val reraise_uncaught : string -> (unit -> 'a) -> 'a + +(** [does_raise f] returns [true] iff [f ()] raises, which is often useful in unit + tests. *) +val does_raise : (unit -> _) -> bool + +(** User code never calls this. It is called in [std_kernel.ml] as a top-level side + effect to change the display of exceptions and install an uncaught-exception + printer. *) +val initialize_module : unit -> unit + +(**/**) +(*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) +module Private : sig + val clear_backtrace : unit -> unit +end diff --git a/src/exn_stubs.c b/src/exn_stubs.c new file mode 100644 index 0000000..e6f905a --- /dev/null +++ b/src/exn_stubs.c @@ -0,0 +1,8 @@ +#include + +extern int caml_backtrace_pos; + +CAMLprim value Base_clear_caml_backtrace_pos () { + caml_backtrace_pos = 0; + return Val_unit; +} diff --git a/src/field.ml b/src/field.ml new file mode 100644 index 0000000..a4b2406 --- /dev/null +++ b/src/field.ml @@ -0,0 +1,67 @@ +(* The type [t] should be abstract to make the fset and set functions unavailable + for private types at the level of types (and not by putting None in the field). + Unfortunately, making the type abstract means that when creating fields (through + a [create] function) value restriction kicks in. This is worked around by instead + not making the type abstract, but forcing anyone breaking the abstraction to use + the [For_generated_code] module, making it obvious to any reader that something ugly + is going on. + t_with_perm (and derivatives) is the type that users really use. It is a constructor + because: + 1. it makes type errors more readable (less aliasing) + 2. the typer in ocaml 4.01 allows this: + + {[ + module A = struct + type t = {a : int} + end + type t = A.t + let f (x : t) = x.a + ]} + + (although with Warning 40: a is used out of scope) + which means that if [t_with_perm] was really an alias on [For_generated_code.t], + people could say [t.setter] and break the abstraction with no indication that + something ugly is going on in the source code. + The warning is (I think) for people who want to make their code compatible with + previous versions of ocaml, so we may very well turn it off. + + The type t_with_perm could also have been a [unit -> For_generated_code.t] to work + around value restriction and then [For_generated_code.t] would have been a proper + abstract type, but it looks like it could impact performance (for example, a fold on a + record type with 40 fields would actually allocate the 40 [For_generated_code.t]'s at + every single fold.) *) + +module For_generated_code = struct + type ('perm, 'record, 'field) t = { + force_variance : 'perm -> unit; + (* force [t] to be contravariant in ['perm], because phantom type variables on + concrete types don't work that well otherwise (using :> can remove them easily) *) + name : string; + setter : ('record -> 'field -> unit) option; + getter : ('record -> 'field); + fset : ('record -> 'field -> 'record); + } +end + +type ('perm, 'record, 'field) t_with_perm = + | Field of ('perm, 'record, 'field) For_generated_code.t +type ('record, 'field) t = ([ `Read | `Set_and_create], 'record, 'field) t_with_perm +type ('record, 'field) readonly_t = ([ `Read ], 'record, 'field) t_with_perm + +let name (Field field) = field.name + +let get (Field field) r = field.getter r + +let fset (Field field) r v = field.fset r v + +let setter (Field field) = field.setter + +type ('perm, 'record, 'result) user = + { f : 'field. ('perm, 'record, 'field) t_with_perm -> 'result } + +let map (Field field) r ~f = field.fset r (f (field.getter r)) + +let updater (Field field) = + match field.setter with + | None -> None + | Some setter -> Some (fun r ~f -> setter r (f (field.getter r))) diff --git a/src/field.mli b/src/field.mli new file mode 100644 index 0000000..f4f211c --- /dev/null +++ b/src/field.mli @@ -0,0 +1,39 @@ +(** OCaml record field. *) + +(**/**) +module For_generated_code : sig + (*_ don't use this by hand, it is only meant for ppx_fields_conv *) + type ('perm, 'record, 'field) t = { + force_variance : 'perm -> unit; + name : string; + setter : ('record -> 'field -> unit) option; + getter : ('record -> 'field); + fset : ('record -> 'field -> 'record); + } +end +(**/**) + +(** ['record] is the type of the record. ['field] is the type of the + values stored in the record field with name [name]. ['perm] is a way + of restricting the operations that can be used. *) +type ('perm, 'record, 'field) t_with_perm = + | Field of ('perm, 'record, 'field) For_generated_code.t + +(** A record field with no restrictions. *) +type ('record, 'field) t = ([ `Read | `Set_and_create], 'record, 'field) t_with_perm + +(** A record that can only be read, because it belongs to a private type. *) +type ('record, 'field) readonly_t = ([ `Read ], 'record, 'field) t_with_perm + +val name : (_, _, _) t_with_perm -> string +val get : (_, 'r, 'a) t_with_perm -> 'r -> 'a +val fset : ([> `Set_and_create], 'r, 'a) t_with_perm -> 'r -> 'a -> 'r +val setter : ([> `Set_and_create], 'r, 'a) t_with_perm -> ('r -> 'a -> unit) option + +val map : ([> `Set_and_create], 'r, 'a) t_with_perm -> 'r -> f:('a -> 'a) -> 'r +val updater + : ([> `Set_and_create], 'r, 'a) t_with_perm + -> ('r -> f:('a -> 'a) -> unit) option + +type ('perm, 'record, 'result) user = + { f : 'field. ('perm, 'record, 'field) t_with_perm -> 'result } diff --git a/src/fieldslib.ml b/src/fieldslib.ml new file mode 100644 index 0000000..dd41b03 --- /dev/null +++ b/src/fieldslib.ml @@ -0,0 +1,3 @@ +(** This module is for use by ppx_fields_conv, and is thus not in the interface of + Base. *) +module Field = Field diff --git a/src/float.ml b/src/float.ml new file mode 100644 index 0000000..41e277b --- /dev/null +++ b/src/float.ml @@ -0,0 +1,1099 @@ +open! Import +open! Printf + +module Bytes = Bytes0 +include Float0 + +let ceil = Caml.ceil +let floor = Caml.floor +let mod_float = Caml.mod_float +let modf = Caml.modf + +let raise_s = Error.raise_s + +module T = struct + type t = float [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_float + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_float in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = float_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_float + [@@@end] + let compare = Float_replace_polymorphic_compare.compare +end + +include T +include Comparator.Make(T) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open Float_replace_polymorphic_compare + +let to_float x = x +let of_float x = x + +let of_string s = + try Caml.float_of_string s with + | _ -> invalid_argf "Float.of_string %s" s () +;; + +external format_float : string -> float -> string = "caml_format_float" + +(* Stolen from [pervasives.ml]. Adds a "." at the end if needed. It is in + [pervasives.mli], but it also says not to use it directly, so we copy and paste the + code. It makes the assumption on the string passed in argument that it was returned by + [format_float]. *) +let valid_float_lexem s = + let l = String.length s in + let rec loop i = + if Int_replace_polymorphic_compare.(>=) i l then s ^ "." else + match s.[i] with + | '0' .. '9' | '-' -> loop (i + 1) + | _ -> s + in + loop 0 +;; + +(* Let [y] be a power of 2. Then the next representable float is: + [z = y * (1 + 2 ** -52)] + and the previous one is + [x = y * (1 - 2 ** -53)] + + In general, every two adjacent floats are within a factor of between [1 + 2**-53] + and [1 + 2**-52] from each other, that is within [1 + 1.1e-16] and [1 + 2.3e-16]. + + So if the decimal representation of a float starts with "1", then its adjacent floats + will usually differ from it by 1, and sometimes by 2, at the 17th significant digit + (counting from 1). + + On the other hand, if the decimal representation starts with "9", then the adjacent + floats will be off by no more than 23 at the 16th and 17th significant digits. + + E.g.: + + {v + # sprintf "%.17g" (1024. *. (1. -. 2.** (-53.)));; + 11111111 + 1234 5678901234567 + - : string = "1023.9999999999999" + v} + Printing a couple of extra digits reveals that the difference indeed is roughly 11 at + digits 17th and 18th (that is, 13th and 14th after "."): + + {v + # sprintf "%.19g" (1024. *. (1. -. 2.** (-53.)));; + 1111111111 + 1234 567890123456789 + - : string = "1023.999999999999886" + v} + + The ulp (the difference between adjacent floats) is twice as big on the other side of + 1024.: + + {v + # sprintf "%.19g" (1024. *. (1. +. 2.** (-52.)));; + 1111111111 + 1234 567890123456789 + - : string = "1024.000000000000227" + v} + + Now take a power of 2 which starts with 99: + + {v + # 2.**93. ;; + 1111111111 + 1 23456789012345678 + - : float = 9.9035203142830422e+27 + + # 2.**93. *. (1. +. 2.** (-52.));; + - : float = 9.9035203142830444e+27 + + # 2.**93. *. (1. -. 2.** (-53.));; + - : float = 9.9035203142830411e+27 + v} + + The difference between 2**93 and its two neighbors is slightly more than, respectively, + 1 and 2 at significant digit 16. + + Those examples show that: + - 17 significant digits is always sufficient to represent a float without ambiguity + - 15th significant digit can always be represented accurately + - converting a decimal number with 16 significant digits to its nearest float and back + can change the last decimal digit by no more than 1 + + To make sure that floats obtained by conversion from decimal fractions (e.g. "3.14") + are printed without trailing non-zero digits, one should choose the first among the + '%.15g', '%.16g', and '%.17g' representations which does round-trip: + + {v + # sprintf "%.15g" 3.14;; + - : string = "3.14" (* pick this one *) + # sprintf "%.16g" 3.14;; + - : string = "3.14" + # sprintf "%.17g" 3.14;; + - : string = "3.1400000000000001" (* do not pick this one *) + + # sprintf "%.15g" 8.000000000000002;; + - : string = "8" (* do not pick this one--does not round-trip *) + # sprintf "%.16g" 8.000000000000002;; + - : string = "8.000000000000002" (* prefer this one *) + # sprintf "%.17g" 8.000000000000002;; + - : string = "8.0000000000000018" (* this one has one digit of junk at the end *) + v} + + Skipping the '%.16g' in the above procedure saves us some time, but it means that, as + seen in the second example above, occasionally numbers with exactly 16 significant + digits will have an error introduced at the 17th digit. That is probably OK for + typical use, because a number with 16 significant digits is "ugly" already. Adding one + more doesn't make it much worse for a human reader. + + On the other hand, we cannot skip '%.15g' and only look at '%.16g' and '%.17g', since + the inaccuracy at the 16th digit might introduce the noise we want to avoid: + + {v + # sprintf "%.15g" 9.992;; + - : string = "9.992" (* pick this one *) + # sprintf "%.16g" 9.992;; + - : string = "9.992000000000001" (* do not pick this one--junk at the end *) + # sprintf "%.17g" 9.992;; + - : string = "9.9920000000000009" + v} +*) +let to_string x = + valid_float_lexem ( + let y = format_float "%.15g" x in + if float_of_string y = x then + y + else + format_float "%.17g" x) +;; + +let nan = Caml.nan + +let infinity = Caml.infinity +let neg_infinity = Caml.neg_infinity + +let max_value = infinity +let min_value = neg_infinity + +let max_finite_value = Caml.max_float + +let min_positive_subnormal_value = 2. ** -1074. +let min_positive_normal_value = 2. ** -1022. + +let zero = 0. +let one = 1. +let minus_one = -1. + +let pi = 0x3.243F6A8885A308D313198A2E037073 +let sqrt_pi = 0x1.C5BF891B4EF6AA79C3B0520D5DB938 +let sqrt_2pi = 0x2.81B263FEC4E0B2CAF9483F5CE459DC +let euler = 0x0.93C467E37DB0C7A4D1BE3F810152CB + +(* The bits of INRIA's [Pervasives] that we just want to expose in + [Float]. Most are already deprecated in [Pervasives], and + eventually all of them should be. *) +include (Caml : sig + external frexp : float -> float * int = "caml_frexp_float" + external ldexp : (float [@unboxed]) -> (int [@untagged]) -> (float [@unboxed]) = "caml_ldexp_float" "caml_ldexp_float_unboxed" [@@noalloc] + external log10 : float -> float = "caml_log10_float" "log10" + [@@unboxed] [@@noalloc] + external expm1 : float -> float = "caml_expm1_float" "caml_expm1" + [@@unboxed] [@@noalloc] + external log1p : float -> float = "caml_log1p_float" "caml_log1p" + [@@unboxed] [@@noalloc] + external copysign : float -> float -> float = "caml_copysign_float" "caml_copysign" + [@@unboxed] [@@noalloc] + external cos : float -> float = "caml_cos_float" "cos" + [@@unboxed] [@@noalloc] + external sin : float -> float = "caml_sin_float" "sin" + [@@unboxed] [@@noalloc] + external tan : float -> float = "caml_tan_float" "tan" + [@@unboxed] [@@noalloc] + external acos : float -> float = "caml_acos_float" "acos" + [@@unboxed] [@@noalloc] + external asin : float -> float = "caml_asin_float" "asin" + [@@unboxed] [@@noalloc] + external atan : float -> float = "caml_atan_float" "atan" + [@@unboxed] [@@noalloc] + external atan2 : float -> float -> float = "caml_atan2_float" "atan2" + [@@unboxed] [@@noalloc] + external hypot : float -> float -> float = "caml_hypot_float" "caml_hypot" + [@@unboxed] [@@noalloc] + external cosh : float -> float = "caml_cosh_float" "cosh" + [@@unboxed] [@@noalloc] + external sinh : float -> float = "caml_sinh_float" "sinh" + [@@unboxed] [@@noalloc] + external tanh : float -> float = "caml_tanh_float" "tanh" + [@@unboxed] [@@noalloc] + external sqrt : float -> float = "caml_sqrt_float" "sqrt" + [@@unboxed] [@@noalloc] + external exp : float -> float = "caml_exp_float" "exp" + [@@unboxed] [@@noalloc] + external log : float -> float = "caml_log_float" "log" + [@@unboxed] [@@noalloc] + end) + +(* We need this indirection because these are exposed as "val" instead of "external" *) +let frexp = frexp +let ldexp = ldexp + +let epsilon_float = Caml.epsilon_float + +let of_int = Int.to_float +let to_int = Int.of_float + +let of_int63 i = Int63.to_float i + +let of_int64 i = Caml.Int64.to_float i + +let to_int64 = Caml.Int64.of_float + +let iround_lbound = lower_bound_for_int Int.num_bits +let iround_ubound = upper_bound_for_int Int.num_bits + +(* The performance of the "exn" rounding functions is important, so they are written + out separately, and tuned individually. (We could have the option versions call + the "exn" versions, but that imposes arguably gratuitous overhead---especially + in the case where the capture of backtraces is enabled upon "with"---and that seems + not worth it when compared to the relatively small amount of code duplication.) *) + +(* Error reporting below is very carefully arranged so that, e.g., [iround_nearest_exn] + itself can be inlined into callers such that they don't need to allocate a box for the + [float] argument. This is done with a box [box] function carefully chosen to allow the + compiler to create a separate box for the float only in error cases. See, e.g., + [../../zero/test/price_test.ml] for a mechanical test of this property when building + with [X_LIBRARY_INLINING=true]. *) + +let iround_up t = + if t > 0.0 then begin + let t' = ceil t in + if t' <= iround_ubound then + Some (Int.of_float_unchecked t') + else + None + end + else begin + if t >= iround_lbound then + Some (Int.of_float_unchecked t) + else + None + end + +let iround_up_exn t = + if t > 0.0 then begin + let t' = ceil t in + if t' <= iround_ubound then + Int.of_float_unchecked t' + else + invalid_argf "Float.iround_up_exn: argument (%f) is too large" (box t) () + end + else begin + if t >= iround_lbound then + Int.of_float_unchecked t + else + invalid_argf "Float.iround_up_exn: argument (%f) is too small or NaN" (box t) () + end +[@@ocaml.inline always] + +let iround_down t = + if t >= 0.0 then begin + if t <= iround_ubound then + Some (Int.of_float_unchecked t) + else + None + end + else begin + let t' = floor t in + if t' >= iround_lbound then + Some (Int.of_float_unchecked t') + else + None + end + +let iround_down_exn t = + if t >= 0.0 then begin + if t <= iround_ubound then + Int.of_float_unchecked t + else + invalid_argf "Float.iround_down_exn: argument (%f) is too large" (box t) () + end + else begin + let t' = floor t in + if t' >= iround_lbound then + Int.of_float_unchecked t' + else + invalid_argf "Float.iround_down_exn: argument (%f) is too small or NaN" (box t) () + end +[@@ocaml.inline always] + +let iround_towards_zero t = + if t >= iround_lbound && t <= iround_ubound then + Some (Int.of_float_unchecked t) + else + None + +let iround_towards_zero_exn t = + if t >= iround_lbound && t <= iround_ubound then + Int.of_float_unchecked t + else + invalid_argf "Float.iround_towards_zero_exn: argument (%f) is out of range or NaN" + (box t) + () +[@@ocaml.inline always] + +(* Outside of the range (round_nearest_lb..round_nearest_ub), all representable doubles + are integers in the mathematical sense, and [round_nearest] should be identity. + + However, for odd numbers with the absolute value between 2**52 and 2**53, the formula + [round_nearest x = floor (x + 0.5)] does not hold: + + {v + # let naive_round_nearest x = floor (x +. 0.5);; + # let x = 2. ** 52. +. 1.;; + val x : float = 4503599627370497. + # naive_round_nearest x;; + - : float = 4503599627370498. + v} +*) + +let round_nearest_lb = -.(2. ** 52.) +let round_nearest_ub = 2. ** 52. + +(* For [x = one_ulp `Down 0.5], the formula [floor (x +. 0.5)] for rounding to nearest + does not work, because the exact result is halfway between [one_ulp `Down 1.] and [1.], + and it gets rounded up to [1.] due to the round-ties-to-even rule. *) +let one_ulp_less_than_half = one_ulp `Down 0.5 +let add_half_for_round_nearest t = + t +. (if t = one_ulp_less_than_half then + one_ulp_less_than_half (* since t < 0.5, make sure the result is < 1.0 *) + else + 0.5) + +let iround_nearest_32 t = + if t >= 0. then + let t' = add_half_for_round_nearest t in + if t' <= iround_ubound then + Some (Int.of_float_unchecked t') + else + None + else + let t' = floor (t +. 0.5) in + if t' >= iround_lbound then + Some (Int.of_float_unchecked t') + else + None + +let iround_nearest_64 t = + if t >= 0. then + if t < round_nearest_ub then + Some (Int.of_float_unchecked (add_half_for_round_nearest t)) + else + if t <= iround_ubound then + Some (Int.of_float_unchecked t) + else + None + else + if t > round_nearest_lb then + Some (Int.of_float_unchecked (floor (t +. 0.5))) + else + if t >= iround_lbound then + Some (Int.of_float_unchecked t) + else + None + +let iround_nearest = + match Word_size.word_size with + | W64 -> iround_nearest_64 + | W32 -> iround_nearest_32 + +let iround_nearest_exn_32 t = + if t >= 0. then + let t' = add_half_for_round_nearest t in + if t' <= iround_ubound then + Int.of_float_unchecked t' + else + invalid_argf "Float.iround_nearest_exn: argument (%f) is too large" (box t) () + else + let t' = floor (t +. 0.5) in + if t' >= iround_lbound then + Int.of_float_unchecked t' + else + invalid_argf "Float.iround_nearest_exn: argument (%f) is too small" (box t) () + +let iround_nearest_exn_64 t = + if t >= 0. then + if t < round_nearest_ub then + Int.of_float_unchecked (add_half_for_round_nearest t) + else + if t <= iround_ubound then + Int.of_float_unchecked t + else + invalid_argf "Float.iround_nearest_exn: argument (%f) is too large" (box t) () + else + if t > round_nearest_lb then + Int.of_float_unchecked (floor (t +. 0.5)) + else + if t >= iround_lbound then + Int.of_float_unchecked t + else + invalid_argf "Float.iround_nearest_exn: argument (%f) is too small or NaN" (box t) () +[@@ocaml.inline always] + +let iround_nearest_exn = + match Word_size.word_size with + | W64 -> iround_nearest_exn_64 + | W32 -> iround_nearest_exn_32 + +(* The following [iround_exn] and [iround] functions are slower than the ones above. + Their equivalence to those functions is tested in the unit tests below. *) + +let iround_exn ?(dir=`Nearest) t = + match dir with + | `Zero -> iround_towards_zero_exn t + | `Nearest -> iround_nearest_exn t + | `Up -> iround_up_exn t + | `Down -> iround_down_exn t +[@@inline] + +let iround ?(dir=`Nearest) t = + try Some (iround_exn ~dir t) + with _ -> None + +let is_inf x = + match Caml.classify_float x with + | FP_infinite -> true + | _ -> false + +let min_inan (x : t) y = + if is_nan y then x + else if is_nan x then y + else if x < y then x else y + +let max_inan (x : t) y = + if is_nan y then x + else if is_nan x then y + else if x > y then x else y + +let add = (+.) +let sub = (-.) +let neg = (~-.) +let abs = Caml.abs_float +let scale = ( *. ) + +let square x = x *. x + +module Parts : sig + type t + + val fractional : t -> float + val integral : t -> float + val modf : float -> t +end = struct + type t = float * float + + let fractional t = fst t + let integral t = snd t + let modf = modf +end +let modf = Parts.modf + +let round_down = floor + +let round_up = ceil + +let round_towards_zero t = + if t >= 0. + then round_down t + else round_up t + +(* see the comment above [round_nearest_lb] and [round_nearest_ub] for an explanation *) +let round_nearest t = + if t > round_nearest_lb && t < round_nearest_ub then + floor (add_half_for_round_nearest t) + else + t +. 0. + +let round_nearest_half_to_even t = + if t <= round_nearest_lb || t >= round_nearest_ub then + t +. 0. + else + let floor = floor t in + (* [ceil_or_succ = if t is an integer then t +. 1. else ceil t]. Faster than [ceil]. *) + let ceil_or_succ = floor +. 1. in + let diff_floor = t -. floor in + let diff_ceil = ceil_or_succ -. t in + if diff_floor < diff_ceil then + floor + else + if diff_floor > diff_ceil then + ceil_or_succ + else + (* exact tie, pick the even *) + if mod_float floor 2. = 0. then + floor + else + ceil_or_succ + +let int63_round_lbound = lower_bound_for_int Int63.num_bits +let int63_round_ubound = upper_bound_for_int Int63.num_bits + +let int63_round_up_exn t = + if t > 0.0 then begin + let t' = ceil t in + if t' <= int63_round_ubound then + Int63.of_float_unchecked t' + else + invalid_argf "Float.int63_round_up_exn: argument (%f) is too large" (Float0.box t) () + end + else begin + if t >= int63_round_lbound then + Int63.of_float_unchecked t + else + invalid_argf "Float.int63_round_up_exn: argument (%f) is too small or NaN" + (Float0.box t) () + end + +let int63_round_down_exn t = + if t >= 0.0 then begin + if t <= int63_round_ubound then + Int63.of_float_unchecked t + else + invalid_argf "Float.int63_round_down_exn: argument (%f) is too large" + (Float0.box t) () + end + else begin + let t' = floor t in + if t' >= int63_round_lbound then + Int63.of_float_unchecked t' + else + invalid_argf "Float.int63_round_down_exn: argument (%f) is too small or NaN" + (Float0.box t) () + end + +let int63_round_nearest_portable_alloc_exn t0 = + let t = round_nearest t0 in + if t > 0. + then begin + if t <= int63_round_ubound + then Int63.of_float_unchecked t + else invalid_argf + "Float.int63_round_nearest_portable_alloc_exn: argument (%f) is too large" + (box t0) + () + end + else begin + if t >= int63_round_lbound + then Int63.of_float_unchecked t + else invalid_argf + "Float.int63_round_nearest_portable_alloc_exn: argument (%f) is too small or NaN" + (box t0) + () + end + +let int63_round_nearest_arch64_noalloc_exn f = Int63.of_int (iround_nearest_exn f) + +let int63_round_nearest_exn = + match Word_size.word_size with + | W64 -> int63_round_nearest_arch64_noalloc_exn + | W32 -> int63_round_nearest_portable_alloc_exn + +let round ?(dir=`Nearest) t = + match dir with + | `Nearest -> round_nearest t + | `Down -> round_down t + | `Up -> round_up t + | `Zero -> round_towards_zero t + +module Class = struct + type t = + | Infinite + | Nan + | Normal + | Subnormal + | Zero + [@@deriving_inline compare, enumerate, sexp] + let compare : t -> t -> int = Ppx_compare_lib.polymorphic_compare + let all : t list = [Infinite; Nan; Normal; Subnormal; Zero] + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = + let _tp_loc = "src/float.ml.Class.t" in + function + | Ppx_sexp_conv_lib.Sexp.Atom ("infinite"|"Infinite") -> Infinite + | Ppx_sexp_conv_lib.Sexp.Atom ("nan"|"Nan") -> Nan + | Ppx_sexp_conv_lib.Sexp.Atom ("normal"|"Normal") -> Normal + | Ppx_sexp_conv_lib.Sexp.Atom ("subnormal"|"Subnormal") -> Subnormal + | Ppx_sexp_conv_lib.Sexp.Atom ("zero"|"Zero") -> Zero + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("infinite"|"Infinite"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("nan"|"Nan"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("normal"|"Normal"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("subnormal"|"Subnormal"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("zero"|"Zero"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as sexp + -> Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | Infinite -> Ppx_sexp_conv_lib.Sexp.Atom "Infinite" + | Nan -> Ppx_sexp_conv_lib.Sexp.Atom "Nan" + | Normal -> Ppx_sexp_conv_lib.Sexp.Atom "Normal" + | Subnormal -> Ppx_sexp_conv_lib.Sexp.Atom "Subnormal" + | Zero -> Ppx_sexp_conv_lib.Sexp.Atom "Zero" + [@@@end] + + let to_string t = string_of_sexp (sexp_of_t t) + let of_string s = t_of_sexp (sexp_of_string s) +end + +let classify t = + let module C = Class in + match Caml.classify_float t with + | FP_normal -> C.Normal + | FP_subnormal -> C.Subnormal + | FP_zero -> C.Zero + | FP_infinite -> C.Infinite + | FP_nan -> C.Nan +;; + +let is_finite t = + not (t = infinity || t = neg_infinity || is_nan t) +;; + +let insert_underscores ?(delimiter='_') ?(strip_zero=false) string = + match String.lsplit2 string ~on:'.' with + | None -> + Int_conversions.insert_delimiter string ~delimiter + | Some (left, right) -> + let left = Int_conversions.insert_delimiter left ~delimiter in + let right = + if strip_zero + then String.rstrip right ~drop:(fun c -> Char.(=) c '0') + else right + in + match right with + | "" -> left + | _ -> left ^ "." ^ right +;; + +let to_string_hum ?delimiter ?(decimals=3) ?strip_zero f = + if Int_replace_polymorphic_compare.(<) decimals 0 then + invalid_argf "to_string_hum: invalid argument ~decimals=%d" decimals (); + match classify f with + | Class.Infinite -> if f > 0. then "inf" else "-inf" + | Class.Nan -> "nan" + | Class.Normal + | Class.Subnormal + | Class.Zero -> insert_underscores (sprintf "%.*f" decimals f) ?delimiter ?strip_zero +;; + +let sexp_of_t t = + let sexp = sexp_of_t t in + match !Sexp.of_float_style with + | `No_underscores -> sexp + | `Underscores -> + match sexp with + | List _ -> raise_s (Sexp.message "[sexp_of_float] produced strange sexp" + ["sexp", Sexp.sexp_of_t sexp]) + | Atom string -> + if String.contains string 'E' + then sexp + else Atom (insert_underscores string) +;; + +let to_padded_compact_string t = + + (* Round a ratio toward the nearest integer, resolving ties toward the nearest even + number. For sane inputs (in particular, when [denominator] is an integer and + [abs numerator < 2e52]) this should be accurate. Otherwise, the result might be a + little bit off, but we don't really use that case. *) + let iround_ratio_exn ~numerator ~denominator = + let k = floor (numerator /. denominator) in + (* if [abs k < 2e53], then both [k] and [k +. 1.] are accurately represented, and in + particular [k +. 1. > k]. If [denominator] is also an integer, and + [abs (denominator *. (k +. 1)) < 2e53] (and in some other cases, too), then [lower] + and [higher] are actually both accurate. Since (roughly) + [numerator = denominator *. k] then for [abs numerator < 2e52] we should be + fine. *) + let lower = denominator *. k in + let higher = denominator *. (k +. 1.) in + (* Subtracting numbers within a factor of two from each other is accurate. + So either the two subtractions below are accurate, or k = 0, or k = -1. + In case of a tie, round to even. *) + let diff_right = higher -. numerator in + let diff_left = numerator -. lower in + let k = iround_nearest_exn k in + if diff_right < diff_left then + k + 1 + else if diff_right > diff_left then + k + else + (* a tie *) + if Int_replace_polymorphic_compare.(=) (k mod 2) 0 then k else k + 1 + in + + match classify t with + | Class.Infinite -> if t < 0.0 then "-inf " else "inf " + | Class.Nan -> "nan " + | Class.Subnormal | Class.Normal | Class.Zero -> + let go t = + let conv_one t = + assert (0. <= t && t < 999.95); + let x = format_float "%.1f" t in + (* Fix the ".0" suffix *) + if String.is_suffix x ~suffix:".0" + then begin + let x = Bytes.of_string x in + let n = Bytes.length x in + Bytes.set x (n - 1) ' '; + Bytes.set x (n - 2) ' '; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:x + end + else x + in + let conv mag t denominator = + assert (denominator = 100. && t >= 999.95 + || denominator >= 100_000. && t >= round_nearest (denominator *. 9.999_5)); + assert (t < round_nearest (denominator *. 9_999.5)); + let i, d = + let k = iround_ratio_exn ~numerator:t ~denominator in + (* [mod] is okay here because we know i >= 0. *) + k / 10, k mod 10 + in + let open Int_replace_polymorphic_compare in + assert (0 <= i && i < 1000); + assert (0 <= d && d < 10); + if d = 0 then + sprintf "%d%c " i mag + else + sprintf "%d%c%d" i mag d + in + (* While the standard metric prefixes (e.g. capital "M" rather than "m", [1]) are + nominally more correct, this hinders readability in our case. E.g., 10G6 and + 1066 look too similar. That's an extreme example, but in general k,m,g,t,p + probably stand out better than K,M,G,T,P when interspersed with digits. + + [1] http://en.wikipedia.org/wiki/Metric_prefix *) + (* The trick here is that: + - the first boundary (999.95) as a float is slightly over-represented (so it is + better approximated as "1k" than as "999.9"), + - the other boundaries are accurately represented, because they are integers. + That's why the strict equalities below do exactly what we want. *) + if t < 999.95E0 then conv_one t + else if t < 999.95E3 then conv 'k' t 100. + else if t < 999.95E6 then conv 'm' t 100_000. + else if t < 999.95E9 then conv 'g' t 100_000_000. + else if t < 999.95E12 then conv 't' t 100_000_000_000. + else if t < 999.95E15 then conv 'p' t 100_000_000_000_000. + else sprintf "%.1e" t + in + if t >= 0. + then go t + else "-" ^ (go ~-.t) + +(* Performance note: Initializing the accumulator to 1 results in one extra + multiply; e.g., to compute x ** 4, we in principle only need 2 multiplies, + but this function will have 3 multiplies. However, attempts to avoid this + (like decrementing n and initializing accum to be x, or handling small + exponents as a special case) have not yielded anything that is a net + improvement. +*) +let int_pow x n = + let open Int_replace_polymorphic_compare in + if n = 0 then + 1. + else begin + (* Using [x +. (-0.)] on the following line convinces the compiler to avoid a certain + boxing (that would result in allocation in each iteration). Soon, the compiler + shouldn't need this "hint" to avoid the boxing. The reason we add -0 rather than 0 + is that [x +. (-0.)] is apparently always the same as [x], whereas [x +. 0.] is + not, in that it sends [-0.] to [0.]. This makes a difference because we want + [int_pow (-0.) (-1)] to return neg_infinity just like [-0. ** -1.] would. *) + let x = ref (x +. (-0.)) in + let n = ref n in + let accum = ref 1. in + if !n < 0 then begin + (* x ** n = (1/x) ** -n *) + x := 1. /. !x; + n := ~- !n; + if !n < 0 then begin + (* n must have been min_int, so it is now so big that it has wrapped around. + We decrement it so that it looks positive again, but accordingly have + to put an extra factor of x in the accumulator. + *) + accum := !x; + decr n + end + end; + (* Letting [a] denote (the original value of) [x ** n], we maintain + the invariant that [(x ** n) *. accum = a]. *) + while !n > 1 do + if !n land 1 <> 0 then accum := !x *. !accum; + x := !x *. !x; + n := !n lsr 1 + done; + (* n is necessarily 1 at this point, so there is one additional + multiplication by x. *) + !x *. !accum + end + +let round_gen x ~how = + if x = 0. then 0. + else if not (is_finite x) then x + else begin + (* Significant digits and decimal digits. *) + let sd, dd = + match how with + | `significant_digits sd -> + let dd = sd - to_int (round_up (log10 (abs x))) in + sd, dd + | `decimal_digits dd -> + let sd = dd + to_int (round_up (log10 (abs x))) in + sd, dd + in + let open Int_replace_polymorphic_compare in + if sd < 0 + then 0. + else if sd >= 17 + then x + else + (* Choose the order that is exactly representable as a float. Small positive + integers are, but their inverses in most cases are not. *) + let abs_dd = Int.abs dd in + if abs_dd > 22 || sd >= 16 + (* 10**22 is exactly representable as a float, but 10**23 is not, so use the slow + path. Similarly, if we need 16 significant digits in the result, then the integer + [round_nearest (x order)] might not be exactly representable as a float, since + for some ranges we only have 15 digits of precision guaranteed. + + That said, we are still rounding twice here: + + 1) first time when rounding [x *. order] or [x /. order] to the nearest float + (just the normal way floating-point multiplication or division works), + + 2) second time when applying [round_nearest_half_to_even] to the result of the + above operation + + So for arguments within an ulp from a tie we might still produce an off-by-one + result. *) + then of_string (sprintf "%.*g" sd x) + else + let order = int_pow 10. abs_dd in + if dd >= 0 + then round_nearest_half_to_even (x *. order) /. order + else round_nearest_half_to_even (x /. order) *. order + end + +let round_significant x ~significant_digits = + if Int_replace_polymorphic_compare.(<=) significant_digits 0 then + raise (Invalid_argument + ("Float.round_significant: invalid argument significant_digits:" + ^ Int.to_string significant_digits)) + else + round_gen x ~how:(`significant_digits significant_digits) + +let round_decimal x ~decimal_digits = + round_gen x ~how:(`decimal_digits decimal_digits) + +let between t ~low ~high = low <= t && t <= high + +let clamp_exn t ~min ~max = + (* Also fails if [min] or [max] is nan *) + assert (min <= max); + (* clamp_unchecked is in float0.ml *) + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + (* Also fails if [min] or [max] is nan *) + if min <= max then + Ok (clamp_unchecked t ~min ~max) + else + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + +let ( + ) = ( +. ) +let ( - ) = ( -. ) +let ( * ) = ( *. ) +let ( ** ) = Caml.( ** ) +let ( / ) = ( /. ) +let ( ~- ) = ( ~-. ) + +let sign_exn t : Sign.t = + if t > 0. + then Pos + else if t < 0. + then Neg + else if t = 0. + then Zero + else Error.raise_s (Sexp.message "Float.sign_exn of NAN" + ["", sexp_of_t t]) + +let sign_or_nan t : Sign_or_nan.t = + if t > 0. + then Pos + else if t < 0. + then Neg + else if t = 0. + then Zero + else Nan + +let ieee_negative t = + let bits = Caml.Int64.bits_of_float t in + Poly.(bits < Caml.Int64.zero) + +let exponent_bits = 11 +let mantissa_bits = 52 + +let exponent_mask64 = Int64.((shift_left one exponent_bits) - one) +let exponent_mask = Int64.to_int_exn exponent_mask64 +let mantissa_mask = Int63.((shift_left one mantissa_bits) - one) +let mantissa_mask64 = Int63.to_int64 mantissa_mask + +let ieee_exponent t = + let bits = Caml.Int64.bits_of_float t in + Int64.((bit_and (shift_right_logical bits mantissa_bits) exponent_mask64)) + |> Caml.Int64.to_int + +let ieee_mantissa t = + let bits = Caml.Int64.bits_of_float t in + Int63.of_int64_exn Caml.Int64.(logand bits mantissa_mask64) + +let create_ieee_exn ~negative ~exponent ~mantissa = + if Int.(bit_and exponent exponent_mask <> exponent) + then failwithf "exponent %d out of range [0, %d]" + exponent exponent_mask () + else if Int63.(bit_and mantissa mantissa_mask <> mantissa) + then failwithf "mantissa %s out of range [0, %s]" + (Int63.to_string mantissa) (Int63.to_string mantissa_mask) () + else + let sign_bits = if negative then Caml.Int64.min_int else Caml.Int64.zero in + let expt_bits = Caml.Int64.shift_left (Caml.Int64.of_int exponent) mantissa_bits in + let mant_bits = Int63.to_int64 mantissa in + let bits = Caml.Int64.(logor sign_bits (logor expt_bits mant_bits)) in + Caml.Int64.float_of_bits bits + +let create_ieee ~negative ~exponent ~mantissa = + Or_error.try_with (fun () -> create_ieee_exn ~negative ~exponent ~mantissa) + +module Terse = struct + type nonrec t = t + let t_of_sexp = t_of_sexp + + let to_string x = Printf.sprintf "%.8G" x + let sexp_of_t x = Sexp.Atom (to_string x) + let of_string x = of_string x +end + +let validate_ordinary t = + Validate.of_error_opt ( + let module C = Class in + match classify t with + | C.Normal | C.Subnormal | C.Zero -> None + | C.Infinite -> Some "value is infinite" + | C.Nan -> Some "value is NaN") +;; + +module V = struct + module ZZ = Comparable.Validate (T) + + let validate_bound ~min ~max t = + Validate.first_failure (validate_ordinary t) (ZZ.validate_bound t ~min ~max) + ;; + + let validate_lbound ~min t = + Validate.first_failure (validate_ordinary t) (ZZ.validate_lbound t ~min) + ;; + + let validate_ubound ~max t = + Validate.first_failure (validate_ordinary t) (ZZ.validate_ubound t ~max) + ;; +end + +include V + +include Comparable.With_zero (struct + include T + let zero = zero + include V + end) + +(* These are partly here as a performance hack to avoid some boxing we're getting with + the versions we get from [With_zero]. They also make [Float.is_negative nan] and + [Float.is_non_positive nan] return [false]; the versions we get from [With_zero] return + [true]. *) +let is_positive t = t > 0. +let is_non_negative t = t >= 0. +let is_negative t = t < 0. +let is_non_positive t = t <= 0. + +include Pretty_printer.Register(struct + include T + let module_name = "Base.Float" + let to_string = to_string + end) + +module O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = Caml.( ** ) + include (Float_replace_polymorphic_compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int = of_int + let of_float x = x +end + +module O_dot = struct + let ( *. ) = ( * ) + let ( +. ) = ( + ) + let ( -. ) = ( - ) + let ( /. ) = ( / ) + let ( ~-. ) = ( ~- ) + let ( **. ) = Caml.( ** ) +end + +module Private = struct + let lower_bound_for_int = lower_bound_for_int + let upper_bound_for_int = upper_bound_for_int + let specialized_hash = hash_float + let one_ulp_less_than_half = one_ulp_less_than_half + let int63_round_nearest_portable_alloc_exn = int63_round_nearest_portable_alloc_exn + let int63_round_nearest_arch64_noalloc_exn = int63_round_nearest_arch64_noalloc_exn + let iround_nearest_exn_64 = iround_nearest_exn_64 +end + + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Float_replace_polymorphic_compare + +(* These functions specifically replace defaults in replace_polymorphic_compare *) +let min (x : t) y = + if is_nan x || is_nan y then nan + else if x < y then x else y + +let max (x : t) y = + if is_nan x || is_nan y then nan + else if x > y then x else y diff --git a/src/float.mli b/src/float.mli new file mode 100644 index 0000000..db3185e --- /dev/null +++ b/src/float.mli @@ -0,0 +1,587 @@ +(** Floating-point representation and utilities. + + If using 32-bit OCaml, you cannot quite assume operations act as you'd expect for IEEE + 64-bit floats. E.g., one can have [let x = ~-. (2. ** 62.) in x = x -. 1.] evaluate + to [false] while [let x = ~-. (2. ** 62.) in let y = x -. 1 in x = y] evaluates to + [true]. This is related to 80-bit registers being used for calculations; you can + force representation as a 64-bit value by let-binding. *) + +open! Import + +type t = float [@@deriving_inline hash] +include +sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value +end[@@ocaml.doc "@inline"] +[@@@end] + +include Floatable.S with type t := t + +(** [max] and [min] will return nan if either argument is nan. + + The [validate_*] functions always fail if class is [Nan] or [Infinite]. *) +include Identifiable.S with type t := t +include Comparable.With_zero with type t := t + +(** [validate_ordinary] fails if class is [Nan] or [Infinite]. *) +val validate_ordinary : t Validate.check + +val nan : t + +val infinity : t +val neg_infinity : t + +val max_value : t (** Equal to [infinity]. *) + +val min_value : t (** Equal to [neg_infinity]. *) + +val zero : t +val one : t +val minus_one : t + +val pi : t (** The constant pi. *) + +val sqrt_pi : t (** The constant sqrt(pi). *) + +val sqrt_2pi : t (** The constant sqrt(2 * pi). *) + +val euler : t (** Euler-Mascheroni constant (γ). *) + +(** The difference between 1.0 and the smallest exactly representable floating-point + number greater than 1.0. That is: + + [epsilon_float = (one_ulp `Up 1.0) -. 1.0] + + This gives the relative accuracy of type [t], in the sense that for numbers on the + order of [x], the roundoff error is on the order of [x *. float_epsilon]. + + See also: {{:http://en.wikipedia.org/wiki/Machine_epsilon} Machine epsilon}. +*) +val epsilon_float : t + +val max_finite_value : t + +(** + - [min_positive_subnormal_value = 2 ** -1074] + - [min_positive_normal_value = 2 ** -1022] *) + +val min_positive_subnormal_value : t +val min_positive_normal_value : t + +(** An order-preserving bijection between all floats except for nans, and all int64s with + absolute value smaller than or equal to [2**63 - 2**52]. Note both 0. and -0. map to + 0L. *) +val to_int64_preserve_order : t -> int64 option +val to_int64_preserve_order_exn : t -> int64 + +(** Returns [nan] if the absolute value of the argument is too large. *) +val of_int64_preserve_order : int64 -> t + +(** The next or previous representable float. ULP stands for "unit of least precision", + and is the spacing between floating point numbers. Both [one_ulp `Up infinity] and + [one_ulp `Down neg_infinity] return a nan. *) +val one_ulp : [`Up | `Down] -> t -> t + +(** Note that this doesn't round trip in either direction. For example, [Float.to_int + (Float.of_int max_int) <> max_int]. *) +val of_int : int -> t +val to_int : t -> int + +val of_int63 : Int63.t -> t + +val of_int64 : int64 -> t +val to_int64 : t -> int64 + +(** [round] rounds a float to an integer float. [iround{,_exn}] rounds a float to an + int. Both round according to a direction [dir], with default [dir] being [`Nearest]. + + {v + | `Down | rounds toward Float.neg_infinity | + | `Up | rounds toward Float.infinity | + | `Nearest | rounds to the nearest int ("round half-integers up") | + | `Zero | rounds toward zero | + v} + + [iround_exn] raises when trying to handle nan or trying to handle a float outside the + range \[float min_int, float max_int). + + + Here are some examples for [round] for each direction: + + {v + | `Down | [-2.,-1.) to -2. | [-1.,0.) to -1. | [0.,1.) to 0., [1.,2.) to 1. | + | `Up | (-2.,-1.] to -1. | (-1.,0.] to -0. | (0.,1.] to 1., (1.,2.] to 2. | + | `Zero | (-2.,-1.] to -1. | (-1.,1.) to 0. | [1.,2.) to 1. | + | `Nearest | [-1.5,-0.5) to -1. | [-0.5,0.5) to 0. | [0.5,1.5) to 1. | + v} + + For convenience, versions of these functions with the [dir] argument hard-coded are + provided. If you are writing performance-critical code you should use the + versions with the hard-coded arguments (e.g. [iround_down_exn]). The [_exn] ones + are the fastest. + + The following properties hold: + + - [of_int (iround_*_exn i) = i] for any float [i] that is an integer with + [min_int <= i <= max_int]. + + - [round_* i = i] for any float [i] that is an integer. + + - [iround_*_exn (of_int i) = i] for any int [i] with [-2**52 <= i <= 2**52]. *) +val round : ?dir:[`Zero|`Nearest|`Up|`Down] -> t -> t +val iround : ?dir:[`Zero|`Nearest|`Up|`Down] -> t -> int option +val iround_exn : ?dir:[`Zero|`Nearest|`Up|`Down] -> t -> int + +val round_towards_zero : t -> t +val round_down : t -> t +val round_up : t -> t +val round_nearest : t -> t (** Rounds half integers up. *) + +val round_nearest_half_to_even : t -> t (** Rounds half integers to the even integer. *) + +val iround_towards_zero : t -> int option +val iround_down : t -> int option +val iround_up : t -> int option +val iround_nearest : t -> int option + +val iround_towards_zero_exn : t -> int +val iround_down_exn : t -> int +val iround_up_exn : t -> int +val iround_nearest_exn : t -> int + +val int63_round_down_exn : t -> Int63.t +val int63_round_up_exn : t -> Int63.t +val int63_round_nearest_exn : t -> Int63.t + +(** If [f <= iround_lbound || f >= iround_ubound], then [iround*] functions will refuse + to round [f], returning [None] or raising as appropriate. *) +val iround_lbound : t +val iround_ubound : t + +(** [round_significant x ~significant_digits:n] rounds to the nearest number with [n] + significant digits. More precisely: it returns the representable float closest to [x + rounded to n significant digits]. It is meant to be equivalent to [sprintf "%.*g" n x + |> Float.of_string] but faster (10x-15x). Exact ties are resolved as round-to-even. + + However, it might in rare cases break the contract above. + + + It might in some cases appear as if it violates the round-to-even rule: + + {[ + let x = 4.36083208835;; + let z = 4.3608320883;; + assert (z = fast_approx_round_significant x ~sf:11) + ]} + + But in this case so does sprintf, since [x] as a float is slightly + under-represented: + + {[ + sprintf "%.11g" x = "4.3608320883";; + sprintf "%.30g" x = "4.36083208834999958014577714493" + ]} + + More importantly, [round_significant] might sometimes give a different + result than [sprintf ... |> Float.of_string] because it round-trips through an + integer. For example, the decimal fraction 0.009375 is slightly under-represented as + a float: + + {[ sprintf "%.17g" 0.009375 = "0.0093749999999999997" ]} + + But: + + {[ 0.009375 *. 1e5 = 937.5 ]} + + Therefore: + + {[ round_significant 0.009375 ~significant_digits:3 = 0.00938 ]} + + whereas: + + {[ sprintf "%.3g" 0.009375 = "0.00937" ]} + + + In general we believe (and have tested on numerous examples) that the following + holds for all x: + + {[ + let s = sprintf "%.*g" significant_digits x |> Float.of_string in + s = round_significant ~significant_digits x + || s = round_significant ~significant_digits (one_ulp `Up x) + || s = round_significant ~significant_digits (one_ulp `Down x) + ]} + + Also, for float representations of decimal fractions (like 0.009375), + [round_significant] is more likely to give the "desired" result than [sprintf ... |> + of_string] (that is, the result of rounding the decimal fraction, rather than its + float representation). But it's not guaranteed either--see the [4.36083208835] + example above. + +*) +val round_significant : float -> significant_digits:int -> float + +(** [round_decimal x ~decimal_digits:n] rounds [x] to the nearest [10**(-n)]. For positive + [n] it is meant to be equivalent to [sprintf "%.*f" n x |> Float.of_string], but + faster. + + All the considerations mentioned in [round_significant] apply (both functions use the + same code path). +*) +val round_decimal : float -> decimal_digits:int -> float + + +val is_nan : t -> bool + +(** Includes positive and negative [Float.infinity]. *) +val is_inf : t -> bool + +(** [min_inan] and [max_inan] return, respectively, the min and max of the two given + values, except when one of the values is a [nan], in which case the other is + returned. (Returns [nan] if both arguments are [nan].) *) + +val min_inan : t -> t -> t +val max_inan : t -> t -> t + +val ( + ) : t -> t -> t +val ( - ) : t -> t -> t +val ( / ) : t -> t -> t +val ( * ) : t -> t -> t +val ( ** ) : t -> t -> t + +val ( ~- ) : t -> t + +(** Returns the fractional part and the whole (i.e., integer) part. For example, [modf + (-3.14)] returns [{ fractional = -0.14; integral = -3.; }]! *) +module Parts : sig + type outer + type t + val fractional : t -> outer + val integral : t -> outer +end with type outer := t +val modf : t -> Parts.t + +(** [mod_float x y] returns a result with the same sign as [x]. It returns [nan] if [y] + is [0]. It is basically + + {[ let mod_float x y = x -. float(truncate(x/.y)) *. y]} + + not + + {[ let mod_float x y = x -. floor(x/.y) *. y ]} + + and therefore resembles [mod] on integers more than [%]. *) +val mod_float : t -> t -> t + +(** {6 Ordinary functions for arithmetic operations} + + These are for modules that inherit from [t], since the infix operators are more + convenient. *) +val add : t -> t -> t +val sub : t -> t -> t +val neg : t -> t +val scale : t -> t -> t +val abs : t -> t + + +(** A sub-module designed to be opened to make working with floats more convenient. *) +module O : sig + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val ( * ) : t -> t -> t + val ( / ) : t -> t -> t + val ( ** ) : t -> t -> t + val ( ~- ) : t -> t + include Comparisons.Infix with type t := t + + val abs : t -> t + val neg : t -> t + val zero : t + val of_int : int -> t + val of_float : float -> t +end + +(** Similar to [O], except that operators are suffixed with a dot, allowing one to have + both int and float operators in scope simultaneously. *) +module O_dot : sig + val ( +. ) : t -> t -> t + val ( -. ) : t -> t -> t + val ( *. ) : t -> t -> t + val ( /. ) : t -> t -> t + val ( **. ) : t -> t -> t + val ( ~-. ) : t -> t +end + +(** [to_string x] builds a string [s] representing the float [x] that guarantees the round + trip, that is such that [Float.equal x (Float.of_string s)]. + + It usually yields as few significant digits as possible. That is, it won't print + [3.14] as [3.1400000000000001243]. The only exception is that occasionally it will + output 17 significant digits when the number can be represented with just 16 (but not + 15 or less) of them. *) +val to_string : t -> string + +(** Pretty print float, for example [to_string_hum ~decimals:3 1234.1999 = "1_234.200"] + [to_string_hum ~decimals:3 ~strip_zero:true 1234.1999 = "1_234.2" ]. No delimiters + are inserted to the right of the decimal. *) +val to_string_hum + : ?delimiter:char (** defaults to ['_'] *) + -> ?decimals:int (** defaults to [3] *) + -> ?strip_zero:bool (** defaults to [false] *) + -> t + -> string + +(** Produce a lossy compact string representation of the float. The float is scaled by + an appropriate power of 1000 and rendered with one digit after the decimal point, + except that the decimal point is written as '.', 'k', 'm', 'g', 't', or 'p' to + indicate the scale factor. (However, if the digit after the "decimal" point is 0, + it is suppressed.) + + The smallest scale factor that allows the number to be rendered with at most 3 digits + to the left of the decimal is used. If the number is too large for this format (i.e., + the absolute value is at least 999.95e15), scientific notation is used instead. E.g.: + + - [to_padded_compact_string (-0.01) = "-0 "] + - [to_padded_compact_string 1.89 = "1.9"] + - [to_padded_compact_string 999_949.99 = "999k9"] + - [to_padded_compact_string 999_950. = "1m "] + + In the case where the digit after the "decimal", or the "decimal" itself is omitted, + the numbers are padded on the right with spaces to ensure the last two columns of the + string always correspond to the decimal and the digit afterward (except in the case of + scientific notation, where the exponent is the right-most element in the string and + could take up to four characters). + + - [to_padded_compact_string 1. = "1 "] + - [to_padded_compact_string 1.e6 = "1m "] + - [to_padded_compact_string 1.e16 = "1.e+16"] + - [to_padded_compact_string max_finite_value = "1.8e+308"] + + Numbers in the range -.05 < x < .05 are rendered as "0 " or "-0 ". + + Other cases: + + - [to_padded_compact_string nan = "nan "] + - [to_padded_compact_string infinity = "inf "] + - [to_padded_compact_string neg_infinity = "-inf "] + + Exact ties are resolved to even in the decimal: + + - [to_padded_compact_string 3.25 = "3.2"] + - [to_padded_compact_string 3.75 = "3.8"] + - [to_padded_compact_string 33_250. = "33k2"] + - [to_padded_compact_string 33_350. = "33k4"] *) +val to_padded_compact_string : t -> string + +(** [int_pow x n] computes [x ** float n] via repeated squaring. It is generally much + faster than [**]. + + Note that [int_pow x 0] always returns [1.], even if [x = nan]. This + coincides with [x ** 0.] and is intentional. + + For [n >= 0] the result is identical to an n-fold product of [x] with itself under + [*.], with a certain placement of parentheses. For [n < 0] the result is identical + to [int_pow (1. /. x) (-n)]. + + The error will be on the order of [|n|] ulps, essentially the same as if you + perturbed [x] by up to a ulp and then exponentiated exactly. + + Benchmarks show a factor of 5-10 speedup (relative to [**]) for exponents up to about + 1000 (approximately 10ns vs. 70ns). For larger exponents the advantage is smaller but + persists into the trillions. For a recent or more detailed comparison, run the + benchmarks. + + Depending on context, calling this function might or might not allocate 2 minor words. + Even if called in a way that causes allocation, it still appears to be faster than + [**]. *) +val int_pow : t -> int -> t + +(** [square x] returns [x *. x]. *) +val square : t -> t + +(** [ldexp x n] returns [x *. 2 ** n] *) +val ldexp : t -> int -> t + +(** [frexp f] returns the pair of the significant and the exponent of [f]. When [f] is + zero, the significant [x] and the exponent [n] of [f] are equal to zero. When [f] is + non-zero, they are defined by [f = x *. 2 ** n] and [0.5 <= x < 1.0]. *) +val frexp : t -> t * int + +(** Base 10 logarithm. *) +external log10 : t -> t = "caml_log10_float" "log10" +[@@unboxed] [@@noalloc] + +(** [expm1 x] computes [exp x -. 1.0], giving numerically-accurate results even if [x] is + close to [0.0]. *) +external expm1 : t -> t = "caml_expm1_float" "caml_expm1" +[@@unboxed] [@@noalloc] + +(** [log1p x] computes [log(1.0 +. x)] (natural logarithm), giving numerically-accurate + results even if [x] is close to [0.0]. *) +external log1p : t -> t = "caml_log1p_float" "caml_log1p" +[@@unboxed] [@@noalloc] + +(** [copysign x y] returns a float whose absolute value is that of [x] and whose sign is + that of [y]. If [x] is [nan], returns [nan]. If [y] is [nan], returns either [x] or + [-. x], but it is not specified which. *) +external copysign : t -> t -> t = "caml_copysign_float" "caml_copysign" +[@@unboxed] [@@noalloc] + +(** Cosine. Argument is in radians. *) +external cos : t -> t = "caml_cos_float" "cos" +[@@unboxed] [@@noalloc] + +(** Sine. Argument is in radians. *) +external sin : t -> t = "caml_sin_float" "sin" +[@@unboxed] [@@noalloc] + +(** Tangent. Argument is in radians. *) +external tan : t -> t = "caml_tan_float" "tan" +[@@unboxed] [@@noalloc] + +(** Arc cosine. The argument must fall within the range [[-1.0, 1.0]]. Result is in + radians and is between [0.0] and [pi]. *) +external acos : t -> t = "caml_acos_float" "acos" +[@@unboxed] [@@noalloc] + +(** Arc sine. The argument must fall within the range [[-1.0, 1.0]]. Result is in + radians and is between [-pi/2] and [pi/2]. *) +external asin : t -> t = "caml_asin_float" "asin" +[@@unboxed] [@@noalloc] + +(** Arc tangent. Result is in radians and is between [-pi/2] and [pi/2]. *) +external atan : t -> t = "caml_atan_float" "atan" +[@@unboxed] [@@noalloc] + +(** [atan2 y x] returns the arc tangent of [y /. x]. The signs of [x] and [y] are used to + determine the quadrant of the result. Result is in radians and is between [-pi] and + [pi]. *) +external atan2 : t -> t -> t = "caml_atan2_float" "atan2" +[@@unboxed] [@@noalloc] + +(** [hypot x y] returns [sqrt(x *. x + y *. y)], that is, the length of the hypotenuse of + a right-angled triangle with sides of length [x] and [y], or, equivalently, the + distance of the point [(x,y)] to origin. *) +external hypot : t -> t -> t = "caml_hypot_float" "caml_hypot" +[@@unboxed] [@@noalloc] + +(** Hyperbolic cosine. Argument is in radians. *) +external cosh : t -> t = "caml_cosh_float" "cosh" +[@@unboxed] [@@noalloc] + +(** Hyperbolic sine. Argument is in radians. *) +external sinh : t -> t = "caml_sinh_float" "sinh" +[@@unboxed] [@@noalloc] + +(** Hyperbolic tangent. Argument is in radians. *) +external tanh : t -> t = "caml_tanh_float" "tanh" +[@@unboxed] [@@noalloc] + +(** Square root. *) +external sqrt : t -> t = "caml_sqrt_float" "sqrt" +[@@unboxed] [@@noalloc] + +(** Exponential. *) +external exp : t -> t = "caml_exp_float" "exp" +[@@unboxed] [@@noalloc] + +(** Natural logarithm. *) +external log : t -> t = "caml_log_float" "log" +[@@unboxed] [@@noalloc] + + +(** Excluding nan the floating-point "number line" looks like: + {v + t Class.t example + ^ neg_infinity Infinite neg_infinity + | neg normals Normal -3.14 + | neg subnormals Subnormal -.2. ** -1023. + | (-/+) zero Zero 0. + | pos subnormals Subnormal 2. ** -1023. + | pos normals Normal 3.14 + v infinity Infinite infinity + v} *) +module Class : sig + type t = + | Infinite + | Nan + | Normal + | Subnormal + | Zero + [@@deriving_inline compare, enumerate, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val all : t list + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Stringable.S with type t := t +end + +val classify : t -> Class.t + +(** [is_finite t] returns [true] iff [classify t] is in [Normal; Subnormal; Zero;]. *) +val is_finite : t -> bool + +(*_ Caution: If we remove this sig item, [sign] will still be present from + [Comparable.With_zero]. *) +val sign : t -> Sign.t +[@@deprecated "[since 2016-01] Replace [sign] with [robust_sign] or [sign_exn]"] + +(** The sign of a float. Both [-0.] and [0.] map to [Zero]. Raises on nan. All other + values map to [Neg] or [Pos]. *) +val sign_exn : t -> Sign.t + +(** The sign of a float, with support for NaN. Both [-0.] and [0.] map to [Zero]. All NaN + values map to [Nan]. All other values map to [Neg] or [Pos]. *) +val sign_or_nan : t -> Sign_or_nan.t + +(** These functions construct and destruct 64-bit floating point numbers based on their + IEEE representation with a sign bit, an 11-bit non-negative (biased) exponent, and a + 52-bit non-negative mantissa (or significand). See + {{:http://en.wikipedia.org/wiki/Double-precision_floating-point_format} Wikipedia} for + details of the encoding. + + In particular, if 1 <= exponent <= 2046, then: + + {[ + create_ieee_exn ~negative:false ~exponent ~mantissa + = 2 ** (exponent - 1023) * (1 + (2 ** -52) * mantissa) + ]} *) +val create_ieee : negative:bool -> exponent:int -> mantissa:Int63.t -> t Or_error.t +val create_ieee_exn : negative:bool -> exponent:int -> mantissa:Int63.t -> t +val ieee_negative : t -> bool +val ieee_exponent : t -> int +val ieee_mantissa : t -> Int63.t + +(** S-expressions contain at most 8 significant digits. *) +module Terse : sig + type nonrec t = t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Stringable.S with type t := t +end + +(**/**) +(*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) +module Private : sig + val lower_bound_for_int : int -> t + val upper_bound_for_int : int -> t + val specialized_hash : t -> int + val one_ulp_less_than_half : t + val int63_round_nearest_portable_alloc_exn : t -> Int63.t + val int63_round_nearest_arch64_noalloc_exn : t -> Int63.t + val iround_nearest_exn_64 : t -> int +end diff --git a/src/float0.ml b/src/float0.ml new file mode 100644 index 0000000..4a405f6 --- /dev/null +++ b/src/float0.ml @@ -0,0 +1,125 @@ +open! Import + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Float_replace_polymorphic_compare + +let is_nan x = (x : float) <> x + +(* An order-preserving bijection between all floats except for NaNs, and 99.95% of + int64s. + + Note we don't distinguish 0. and -0. as separate values here, they both map to 0L, which + maps back to 0. + + This should work both on little-endian and high-endian CPUs. Wikipedia says: "on + modern standard computers (i.e., implementing IEEE 754), one may in practice safely + assume that the endianness is the same for floating point numbers as for integers" + (http://en.wikipedia.org/wiki/Endianness#Floating-point_and_endianness). +*) +let to_int64_preserve_order t = + if is_nan t then + None + else + if t = 0. then (* also includes -0. *) + Some 0L + else + if t > 0. then + Some (Caml.Int64.bits_of_float t) + else + Some (Caml.Int64.neg (Caml.Int64.bits_of_float (~-. t))) +;; + +let to_int64_preserve_order_exn x = + Option.value_exn (to_int64_preserve_order x) +;; + +let of_int64_preserve_order x = + if Int64_replace_polymorphic_compare.(>=) x 0L then + Caml.Int64.float_of_bits x + else + ~-. (Caml.Int64.float_of_bits (Caml.Int64.neg x)) +;; + +let one_ulp dir t = + match to_int64_preserve_order t with + | None -> Caml.nan + | Some x -> + of_int64_preserve_order (Caml.Int64.add x (match dir with `Up -> 1L | `Down -> -1L)) +;; + +(* [upper_bound_for_int] and [lower_bound_for_int] are for calculating the max/min float + that fits in a given-size integer when rounded towards 0 (using [int_of_float]). + + max_int/min_int depend on [num_bits], e.g. +/- 2^30, +/- 2^62 if 31-bit, 63-bit + (respectively) while float is IEEE standard for double (52 significant bits). + + In all cases, we want to guarantee that + [lower_bound_for_int <= x <= upper_bound_for_int] + iff [int_of_float x] fits in an int with [num_bits] bits. + + [2 ** (num_bits - 1)] is the first float greater that max_int, we use the preceding + float as upper bound. + + [- (2 ** (num_bits - 1))] is equal to min_int. + For lower bound we look for the smallest float [f] satisfying [f > min_int - 1] so that + [f] rounds toward zero to [min_int] + + So in particular we will have: + [lower_bound_for_int x <= - (2 ** (1-x))] + [upper_bound_for_int x < 2 ** (1-x) ] +*) +let upper_bound_for_int num_bits = + let exp = Caml.float_of_int ( num_bits - 1 ) in + one_ulp `Down (2. ** exp) + +let is_x_minus_one_exact x = + (* [x = x -. 1.] does not work with x87 floating point arithmetic backend (which is used + on 32-bit ocaml) because of 80-bit register precision of intermediate computations. + + An alternative way of computing this: [x -. one_ulp `Down x <= 1.] is also prone to + the same precision issues: you need to make sure [x] is 64-bit. + *) + let open Int64_replace_polymorphic_compare in + not (Caml.Int64.bits_of_float x = Caml.Int64.bits_of_float (x -. 1.)) + +let lower_bound_for_int num_bits = + let exp = Caml.float_of_int ( num_bits - 1 ) in + let min_int_as_float = ~-. (2. ** exp) in + let open Int_replace_polymorphic_compare in + if num_bits - 1 < 53 (* 53 = #bits in the float's mantissa with sign included *) + then + begin + (* The smallest float that rounds towards zero to [min_int] is + [min_int - 1 + epsilon] *) + assert (is_x_minus_one_exact min_int_as_float); + one_ulp `Up (min_int_as_float -. 1.) + end + else + begin + (* [min_int_as_float] is already the smallest float [f] satisfying [f > min_int - 1]. *) + assert (not (is_x_minus_one_exact min_int_as_float)); + min_int_as_float + end + + +(* Float clamping is structured slightly differently than clamping for other types, so + that we get the behavior of [clamp_unchecked nan ~min ~max = nan] (for any [min] and + [max]) for free. +*) +let clamp_unchecked (t : float) ~min ~max = + if t < min then min + else if max < t then max + else t + +let box = + (* Prevent potential constant folding of [+. 0.] in the near ocamlopt future. *) + let x = if Random.bool () then 0. else 0. in + (fun f -> f +. x) + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Float_replace_polymorphic_compare diff --git a/src/floatable.ml b/src/floatable.ml new file mode 100644 index 0000000..1c77d93 --- /dev/null +++ b/src/floatable.ml @@ -0,0 +1,10 @@ +(** Functor that adds float conversion functions to a module. *) + +open! Import + +module type S = sig + type t + + val of_float : float -> t + val to_float : t -> float +end diff --git a/src/fn.ml b/src/fn.ml new file mode 100644 index 0000000..d7a2b32 --- /dev/null +++ b/src/fn.ml @@ -0,0 +1,30 @@ +open! Import + +let const c _ = c + +external ignore : _ -> unit = "%ignore" (* this has the same behavior as [Caml.ignore] *) + +let non f x = not (f x) + +let forever f = + let rec forever () = + f (); + forever () + in + try forever () + with e -> e + +external id : 'a -> 'a = "%identity" + +external ( |> ) : 'a -> ( 'a -> 'b) -> 'b = "%revapply" + +(* The typical use case for these functions is to pass in functional arguments and get + functions as a result. *) +let compose f g x = f (g x) + +let flip f x y = f y x + +let rec apply_n_times ~n f x = + if n <= 0 + then x + else apply_n_times ~n:(n - 1) f (f x) diff --git a/src/fn.mli b/src/fn.mli new file mode 100644 index 0000000..8ed42fa --- /dev/null +++ b/src/fn.mli @@ -0,0 +1,32 @@ +(** Various combinators for functions. *) + +open! Import + +(** A "pipe" operator. *) +external ( |> ) : 'a -> ( 'a -> 'b) -> 'b = "%revapply" + +(** Produces a function that just returns its first argument. *) +val const : 'a -> _ -> 'a + +(** [ignore] is the same as [Caml.ignore]. It is useful to have here so that code + that rebinds [ignore] can still refer to [Fn.ignore]. *) +external ignore : _ -> unit = "%ignore" + +(** Negates a function. *) +val non : ('a -> bool) -> 'a -> bool + +(** [forever f] runs [f ()] until it throws an exception and returns the + exception. This function is useful for read_line loops, etc. *) +val forever : (unit -> unit) -> exn + +(** [apply_n_times ~n f x] is the [n]-fold application of [f] to [x]. *) +val apply_n_times : n:int -> ('a -> 'a) -> ('a -> 'a) + +(** The identity function. Also see [Sys.opaque_identity]. *) +external id : 'a -> 'a = "%identity" + +(** [compose f g x] is [f (g x)]. *) +val compose : ('b -> 'c) -> ('a -> 'b) -> ('a -> 'c) + +(** Reverses the order of arguments for a binary function. *) +val flip : ('a -> 'b -> 'c) -> ('b -> 'a -> 'c) diff --git a/src/formatter.ml b/src/formatter.ml new file mode 100644 index 0000000..0f4b939 --- /dev/null +++ b/src/formatter.ml @@ -0,0 +1 @@ +type t = Caml.Format.formatter diff --git a/src/formatter.mli b/src/formatter.mli new file mode 100644 index 0000000..fdd1daf --- /dev/null +++ b/src/formatter.mli @@ -0,0 +1,9 @@ +(** The [Format.formatter] type from OCaml's standard library, exported here + for convenience and compatibility with other libraries. + + The [Format] module itself is deprecated in Base. You may refer to it + explicitly through [Caml.Format], though you may wish to search for other + alternatives for constructing pretty-printers using the [Format.formatter] + type. *) + +type t = Caml.Format.formatter diff --git a/src/hash.ml b/src/hash.ml new file mode 100644 index 0000000..fcd56db --- /dev/null +++ b/src/hash.ml @@ -0,0 +1,231 @@ +(* + This is the interface to the runtime support for [ppx_hash]. + + The [ppx_hash] syntax extension supports: [@@deriving_inline hash][@@@end] and [%hash_fold: TYPE] and + [%hash: TYPE] + + For type [t] a function [hash_fold_t] of type [Hash.state -> t -> Hash.state] is + generated. + + The generated [hash_fold_] function is compositional, following the structure of the + type; allowing user overrides at every level. This is in contrast to ocaml's builtin + polymorphic hashing [Hashtbl.hash] which ignores user overrides. + + The generator also provides a direct hash-function [hash] (named [hash_] when != + "t") of type: [t -> Hash.hash_value]. + + The folding hash function can be accessed as [%hash_fold: TYPE] + The direct hash function can be accessed as [%hash: TYPE] +*) + +open! Import0 + +module Array = Array0 +module Char = Char0 +module Int = Int0 +module List = List0 + +include Hash_intf + +(** Builtin folding-style hash functions, abstracted over [Hash_intf.S] *) +module Folding (Hash : Hash_intf.S) + : Hash_intf.Builtin_intf + with type state = Hash.state + and type hash_value = Hash.hash_value += struct + + type state = Hash.state + type hash_value = Hash.hash_value + type 'a folder = state -> 'a -> state + + let hash_fold_unit s () = s + + let hash_fold_int = Hash.fold_int + let hash_fold_int64 = Hash.fold_int64 + let hash_fold_float = Hash.fold_float + let hash_fold_string = Hash.fold_string + + let as_int f s x = hash_fold_int s (f x) + + (* This ignores the sign bit on 32-bit architectures, but it's unlikely to lead to + frequent collisions (min_value colliding with 0 is the most likely one). *) + let hash_fold_int32 = as_int Caml.Int32.to_int + + let hash_fold_char = as_int Char.to_int + let hash_fold_bool = as_int (function true -> 1 | false -> 0) + + let hash_fold_nativeint s x = hash_fold_int64 s (Caml.Int64.of_nativeint x) + + let hash_fold_option hash_fold_elem s = function + | None -> hash_fold_int s 0 + | Some x -> hash_fold_elem (hash_fold_int s 1) x + + let rec hash_fold_list_body hash_fold_elem s list = + match list with + | [] -> s + | x::xs -> hash_fold_list_body hash_fold_elem (hash_fold_elem s x) xs + + let hash_fold_list hash_fold_elem s list = + (* The [length] of the list must be incorporated into the hash-state so values of + types such as [unit list] - ([], [()], [();()],..) are hashed differently. *) + (* The [length] must come before the elements to avoid a violation of the rule + enforced by Perfect_hash. *) + let s = hash_fold_int s (List.length list) in + let s = hash_fold_list_body hash_fold_elem s list in + s + + let hash_fold_lazy_t hash_fold_elem s x = + hash_fold_elem s (Caml.Lazy.force x) + + let hash_fold_ref_frozen hash_fold_elem s x = hash_fold_elem s (!x) + + let rec hash_fold_array_frozen_i hash_fold_elem s array i = + if i = Array.length array + then s + else + let e = Array.unsafe_get array i in + hash_fold_array_frozen_i hash_fold_elem (hash_fold_elem s e) array (i + 1) + + let hash_fold_array_frozen hash_fold_elem s array = + hash_fold_array_frozen_i + (* [length] must be incorporated for arrays, as it is for lists. See comment above *) + hash_fold_elem (hash_fold_int s (Array.length array)) array 0 + + (* the duplication here is because we think + ocaml can't eliminate indirect function calls otherwise. *) + let hash_nativeint x = + Hash.get_hash_value (hash_fold_nativeint (Hash.reset (Hash.alloc ())) x) + let hash_int64 x = + Hash.get_hash_value (hash_fold_int64 (Hash.reset (Hash.alloc ())) x) + let hash_int32 x = + Hash.get_hash_value (hash_fold_int32 (Hash.reset (Hash.alloc ())) x) + let hash_char x = + Hash.get_hash_value (hash_fold_char (Hash.reset (Hash.alloc ())) x) + let hash_int x = + Hash.get_hash_value (hash_fold_int (Hash.reset (Hash.alloc ())) x) + let hash_bool x = + Hash.get_hash_value (hash_fold_bool (Hash.reset (Hash.alloc ())) x) + let hash_string x = + Hash.get_hash_value (hash_fold_string (Hash.reset (Hash.alloc ())) x) + let hash_float x = + Hash.get_hash_value (hash_fold_float (Hash.reset (Hash.alloc ())) x) + let hash_unit x = + Hash.get_hash_value (hash_fold_unit (Hash.reset (Hash.alloc ())) x) + +end + +module F (Hash : Hash_intf.S) : + Hash_intf.Full + with type hash_value = Hash.hash_value + and type state = Hash.state + and type seed = Hash.seed += struct + + include Hash + + type 'a folder = state -> 'a -> state + + let create ?seed () = reset ?seed (alloc ()) + + let of_fold hash_fold_t = (fun t -> get_hash_value (hash_fold_t (create ()) t)) + + module Builtin = Folding(Hash) + + let run ?seed folder x = + Hash.get_hash_value (folder (Hash.reset ?seed (Hash.alloc ())) x) + +end + +module Internalhash : sig + include Hash_intf.S + with type state = private int (* allow optimizations for immediate type *) + and type seed = int + and type hash_value = int + + external fold_int64 : state -> int64 -> state = "Base_internalhash_fold_int64" [@@noalloc] + external fold_int : state -> int -> state = "Base_internalhash_fold_int" [@@noalloc] + external fold_float : state -> float -> state = "Base_internalhash_fold_float" [@@noalloc] + external fold_string : state -> string -> state = "Base_internalhash_fold_string" [@@noalloc] + external get_hash_value : state -> hash_value = "Base_internalhash_get_hash_value" [@@noalloc] +end = struct + let description = "internalhash" + + type state = int + type hash_value = int + type seed = int + + external create_seeded : seed -> state = "%identity" [@@noalloc] + external fold_int64 : state -> int64 -> state = "Base_internalhash_fold_int64" [@@noalloc] + external fold_int : state -> int -> state = "Base_internalhash_fold_int" [@@noalloc] + external fold_float : state -> float -> state = "Base_internalhash_fold_float" [@@noalloc] + external fold_string : state -> string -> state = "Base_internalhash_fold_string" [@@noalloc] + external get_hash_value : state -> hash_value = "Base_internalhash_get_hash_value" [@@noalloc] + + let alloc () = create_seeded 0 + + let reset ?(seed=0) _t = create_seeded seed + + module For_tests = struct + let compare_state = compare + let state_to_string = Int.to_string + end +end + +module T = struct + include Internalhash + type 'a folder = state -> 'a -> state + + let create ?seed () = reset ?seed (alloc ()) + + let run ?seed folder x = + get_hash_value (folder (reset ?seed (alloc ())) x) + + let of_fold hash_fold_t = (fun t -> get_hash_value (hash_fold_t (create ()) t)) + + module Builtin = struct + module Folding = Folding(Internalhash) + include + (Folding : Hash_intf.Builtin_hash_fold_intf + with type state := state + and type 'a folder := 'a folder) + + let hash_nativeint = Folding.hash_nativeint + let hash_int64 = Folding.hash_int64 + let hash_int32 = Folding.hash_int32 + let hash_string = Folding.hash_string + + (* [Folding] provides some default implementations for the [hash_*] functions below, + but they are inefficient for some use-cases because of the use of the [hash_fold] + functions. At this point, the [hash_value] type has been fixed to [int], so this + module can provide specialized implementations. *) + + let hash_char = Char0.to_int + + (* This hash was chosen from here: https://gist.github.com/badboy/6267743 + + It attempts to fulfill the primary goals of a non-cryptographic hash function: + + - a bit change in the input should change ~1/2 of the output bits + - the output should be uniformly distributed across the output range + - inputs that are close to each other shouldn't lead to outputs that are close to + each other. + - all bits of the input are used in generating the output + + In our case we also want it to be fast, non-allocating, and inlinable. *) + let [@inline always] hash_int (t : int) = + let t = (lnot t) + (t lsl 21) in + let t = t lxor (t lsr 24) in + let t = (t + (t lsl 3)) + (t lsl 8) in + let t = t lxor (t lsr 14) in + let t = (t + (t lsl 2)) + (t lsl 4) in + let t = t lxor (t lsr 28) in + t + (t lsl 31) + ;; + + let hash_bool x = if x then 1 else 0 + external hash_float : float -> int = "Base_hash_double" [@@noalloc] + let hash_unit () = 0 + end +end + +include T diff --git a/src/hash.mli b/src/hash.mli new file mode 100644 index 0000000..1b586b4 --- /dev/null +++ b/src/hash.mli @@ -0,0 +1 @@ +include Hash_intf.Hash (** @inline *) diff --git a/src/hash_intf.ml b/src/hash_intf.ml new file mode 100644 index 0000000..f0f9536 --- /dev/null +++ b/src/hash_intf.ml @@ -0,0 +1,196 @@ +(** [Hash_intf.S] is the interface which a hash function must support. + + The functions of [Hash_intf.S] are only allowed to be used in specific sequence: + + [alloc], [reset ?seed], [fold_..*], [get_hash_value], [reset ?seed], [fold_..*], + [get_hash_value], ... + + (The optional [seed]s passed to each reset may differ.) + + The chain of applications from [reset] to [get_hash_value] must be done in a + single-threaded manner (you can't use [fold_*] on a state that's been used + before). More precisely, [alloc ()] creates a new family of states. All functions that + take [t] and produce [t] return a new state from the same family. + + At any point in time, at most one state in the family is "valid". The other states are + "invalid". + + - The state returned by [alloc] is invalid. + - The state returned by [reset] is valid (all of the other states become invalid). + - The [fold_*] family of functions requires a valid state and produces a valid state + (thereby making the input state invalid). + - [get_hash_value] requires a valid state and makes it invalid. + + These requirements are currently formally encoded in the [Check_initialized_correctly] + module in bench/bench.ml. *) + +open! Import0 + +module type S = sig + + (** Name of the hash-function, e.g., "internalhash", "siphash" *) + val description : string + + (** [state] is the internal hash-state used by the hash function. *) + type state + + (** [fold_ state v] incorporates a value [v] of type into the hash-state, + returning a modified hash-state. Implementations of the [fold_] functions may + mutate the [state] argument in place, and return a reference to it. Implementations + of the fold_ functions should not allocate. *) + val fold_int : state -> int -> state + val fold_int64 : state -> int64 -> state + val fold_float : state -> float -> state + val fold_string : state -> string -> state + + (** [seed] is the type used to seed the initial hash-state. *) + type seed + + (** [alloc ()] returns a fresh uninitialized hash-state. May allocate. *) + val alloc : unit -> state + + (** [reset ?seed state] initializes/resets a hash-state with the given [seed], or else a + default-seed. Argument [state] may be mutated. Should not allocate. *) + val reset : ?seed:seed -> state -> state + + (** [hash_value] The type of hash values, returned by [get_hash_value]. *) + type hash_value + + (** [get_hash_value] extracts a hash-value from the hash-state. *) + val get_hash_value : state -> hash_value + + module For_tests : sig + val compare_state : state -> state -> int + val state_to_string : state -> string + end +end + +module type Builtin_hash_fold_intf = sig + type state + type 'a folder = state -> 'a -> state + + val hash_fold_nativeint : nativeint folder + val hash_fold_int64 : int64 folder + val hash_fold_int32 : int32 folder + val hash_fold_char : char folder + val hash_fold_int : int folder + val hash_fold_bool : bool folder + val hash_fold_string : string folder + val hash_fold_float : float folder + val hash_fold_unit : unit folder + + val hash_fold_option : 'a folder -> 'a option folder + val hash_fold_list : 'a folder -> 'a list folder + val hash_fold_lazy_t : 'a folder -> 'a lazy_t folder + + (** Hash support for [array] and [ref] is provided, but is potentially DANGEROUS, since + it incorporates the current contents of the array/ref into the hash value. Because + of this we add a [_frozen] suffix to the function name. + + Hash support for [string] is also potentially DANGEROUS, but strings are mutated + less often, so we don't append [_frozen] to it. + + Also note that we don't support [bytes]. *) + val hash_fold_ref_frozen : 'a folder -> 'a ref folder + val hash_fold_array_frozen : 'a folder -> 'a array folder + +end + +module type Builtin_hash_intf = sig + type hash_value + + val hash_nativeint : nativeint -> hash_value + val hash_int64 : int64 -> hash_value + val hash_int32 : int32 -> hash_value + val hash_char : char -> hash_value + val hash_int : int -> hash_value + val hash_bool : bool -> hash_value + val hash_string : string -> hash_value + val hash_float : float -> hash_value + val hash_unit : unit -> hash_value + +end + +module type Builtin_intf = sig + include Builtin_hash_fold_intf + include Builtin_hash_intf +end + +module type Full = sig + + include S (** @inline *) + + type 'a folder = state -> 'a -> state + + (** [create ?seed ()] is a convenience. Equivalent to [reset ?seed (alloc ())]. *) + val create : ?seed:seed -> unit -> state + + (** [of_fold fold] constructs a standard hash function from an existing fold + function. *) + val of_fold : (state -> 'a -> state) -> ('a -> hash_value) + + module Builtin : Builtin_intf + with type state := state + and type 'a folder := 'a folder + and type hash_value := hash_value + + (** [run ?seed folder x] runs [folder] on [x] in a newly allocated hash-state, + initialized using optional [seed] or a default-seed. + + The following identity exists: [run [%hash_fold: T]] == [[%hash: T]] + + [run] can be used if we wish to run a hash-folder with a non-default seed. *) + val run : ?seed:seed -> 'a folder -> 'a -> hash_value + +end + +module type Hash = sig + module type Full = Full + module type S = S + + module F (Hash : S) : Full + with type hash_value = Hash.hash_value + and type state = Hash.state + and type seed = Hash.seed + + (** The code of [ppx_hash] is agnostic to the choice of hash algorithm that is + used. However, it is not currently possible to mix various choices of hash algorithms + in a given code base. + + We experimented with: + - (a) custom hash algorithms implemented in OCaml and + - (b) in C; + - (c) OCaml's internal hash function (which is a custom version of Murmur3, + implemented in C); + - (d) siphash, a modern hash function implemented in C. + + Our findings were as follows: + + - Implementing our own custom hash algorithms in OCaml and C yielded very little + performance improvement over the (c) proposal, without providing the benefit of being + a peer-reviewed, widely used hash function. + + - Siphash (a modern hash function with an internal state of 32 bytes) has a worse + performance profile than (a,b,c) above (hashing takes more time). Since its internal + state is bigger than an OCaml immediate value, one must either manage allocation of + such state explicitly, or paying the cost of allocation each time a hash is computed. + While being a supposedly good hash function (with good hash quality), this quality was + not translated in measurable improvements in our macro benchmarks. (Also, based on + the data available at the time of writing, it's unclear that other hash algorithms in + this class would be more than marginally faster.) + + - By contrast, using the internal combinators of OCaml hash function means that we do + not allocate (the internal state of this hash function is 32 bit) and have the same + quality and performance as Hashtbl.hash. + + Hence, we are here making the choice of using this Internalhash (that is, Murmur3, the + OCaml hash algorithm as of 4.03) as our hash algorithm. It means that the state of the + hash function does not need to be preallocated, and makes for simpler use in hash + tables and other structures. *) + + include Full + with type state = private int + and type seed = int + + and type hash_value = int (** @open *) +end diff --git a/src/hash_set.ml b/src/hash_set.ml new file mode 100644 index 0000000..8c416df --- /dev/null +++ b/src/hash_set.ml @@ -0,0 +1,220 @@ +open! Import + +include Hash_set_intf + +let hashable_s = Hashtbl.hashable_s +let hashable = Hashtbl.Private.hashable +let poly_hashable = Hashtbl.Poly.hashable + +let with_return = With_return.with_return + +type 'a t = ('a, unit) Hashtbl.t +type 'a hash_set = 'a t +type 'a elt = 'a + +module Accessors = struct + + let hashable = hashable + let clear = Hashtbl.clear + let length = Hashtbl.length + let mem = Hashtbl.mem + + let is_empty t = Hashtbl.is_empty t + + let find_map t ~f = + with_return (fun r -> + Hashtbl.iter_keys t ~f:(fun elt -> + match f elt with + | None -> () + | Some _ as o -> r.return o); + None) + ;; + + let find t ~f = find_map t ~f:(fun a -> if f a then Some a else None) + + let add t k = Hashtbl.set t ~key:k ~data:() + + let strict_add t k = + if mem t k then Or_error.error_string "element already exists" + else begin + Hashtbl.set t ~key:k ~data:(); + Result.Ok () + end + ;; + + let strict_add_exn t k = Or_error.ok_exn (strict_add t k) + + let remove = Hashtbl.remove + + let strict_remove t k = + if mem t k then begin + remove t k; + Result.Ok () + end else + Or_error.error "element not in set" k (Hashtbl.sexp_of_key t) + ;; + + let strict_remove_exn t k = Or_error.ok_exn (strict_remove t k) + + let fold t ~init ~f = Hashtbl.fold t ~init ~f:(fun ~key ~data:() acc -> f acc key) + let iter t ~f = Hashtbl.iter_keys t ~f + + let count t ~f = Container.count ~fold t ~f + let sum m t ~f = Container.sum ~fold m t ~f + let min_elt t ~compare = Container.min_elt ~fold t ~compare + let max_elt t ~compare = Container.max_elt ~fold t ~compare + let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t + let fold_until t ~init ~f = Container.fold_until ~fold ~init ~f t + + let to_list = Hashtbl.keys + + let sexp_of_t sexp_of_e t = + sexp_of_list sexp_of_e (to_list t |> List.sort ~compare:(hashable t).compare) + + let to_array t = + let len = length t in + let index = ref (len - 1) in + fold t ~init:[||] ~f:(fun acc key -> + if Array.length acc = 0 then Array.create ~len key + else begin + index := !index - 1; + Array.set acc (!index) key; + acc + end) + + let exists t ~f = Hashtbl.existsi t ~f:(fun ~key ~data:() -> f key) + let for_all t ~f = not (Hashtbl.existsi t ~f:(fun ~key ~data:() -> not (f key))) + + let equal t1 t2 = Hashtbl.equal t1 t2 (fun () () -> true) + + let copy t = Hashtbl.copy t + + let filter t ~f = Hashtbl.filteri t ~f:(fun ~key ~data:() -> f key) + + let diff t1 t2 = filter t1 ~f:(fun key -> not (Hashtbl.mem t2 key)) + + let inter t1 t2 = + let smaller, larger = if length t1 > length t2 then (t2, t1) else (t1, t2) in + Hashtbl.filteri smaller ~f:(fun ~key ~data:() -> Hashtbl.mem larger key) + + let filter_inplace t ~f = + let to_remove = + fold t ~init:[] ~f:(fun ac x -> + if f x then ac else x :: ac) + in + List.iter to_remove ~f:(fun x -> remove t x) + ;; + + let of_hashtbl_keys hashtbl = Hashtbl.map hashtbl ~f:ignore + + let to_hashtbl t ~f = Hashtbl.mapi t ~f:(fun ~key ~data:() -> f key) +end + +include Accessors + +let create ?growth_allowed ?size m = + Hashtbl.create ?growth_allowed ?size m +;; + +let of_list ?growth_allowed ?size m l = + let size = match size with Some x -> x | None -> List.length l in + let t = Hashtbl.create ?growth_allowed ~size m in + List.iter l ~f:(fun k -> add t k); + t +;; + +let t_of_sexp m e_of_sexp sexp = + match sexp with + | Sexp.Atom _ -> + raise (Of_sexp_error (Failure "Hash_set.t_of_sexp requires a list", sexp)) + | Sexp.List list -> + let t = create m ~size:(List.length list) in + List.iter list ~f:(fun sexp -> + let e = e_of_sexp sexp in + match strict_add t e with + | Ok () -> () + | Error _ -> + raise (Of_sexp_error + (Error.to_exn + (Error.create "Hash_set.t_of_sexp got a duplicate element" + sexp Fn.id), + sexp))); + t +;; + +module Creators (Elt : sig + type 'a t + + val hashable : 'a t Hashable.t + end) : sig + + type 'a t_ = 'a Elt.t t + + val t_of_sexp : (Sexp.t -> 'a Elt.t) -> Sexp.t -> 'a t_ + + include Creators_generic + with type 'a t := 'a t_ + with type 'a elt := 'a Elt.t + with type ('elt, 'z) create_options := ('elt, 'z) create_options_without_first_class_module + +end = struct + + type 'a t_ = 'a Elt.t t + + let create ?growth_allowed ?size () = + create ?growth_allowed ?size (Hashable.to_key Elt.hashable) + + let of_list ?growth_allowed ?size l = + of_list ?growth_allowed ?size (Hashable.to_key Elt.hashable) l + + let t_of_sexp e_of_sexp sexp = + t_of_sexp (Hashable.to_key Elt.hashable) e_of_sexp sexp +end + +module Poly = struct + + type 'a t = 'a hash_set + + type 'a elt = 'a + + let hashable = poly_hashable + + include Creators (struct + type 'a t = 'a + let hashable = hashable + end) + + include Accessors + + let sexp_of_t = sexp_of_t + +end + +module M (Elt : T.T) = struct + type nonrec t = Elt.t t +end +module type Sexp_of_m = sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] +end +module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] + include Hashtbl_intf.Key with type t := t +end + +let sexp_of_m__t (type elt) (module Elt : Sexp_of_m with type t = elt) t = + sexp_of_t Elt.sexp_of_t t + +let m__t_of_sexp (type elt) (module Elt : M_of_sexp with type t = elt) sexp = + t_of_sexp (module Elt) Elt.t_of_sexp sexp + +module Private = struct + let hashable = Hashtbl.Private.hashable +end diff --git a/src/hash_set.mli b/src/hash_set.mli new file mode 100644 index 0000000..f69262f --- /dev/null +++ b/src/hash_set.mli @@ -0,0 +1 @@ +include Hash_set_intf.Hash_set (** @inline *) diff --git a/src/hash_set_intf.ml b/src/hash_set_intf.ml new file mode 100644 index 0000000..cbda416 --- /dev/null +++ b/src/hash_set_intf.ml @@ -0,0 +1,204 @@ +open! Import + +module type Key = Hashtbl_intf.Key + +module type Accessors = sig + include Container.Generic + + val mem : 'a t -> 'a -> bool (** override [Container.Generic.mem] *) + + val copy : 'a t -> 'a t (** preserves the equality function *) + + val add : 'a t -> 'a -> unit + + (** [strict_add t x] returns [Ok ()] if the [x] was not in [t], or an [Error] if it + was. *) + val strict_add : 'a t -> 'a -> unit Or_error.t + val strict_add_exn : 'a t -> 'a -> unit + + val remove : 'a t -> 'a -> unit + + (** [strict_remove t x] returns [Ok ()] if the [x] was in [t], or an [Error] if it + was not. *) + val strict_remove : 'a t -> 'a -> unit Or_error.t + val strict_remove_exn : 'a t -> 'a -> unit + + val clear : 'a t -> unit + val equal : 'a t -> 'a t -> bool + val filter : 'a t -> f:('a -> bool) -> 'a t + val filter_inplace : 'a t -> f:('a -> bool) -> unit + + (** [inter t1 t2] computes the set intersection of [t1] and [t2]. Runs in O(min(length + t1, length t2)). Behavior is undefined if [t1] and [t2] don't have the same + equality function. *) + val inter : 'key t -> 'key t -> 'key t + val diff : 'a t -> 'a t -> 'a t + + val of_hashtbl_keys : ('a, _) Hashtbl.t -> 'a t + val to_hashtbl : 'key t -> f:('key -> 'data) -> ('key, 'data) Hashtbl.t +end + +type ('key, 'z) create_options = + ('key, unit, 'z) Hashtbl_intf.create_options + +type ('key, 'z) create_options_without_first_class_module = + ('key, unit, 'z) Hashtbl_intf.create_options_without_first_class_module + +module type Creators = sig + type 'a t + + val create + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> 'a t + val of_list + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> 'a list + -> 'a t +end + +module type Creators_generic = sig + type 'a t + type 'a elt + type ('a, 'z) create_options + + val create : ('a, unit -> 'a t) create_options + val of_list : ('a, 'a elt list -> 'a t) create_options +end + +module Check = struct + module Make_creators_check (Type : T.T1) (Elt : T.T1) (Options : T.T2) + (M : Creators_generic + with type 'a t := 'a Type.t + with type 'a elt := 'a Elt.t + with type ('a, 'z) create_options := ('a, 'z) Options.t) + = struct end + + module Check_creators_is_specialization_of_creators_generic (M : Creators) = + Make_creators_check + (struct type 'a t = 'a M.t end) + (struct type 'a t = 'a end) + (struct type ('a, 'z) t = ('a, 'z) create_options end) + (struct + include M + + let create ?growth_allowed ?size m () = + create ?growth_allowed ?size m + end) +end + +module type Hash_set = sig + + type 'a t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** We use [[@@deriving_inline sexp_of][@@@end]] but not [[@@deriving sexp]] because we want people to be + explicit about the hash and comparison functions used when creating hashtables. One + can use [Hash_set.Poly.t], which does have [[@@deriving_inline sexp][@@@end]], to use polymorphic + comparison and hashing. *) + + module type Creators = Creators + module type Creators_generic = Creators_generic + + type nonrec ('key, 'z) create_options = + ('key, 'z) create_options + + include Creators + with type 'a t := 'a t (** @open *) + + module type Accessors = Accessors + + include Accessors with type 'a t := 'a t with type 'a elt = 'a (** @open *) + + val hashable_s : 'key t -> (module Hashtbl_intf.Key with type t = 'key) + + type nonrec ('key, 'z) create_options_without_first_class_module = + ('key, 'z) create_options_without_first_class_module + + (** A hash set that uses polymorphic comparison *) + module Poly : sig + + type nonrec 'a t = 'a t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Creators_generic + with type 'a t := 'a t + with type 'a elt = 'a + with type ('key, 'z) create_options := + ('key, 'z) create_options_without_first_class_module + + include Accessors with type 'a t := 'a t with type 'a elt := 'a elt + + end + + (** [M] is meant to be used in combination with OCaml applicative functor types: + + {[ + type string_hash_set = Hash_set.M(String).t + ]} + + which stands for: + + {[ + type string_hash_set = (String.t, int) Hash_set.t + ]} + + The point is that [Hash_set.M(String).t] supports deriving, whereas the second + syntax doesn't (because [t_of_sexp] doesn't know what comparison/hash function to + use). *) + module M (Elt : T.T) : sig + type nonrec t = Elt.t t + end + module type Sexp_of_m = sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end + module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] + include Hashtbl_intf.Key with type t := t + end + val sexp_of_m__t : (module Sexp_of_m with type t = 'elt) -> 'elt t -> Sexp.t + val m__t_of_sexp : (module M_of_sexp with type t = 'elt) -> Sexp.t -> 'elt t + + module Creators (Elt : sig + type 'a t + val hashable : 'a t Hashable.t + end) : sig + type 'a t_ = 'a Elt.t t + val t_of_sexp : (Sexp.t -> 'a Elt.t) -> Sexp.t -> 'a t_ + include Creators_generic + with type 'a t := 'a t_ + with type 'a elt := 'a Elt.t + with type ('elt, 'z) create_options := + ('elt, 'z) create_options_without_first_class_module + end + + (**/**) + (*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) + module Private : sig + val hashable : 'a t -> 'a Hashable.t + end +end diff --git a/src/hash_stubs.c b/src/hash_stubs.c new file mode 100644 index 0000000..3aefec5 --- /dev/null +++ b/src/hash_stubs.c @@ -0,0 +1,26 @@ +#include +#include +#include + +/* Final mix and return from the hash.c implementation from INRIA */ +#define FINAL_MIX_AND_RETURN(h) \ + h ^= h >> 16; \ + h *= 0x85ebca6b; \ + h ^= h >> 13; \ + h *= 0xc2b2ae35; \ + h ^= h >> 16; \ + return Val_int(h & 0x3FFFFFFFU); + +CAMLprim value Base_hash_string (value string) +{ + uint32_t h; + h = caml_hash_mix_string (0, string); + FINAL_MIX_AND_RETURN(h) +} + +CAMLprim value Base_hash_double (value d) +{ + uint32_t h; + h = caml_hash_mix_double (0, Double_val(d)); + FINAL_MIX_AND_RETURN (h); +} diff --git a/src/hashable.ml b/src/hashable.ml new file mode 100644 index 0000000..3b7b275 --- /dev/null +++ b/src/hashable.ml @@ -0,0 +1,4 @@ +open! Import + + +include Hashable_intf diff --git a/src/hashable.mli b/src/hashable.mli new file mode 100644 index 0000000..5678db4 --- /dev/null +++ b/src/hashable.mli @@ -0,0 +1,5 @@ +open! Import + +module type Key = Hashable_intf.Key +module type Hashable = Hashable_intf.Hashable +include Hashable (** @inline *) diff --git a/src/hashable_intf.ml b/src/hashable_intf.ml new file mode 100644 index 0000000..cca6dfe --- /dev/null +++ b/src/hashable_intf.ml @@ -0,0 +1,93 @@ +open! Import + +module type Key = sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** Values returned by [hash] must be non-negative. An exception will be raised in the + case that [hash] returns a negative value. *) + val hash : t -> int +end + +module Hashable = struct + type 'a t = + { hash : 'a -> int; + compare : 'a -> 'a -> int; + sexp_of_t : 'a -> Sexp.t; + } + + (** This function is sound but not complete, meaning that if it returns [true] then it's + safe to use the two interchangeably. If it's [false], you have no guarantees. For + example: + + {[ + > utop + open Core;; + let equal (a : 'a Hashtbl_intf.Hashable.t) b = + phys_equal a b + || (phys_equal a.hash b.hash + && phys_equal a.compare b.compare + && phys_equal a.sexp_of_t b.sexp_of_t) + ;; + let a = Hashtbl_intf.Hashable.{ hash; compare; sexp_of_t = Int.sexp_of_t };; + let b = Hashtbl_intf.Hashable.{ hash; compare; sexp_of_t = Int.sexp_of_t };; + equal a b;; (* false?! *) + ]} + *) + let equal a b = + phys_equal a b + || (phys_equal a.hash b.hash + && phys_equal a.compare b.compare + && phys_equal a.sexp_of_t b.sexp_of_t) + ;; + + let hash_param = Caml.Hashtbl.hash_param + let hash = Caml.Hashtbl.hash + + let poly = { hash; + compare = Poly.compare; + sexp_of_t = (fun _ -> Sexp.Atom "_"); + } + + let of_key (type a) (module Key : Key with type t = a) = + { hash = Key.hash; + compare = Key.compare; + sexp_of_t = Key.sexp_of_t; + } + ;; + + let to_key (type a) { hash; compare; sexp_of_t } = + (module struct + type t = a + let hash = hash + let compare = compare + let sexp_of_t = sexp_of_t + end : Key with type t = a) + ;; +end +include Hashable + +module type Hashable = sig + type 'a t = 'a Hashable.t = + { hash : 'a -> int; + compare : 'a -> 'a -> int; + sexp_of_t : 'a -> Sexp.t; + } + + val equal : 'a t -> 'a t -> bool + + val poly : 'a t + + val of_key : (module Key with type t = 'a) -> 'a t + val to_key : 'a t -> (module Key with type t = 'a) + + val hash_param : int -> int -> 'a -> int + + val hash : 'a -> int +end diff --git a/src/hasher.ml b/src/hasher.ml new file mode 100644 index 0000000..06e04ef --- /dev/null +++ b/src/hasher.ml @@ -0,0 +1,56 @@ +open! Import + +(** Signatures required of types which can be used in [[@@deriving_inline hash][@@@end]]. *) +(*_ JS-only: For a more in-depth discussion, see documentation of ppx_hash, available in + ppx/ppx_hash/README.md and ppx/ppx_hash/doc/design.notes. *) + +module type S = sig + + (** The type that is hashed. *) + type t + + (** [hash_fold_t state x] mixes the content of [x] into the [state]. + + By default, all our [hash_fold_t] functions (derived or not) should satisfy the + following properties. + + 1. [hash_fold_t state x] should mix all the information present in [x] in the state. + That is, by default, [hash_fold_t] will traverse the full term [x] (this is a + significant change for Hashtbl.hash which by default stops traversing the term after + after considering a small number of "significant values"). [hash_fold_t] must not + discard the [state]. + + 2. [hash_fold_t] must be compatible with the associated [compare] function: that is, + for all [x] [y] and [s], [compare x y = 0] must imply [hash_fold_t s x = hash_fold_t + s y]. + + 3. To avoid avoid systematic collisions, [hash_fold_t] should expand to different + sequences of built-in mixing functions for different values of [x]. No such sequence + is allowed to be a prefix of another. + + A common mistake is to implement [hash_fold_t] of a collection by just folding all + the elements. This makes the folding sequence of [a] be a prefix of [a @ b], thereby + violating the requirement. This creates large families of collisions: all of the + following collections would hash the same: + + {v + [[]; [1;2;3]] + [[1]; [2;3]] + [[1; 2]; [3]] + [[1; 2; 3]; []] + [[1]; [2]; []; [3];] + ... + v} + + A good way to avoid this is to mix in the size of the collection to the beginning + ([fold ~init:(hash_fold_int state length) ~f:hash_fold_elem]). The default in our + libraries is to mix the length of the structure before folding. To prevent the + aforementioned collisions, one should respect this ordering. + *) + val hash_fold_t : Hash.state -> t -> Hash.state +end + +module type S1 = sig + type 'a t + val hash_fold_t : (Hash.state -> 'a -> Hash.state) -> Hash.state -> 'a t -> Hash.state +end diff --git a/src/hashtbl.ml b/src/hashtbl.ml new file mode 100644 index 0000000..dcc09ac --- /dev/null +++ b/src/hashtbl.ml @@ -0,0 +1,880 @@ +open! Import + +include Hashtbl_intf + +let with_return = With_return.with_return + +let hash_param = Hashable.hash_param +let hash = Hashable.hash + + +type ('k, 'v) t = + { mutable table : ('k, 'v) Avltree.t array + ; mutable length : int + (* [recently_added] is the reference passed to [Avltree.add]. We put it in the hash + table to avoid allocating it at every [set]. *) + ; recently_added : bool ref + ; growth_allowed : bool + ; hashable : 'k Hashable.t + ; mutable mutation_allowed : bool (* Set during all iteration operations *) + } + +type 'a key = 'a + +let sexp_of_key t = t.hashable.Hashable.sexp_of_t +let compare_key t = t.hashable.Hashable.compare + +let ensure_mutation_allowed t = + if not t.mutation_allowed then failwith "Hashtbl: mutation not allowed during iteration" +;; + +let without_mutating t f = + if t.mutation_allowed + then + begin + t.mutation_allowed <- false; + match f () with + | x -> t.mutation_allowed <- true; x + | exception exn -> t.mutation_allowed <- true; raise exn + end + else + f () +;; + +(** Internally use a maximum size that is a power of 2. Reverses the above to find the + floor power of 2 below the system max array length *) +let max_table_length = Int.floor_pow2 Array.max_length ;; + +let create ?(growth_allowed = true) ?(size = 128) ~hashable () = + let size = Int.min (Int.max 1 size) max_table_length in + let size = Int.ceil_pow2 size in + { table = Array.create ~len:size Avltree.empty + ; length = 0 + ; growth_allowed = growth_allowed + ; recently_added = ref false + ; hashable + ; mutation_allowed = true + } +;; + +(** Supplemental hash. This may not be necessary, it is intended as a defense against poor + hash functions, for which the power of 2 sized table will be especially sensitive. + With some testing we may choose to add it, but this table is designed to be robust to + collisions, and in most of my testing this degrades performance. *) +let _supplemental_hash h = + let h = h lxor ((h lsr 20) lxor (h lsr 12)) in + h lxor (h lsr 7) lxor (h lsr 4) +;; + +let slot t key = + let hash = t.hashable.Hashable.hash key in + (* this is always non-negative because we do [land] with non-negative number *) + hash land ((Array.length t.table) - 1) +;; + +let add_worker t ~replace ~key ~data = + let i = slot t key in + let root = t.table.(i) in + let added = t.recently_added in + added := false; + let new_root = + (* The avl tree might replace the value [replace=true] or do nothing [replace=false] + to the entry, in that case the table did not get bigger, so we should not + increment length, we pass in the bool ref t.added so that it can tell us whether + it added or replaced. We do it this way to avoid extra allocation. Since the bool + is an immediate it does not go through the write barrier. *) + Avltree.add ~replace root ~compare:(compare_key t) ~added ~key ~data + in + if !added then + t.length <- t.length + 1; + (* This little optimization saves a caml_modify when the tree + hasn't been rebalanced. *) + if not (phys_equal new_root root) then + t.table.(i) <- new_root +;; + +let maybe_resize_table t = + let len = Array.length t.table in + let should_grow = t.length > len in + if should_grow && t.growth_allowed then begin + let new_array_length = Int.min (len * 2) max_table_length in + if new_array_length > len then begin + let new_table = + Array.create ~len:new_array_length Avltree.empty + in + let old_table = t.table in + t.table <- new_table; + t.length <- 0; + let f ~key ~data = add_worker ~replace:true t ~key ~data in + for i = 0 to Array.length old_table - 1 do + Avltree.iter old_table.(i) ~f + done + end + end +;; + +let set t ~key ~data = + ensure_mutation_allowed t; + add_worker ~replace:true t ~key ~data; + maybe_resize_table t +;; + +let replace = set + +let add t ~key ~data = + ensure_mutation_allowed t; + add_worker ~replace:false t ~key ~data; + if !(t.recently_added) then begin + maybe_resize_table t; + `Ok + end else + `Duplicate +;; + +let add_exn t ~key ~data = + match add t ~key ~data with + | `Ok -> () + | `Duplicate -> + let sexp_of_key = sexp_of_key t in + let error = Error.create "Hashtbl.add_exn got key already present" key sexp_of_key in + Error.raise error +;; + +let clear t = + ensure_mutation_allowed t; + for i = 0 to Array.length t.table - 1 do + t.table.(i) <- Avltree.empty; + done; + t.length <- 0 +;; + +let find_and_call t key ~if_found ~if_not_found = + (* with a good hash function these first two cases will be the overwhelming majority, + and Avltree.find is recursive, so it can't be inlined, so doing this avoids a + function call in most cases. *) + match t.table.(slot t key) with + | Avltree.Empty -> if_not_found key + | Avltree.Leaf { key = k; value = v } -> + if compare_key t k key = 0 then if_found v + else if_not_found key + | tree -> + Avltree.find_and_call tree ~compare:(compare_key t) key ~if_found ~if_not_found +;; + +let findi_and_call t key ~if_found ~if_not_found = + (* with a good hash function these first two cases will be the overwhelming majority, + and Avltree.find is recursive, so it can't be inlined, so doing this avoids a + function call in most cases. *) + match t.table.(slot t key) with + | Avltree.Empty -> if_not_found key + | Avltree.Leaf { key = k; value = v } -> + if compare_key t k key = 0 then if_found ~key:k ~data:v + else if_not_found key + | tree -> + Avltree.findi_and_call tree ~compare:(compare_key t) key ~if_found ~if_not_found +;; + +let find = + let if_found v = Some v in + let if_not_found _ = None in + fun t key -> + find_and_call t key ~if_found ~if_not_found +;; + +let mem t key = + match t.table.(slot t key) with + | Avltree.Empty -> false + | Avltree.Leaf { key = k; value = _ } -> compare_key t k key = 0 + | tree -> Avltree.mem tree ~compare:(compare_key t) key +;; + +let remove t key = + ensure_mutation_allowed t; + let i = slot t key in + let root = t.table.(i) in + let added_or_removed = t.recently_added in + added_or_removed := false; + let new_root = + Avltree.remove root + ~removed:added_or_removed ~compare:(compare_key t) key + in + if not (phys_equal root new_root) then + t.table.(i) <- new_root; + if !added_or_removed then + t.length <- t.length - 1 +;; + +let length t = t.length + +let is_empty t = length t = 0 + +let fold t ~init ~f = + if length t = 0 then init + else begin + let n = Array.length t.table in + let acc = ref init in + let m = t.mutation_allowed in + match + t.mutation_allowed <- false; + for i = 0 to n - 1 do + match Array.unsafe_get t.table i with + | Avltree.Empty -> () + | Avltree.Leaf { key; value = data } -> acc := f ~key ~data !acc + | bucket -> acc := Avltree.fold bucket ~init:!acc ~f + done + with + | () -> + t.mutation_allowed <- m; + !acc + | exception exn -> + t.mutation_allowed <- m; + raise exn + end +;; + +let iteri t ~f = + if t.length = 0 then () + else begin + let n = Array.length t.table in + let m = t.mutation_allowed in + match + t.mutation_allowed <- false; + for i = 0 to n - 1 do + match Array.unsafe_get t.table i with + | Avltree.Empty -> () + | Avltree.Leaf { key; value = data } -> f ~key ~data + | bucket -> Avltree.iter bucket ~f + done + with + | () -> + t.mutation_allowed <- m + | exception exn -> + t.mutation_allowed <- m; + raise exn + end +;; + +let iter t ~f = iteri t ~f:(fun ~key:_ ~data -> f data) +let iter_keys t ~f = iteri t ~f:(fun ~key ~data:_ -> f key) + +let invariant invariant_key invariant_data t = + for i = 0 to Array.length t.table - 1 do + Avltree.invariant t.table.(i) ~compare:(compare_key t) + done; + let real_len = + fold t ~init:0 ~f:(fun ~key ~data i -> + invariant_key key; + invariant_data data; + i + 1) + in + assert (real_len = t.length); +;; + +let find_exn = + let if_found v = v in + let if_not_found _ = raise Caml.Not_found in + fun t key -> + find_and_call t key ~if_found ~if_not_found +;; + +let existsi t ~f = + with_return (fun r -> + iteri t ~f:(fun ~key ~data -> if f ~key ~data then r.return true); + false) +;; + +let exists t ~f = existsi t ~f:(fun ~key:_ ~data -> f data) +;; + +let for_alli t ~f = not (existsi t ~f:(fun ~key ~data -> not (f ~key ~data))) +let for_all t ~f = not (existsi t ~f:(fun ~key:_ ~data -> not (f data))) + +let counti t ~f = + fold t ~init:0 ~f:(fun ~key ~data acc -> if f ~key ~data then acc+1 else acc) +let count t ~f = + fold t ~init:0 ~f:(fun ~key:_ ~data acc -> if f data then acc+1 else acc) + +let mapi t ~f = + let new_t = + create ~growth_allowed:t.growth_allowed + ~hashable:t.hashable ~size:t.length () + in + iteri t ~f:(fun ~key ~data -> replace new_t ~key ~data:(f ~key ~data)); + new_t + +let map t ~f = mapi t ~f:(fun ~key:_ ~data -> f data) + +let copy t = map t ~f:Fn.id + +let filter_mapi t ~f = + let new_t = + create ~growth_allowed:t.growth_allowed + ~hashable:t.hashable ~size:t.length () + in + iteri t ~f:(fun ~key ~data -> + match f ~key ~data with + | Some new_data -> replace new_t ~key ~data:new_data + | None -> ()); + new_t + +let filter_map t ~f = filter_mapi t ~f:(fun ~key:_ ~data -> f data) + +let filteri t ~f = + filter_mapi t ~f:(fun ~key ~data -> if f ~key ~data then Some data else None) +;; + +let filter t ~f = filteri t ~f:(fun ~key:_ ~data -> f data) +let filter_keys t ~f = filteri t ~f:(fun ~key ~data:_ -> f key) + +let partition_mapi t ~f = + let t0 = + create ~growth_allowed:t.growth_allowed + ~hashable:t.hashable ~size:t.length () + in + let t1 = + create ~growth_allowed:t.growth_allowed + ~hashable:t.hashable ~size:t.length () + in + iteri t ~f:(fun ~key ~data -> + match f ~key ~data with + | `Fst new_data -> replace t0 ~key ~data:new_data + | `Snd new_data -> replace t1 ~key ~data:new_data); + (t0, t1) +;; + +let partition_map t ~f = partition_mapi t ~f:(fun ~key:_ ~data -> f data) + +let partitioni_tf t ~f = + partition_mapi t ~f:(fun ~key ~data -> if f ~key ~data then `Fst data else `Snd data) +;; + +let partition_tf t ~f = partitioni_tf t ~f:(fun ~key:_ ~data -> f data) + +let find_or_add t id ~default = + match find t id with + | Some x -> x + | None -> + let default = default () in + replace t ~key:id ~data:default; + default + +let findi_or_add t id ~default = + match find t id with + | Some x -> x + | None -> + let default = default id in + replace t ~key:id ~data:default; + default + +(* Some hashtbl implementations may be able to perform this more efficiently than two + separate lookups *) +let find_and_remove t id = + let result = find t id in + if Option.is_some result then remove t id; + result + + +let change t id ~f = + match f (find t id) with + | None -> remove t id + | Some data -> replace t ~key:id ~data +;; + +let update t id ~f = + set t ~key:id ~data:(f (find t id)) +;; + +let incr_by ~remove_if_zero t key by = + if remove_if_zero + then + change t key ~f:(fun opt -> + match by + Option.value opt ~default:0 with + | 0 -> None + | n -> Some n) + else + update t key ~f:(function + | None -> by + | Some i -> by + i) +;; + +let incr ?(by = 1) ?(remove_if_zero = false) t key = incr_by ~remove_if_zero t key by +let decr ?(by = 1) ?(remove_if_zero = false) t key = incr_by ~remove_if_zero t key (-by) +;; + +let add_multi t ~key ~data = + update t key ~f:(function + | None -> [ data ] + | Some l -> data :: l) +;; + +let remove_multi t key = + match find t key with + | None -> () + | Some [] | Some [_] -> remove t key + | Some (_ :: tl) -> replace t ~key ~data:tl + +let find_multi t key = + match find t key with + | None -> [] + | Some l -> l + +let create_mapped ?growth_allowed ?size ~hashable ~get_key ~get_data rows = + let size = match size with Some s -> s | None -> List.length rows in + let res = create ?growth_allowed ~hashable ~size () in + let dupes = ref [] in + List.iter rows ~f:(fun r -> + let key = get_key r in + let data = get_data r in + if mem res key then + dupes := key :: !dupes + else + replace res ~key ~data); + match !dupes with + | [] -> `Ok res + | keys -> `Duplicate_keys (List.dedup_and_sort ~compare:hashable.Hashable.compare keys) +;; + +(* + {[ + let create_mapped_exn ?growth_allowed ?size ~hashable ~get_key ~get_data rows = + let size = match size with Some s -> s | None -> List.length rows in + let res = create ?growth_allowed ~size ~hashable () in + List.iter rows ~f:(fun r -> + let key = get_key r in + let data = get_data r in + if mem res key then + let sexp_of_key = hashable.Hashable.sexp_of_t in + failwiths "Hashtbl.create_mapped_exn: duplicate key" key <:sexp_of< key >> + else + replace res ~key ~data); + res + ;; + ]} *) + +let create_mapped_multi ?growth_allowed ?size ~hashable ~get_key ~get_data rows = + let size = match size with Some s -> s | None -> List.length rows in + let res = create ?growth_allowed ~size ~hashable () in + List.iter rows ~f:(fun r -> + let key = get_key r in + let data = get_data r in + add_multi res ~key ~data); + res +;; + +let of_alist ?growth_allowed ?size ~hashable lst = + match create_mapped ?growth_allowed ?size ~hashable ~get_key:fst ~get_data:snd lst with + | `Ok t -> `Ok t + | `Duplicate_keys k -> `Duplicate_key (List.hd_exn k) +;; + +let of_alist_report_all_dups ?growth_allowed ?size ~hashable lst = + create_mapped ?growth_allowed ?size ~hashable ~get_key:fst ~get_data:snd lst +;; + +let of_alist_or_error ?growth_allowed ?size ~hashable lst = + match of_alist ?growth_allowed ?size ~hashable lst with + | `Ok v -> Result.Ok v + | `Duplicate_key key -> + let sexp_of_key = hashable.Hashable.sexp_of_t in + Or_error.error "Hashtbl.of_alist_exn: duplicate key" key sexp_of_key +;; + +let of_alist_exn ?growth_allowed ?size ~hashable lst = + match of_alist_or_error ?growth_allowed ?size ~hashable lst with + | Result.Ok v -> v + | Result.Error e -> Error.raise e +;; + +let of_alist_multi ?growth_allowed ?size ~hashable lst = + create_mapped_multi ?growth_allowed ?size ~hashable ~get_key:fst ~get_data:snd lst +;; + +let to_alist t = fold ~f:(fun ~key ~data list -> (key, data) :: list) ~init:[] t + +let sexp_of_t sexp_of_key sexp_of_data t = + t + |> to_alist + |> List.sort ~compare:(fun (k1, _) (k2, _) -> t.hashable.compare k1 k2) + |> sexp_of_list (sexp_of_pair sexp_of_key sexp_of_data) +;; + +let t_of_sexp ~hashable k_of_sexp d_of_sexp sexp = + let alist = list_of_sexp (pair_of_sexp k_of_sexp d_of_sexp) sexp in + of_alist_exn ~hashable alist ~size:(List.length alist) +;; + +let validate ~name f t = Validate.alist ~name f (to_alist t) + +let keys t = fold t ~init:[] ~f:(fun ~key ~data:_ acc -> key :: acc) + +let data t = fold ~f:(fun ~key:_ ~data list -> data::list) ~init:[] t + +let add_to_groups groups ~get_key ~get_data ~combine ~rows = + List.iter rows ~f:(fun row -> + let key = get_key row in + let data = get_data row in + let data = + match find groups key with + | None -> data + | Some old -> combine old data + in + replace groups ~key ~data) +;; + +let group ?growth_allowed ?size ~hashable ~get_key ~get_data ~combine rows = + let res = create ?growth_allowed ?size ~hashable () in + add_to_groups res ~get_key ~get_data ~combine ~rows; + res +;; + +let create_with_key ?growth_allowed ?size ~hashable ~get_key rows = + create_mapped ?growth_allowed ?size ~hashable ~get_key ~get_data:Fn.id rows +;; + +let create_with_key_or_error ?growth_allowed ?size ~hashable ~get_key rows = + match create_with_key ?growth_allowed ?size ~hashable ~get_key rows with + | `Ok t -> Result.Ok t + | `Duplicate_keys keys -> + let sexp_of_key = hashable.Hashable.sexp_of_t in + Or_error.error_s + (Sexp.message "Hashtbl.create_with_key: duplicate keys" + [ "keys", sexp_of_list sexp_of_key keys ]) +;; + +let create_with_key_exn ?growth_allowed ?size ~hashable ~get_key rows = + Or_error.ok_exn (create_with_key_or_error ?growth_allowed ?size ~hashable ~get_key rows) +;; + +let merge = + let maybe_set t ~key ~f d = + match f ~key d with + | None -> () + | Some v -> + set t ~key ~data:v + in + fun t_left t_right ~f -> + if not (Hashable.equal t_left.hashable t_right.hashable) + then invalid_arg "Hashtbl.merge: different 'hashable' values"; + let new_t = + create ~growth_allowed:t_left.growth_allowed + ~hashable:t_left.hashable ~size:t_left.length () + in + without_mutating t_left (fun () -> + without_mutating t_right (fun () -> + iteri t_left ~f:(fun ~key ~data:left -> + match find t_right key with + | None -> + maybe_set new_t ~key ~f (`Left left) + | Some right -> + maybe_set new_t ~key ~f (`Both (left, right)) + ); + iteri t_right ~f:(fun ~key ~data:right -> + match find t_left key with + | None -> + maybe_set new_t ~key ~f (`Right right) + | Some _ -> () (* already done above *) + ))); + new_t +;; + +type 'a merge_into_action = Remove | Set_to of 'a + +let merge_into ~src ~dst ~f = + iteri src ~f:(fun ~key ~data -> + let dst_data = find dst key in + let action = without_mutating dst (fun () -> f ~key data dst_data) in + match action with + | Remove -> remove dst key + | Set_to data -> + match dst_data with + | None -> replace dst ~key ~data + | Some dst_data -> + if not (phys_equal dst_data data) + then replace dst ~key ~data) +;; + +let filteri_inplace t ~f = + let to_remove = + fold t ~init:[] ~f:(fun ~key ~data ac -> + if f ~key ~data then ac else key :: ac) + in + List.iter to_remove ~f:(fun key -> remove t key); +;; + +let filter_inplace t ~f = + filteri_inplace t ~f:(fun ~key:_ ~data -> f data) +;; + +let filter_keys_inplace t ~f = + filteri_inplace t ~f:(fun ~key ~data:_ -> f key) +;; + +let filter_mapi_inplace t ~f = + let map_results = + fold t ~init:[] ~f:(fun ~key ~data ac -> (key, f ~key ~data) :: ac) + in + List.iter map_results ~f:(fun (key,result) -> + match result with + | None -> remove t key + | Some data -> set t ~key ~data + ); +;; + +let filter_map_inplace t ~f = + filter_mapi_inplace t ~f:(fun ~key:_ ~data -> f data) + +let mapi_inplace t ~f = + ensure_mutation_allowed t; + without_mutating t (fun () -> Array.iter t.table ~f:(Avltree.mapi_inplace ~f)) + +let map_inplace t ~f = + mapi_inplace t ~f:(fun ~key:_ ~data -> f data) + +let equal t t' equal = + length t = length t' && + with_return (fun r -> + without_mutating t' (fun () -> + iteri t ~f:(fun ~key ~data -> + match find t' key with + | None -> r.return false + | Some data' -> + if not (equal data data') + then r.return false)); + true) +;; + +let similar = equal + +module Accessors = struct + + type nonrec 'a merge_into_action = 'a merge_into_action = Remove | Set_to of 'a + + let invariant = invariant + let clear = clear + let copy = copy + let remove = remove + let set = set + let add = add + let add_exn = add_exn + let change = change + let update = update + let add_multi = add_multi + let remove_multi = remove_multi + let find_multi = find_multi + let mem = mem + let iter_keys = iter_keys + let iter = iter + let iteri = iteri + let exists = exists + let existsi = existsi + let for_all = for_all + let for_alli = for_alli + let count = count + let counti = counti + let fold = fold + let length = length + let is_empty = is_empty + let map = map + let mapi = mapi + let filter_map = filter_map + let filter_mapi = filter_mapi + let filter_keys = filter_keys + let filter = filter + let filteri = filteri + let partition_map = partition_map + let partition_mapi = partition_mapi + let partition_tf = partition_tf + let partitioni_tf = partitioni_tf + let find_or_add = find_or_add + let findi_or_add = findi_or_add + let find = find + let find_exn = find_exn + let find_and_call = find_and_call + let findi_and_call = findi_and_call + let find_and_remove = find_and_remove + let to_alist = to_alist + let validate = validate + let merge = merge + let merge_into = merge_into + let keys = keys + let data = data + let filter_keys_inplace = filter_keys_inplace + let filter_inplace = filter_inplace + let filteri_inplace = filteri_inplace + let map_inplace = map_inplace + let mapi_inplace = mapi_inplace + let filter_map_inplace = filter_map_inplace + let filter_mapi_inplace = filter_mapi_inplace + let equal = equal + let similar = similar + let incr = incr + let decr = decr + let sexp_of_key = sexp_of_key +end + +module Creators (Key : sig + type 'a t + + val hashable : 'a t Hashable.t + end) : sig + + type ('a, 'b) t_ = ('a Key.t, 'b) t + + val t_of_sexp : (Sexp.t -> 'a Key.t) -> (Sexp.t -> 'b) -> Sexp.t -> ('a, 'b) t_ + + include Creators_generic + with type ('a, 'b) t := ('a, 'b) t_ + with type 'a key := 'a Key.t + with type ('key, 'data, 'a) create_options := + ('key, 'data, 'a) create_options_without_first_class_module + +end = struct + + let hashable = Key.hashable + + type ('a, 'b) t_ = ('a Key.t, 'b) t + + let create ?growth_allowed ?size () = create ?growth_allowed ?size ~hashable () + + let of_alist ?growth_allowed ?size l = + of_alist ?growth_allowed ~hashable ?size l + ;; + + let of_alist_report_all_dups ?growth_allowed ?size l = + of_alist_report_all_dups ?growth_allowed ~hashable ?size l + ;; + + let of_alist_or_error ?growth_allowed ?size l = + of_alist_or_error ?growth_allowed ~hashable ?size l + ;; + + let of_alist_exn ?growth_allowed ?size l = + of_alist_exn ?growth_allowed ~hashable ?size l + ;; + + let t_of_sexp k_of_sexp d_of_sexp sexp = + t_of_sexp ~hashable k_of_sexp d_of_sexp sexp + ;; + + let of_alist_multi ?growth_allowed ?size l = + of_alist_multi ?growth_allowed ~hashable ?size l + ;; + + let create_mapped ?growth_allowed ?size ~get_key ~get_data l = + create_mapped ?growth_allowed ~hashable ?size ~get_key ~get_data l + ;; + + let create_with_key ?growth_allowed ?size ~get_key l = + create_with_key ?growth_allowed ~hashable ?size ~get_key l + ;; + + let create_with_key_or_error ?growth_allowed ?size ~get_key l = + create_with_key_or_error ?growth_allowed ~hashable ?size ~get_key l + ;; + + let create_with_key_exn ?growth_allowed ?size ~get_key l = + create_with_key_exn ?growth_allowed ~hashable ?size ~get_key l + ;; + + let group ?growth_allowed ?size ~get_key ~get_data ~combine l = + group ?growth_allowed ~hashable ?size ~get_key ~get_data ~combine l + ;; +end + +module Poly = struct + + type nonrec ('a, 'b) t = ('a, 'b) t + + type 'a key = 'a + + let hashable = Hashable.poly + + include Creators (struct + type 'a t = 'a + let hashable = hashable + end) + + include Accessors + + let sexp_of_t = sexp_of_t +end + +module Private = struct + module type Creators_generic = Creators_generic + module type Hashable = Hashable.Hashable + + type nonrec ('key, 'data, 'z) create_options_without_first_class_module = + ('key, 'data, 'z) create_options_without_first_class_module + + let hashable t = t.hashable +end + +let create ?growth_allowed ?size m = + create ~hashable:(Hashable.of_key m) ?growth_allowed ?size () +let of_alist ?growth_allowed ?size m l = + of_alist ~hashable:(Hashable.of_key m) ?growth_allowed ?size l +let of_alist_report_all_dups ?growth_allowed ?size m l = + of_alist_report_all_dups ~hashable:(Hashable.of_key m) ?growth_allowed ?size l +let of_alist_or_error ?growth_allowed ?size m l = + of_alist_or_error ~hashable:(Hashable.of_key m) ?growth_allowed ?size l +let of_alist_exn ?growth_allowed ?size m l = + of_alist_exn ~hashable:(Hashable.of_key m) ?growth_allowed ?size l +let of_alist_multi ?growth_allowed ?size m l = + of_alist_multi ~hashable:(Hashable.of_key m) ?growth_allowed ?size l +let create_mapped ?growth_allowed ?size m ~get_key ~get_data l = + create_mapped ~hashable:(Hashable.of_key m) ?growth_allowed ?size ~get_key ~get_data l +let create_with_key ?growth_allowed ?size m ~get_key l = + create_with_key ~hashable:(Hashable.of_key m) ?growth_allowed ?size ~get_key l +let create_with_key_or_error ?growth_allowed ?size m ~get_key l = + create_with_key_or_error ~hashable:(Hashable.of_key m) ?growth_allowed ?size ~get_key l +let create_with_key_exn ?growth_allowed ?size m ~get_key l = + create_with_key_exn ~hashable:(Hashable.of_key m) ?growth_allowed ?size ~get_key l +let group ?growth_allowed ?size m ~get_key ~get_data ~combine l = + group ~hashable:(Hashable.of_key m) ?growth_allowed ?size ~get_key ~get_data ~combine l + +let hashable_s t = Hashable.to_key t.hashable + +module M (K : T.T) = struct + type nonrec 'v t = (K.t, 'v) t +end +module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end +module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Key with type t := t +end + +let sexp_of_m__t (type k) (module K : Sexp_of_m with type t = k) sexp_of_v t = + sexp_of_t K.sexp_of_t sexp_of_v t + +let m__t_of_sexp (type k) (module K : M_of_sexp with type t = k) v_of_sexp sexp = + t_of_sexp ~hashable:(Hashable.of_key (module K)) K.t_of_sexp v_of_sexp sexp + +(* typechecking this code is a compile-time test that [Creators] is a specialization of + [Creators_generic]. *) +module Check : sig end = struct + module Make_creators_check (Type : T.T2) (Key : T.T1) (Options : T.T3) + (M : Creators_generic + with type ('a, 'b) t := ('a, 'b) Type.t + with type 'a key := 'a Key.t + with type ('a, 'b, 'z) create_options := ('a, 'b, 'z) Options.t) + = struct end + + module Check_creators_is_specialization_of_creators_generic (M : Creators) = + Make_creators_check + (struct type ('a, 'b) t = ('a, 'b) M.t end) + (struct type 'a t = 'a end) + (struct type ('a, 'b, 'z) t = ('a, 'b, 'z) create_options end) + (struct + include M + + let create ?growth_allowed ?size m () = + create ?growth_allowed ?size m + end) +end diff --git a/src/hashtbl.mli b/src/hashtbl.mli new file mode 100644 index 0000000..e011dea --- /dev/null +++ b/src/hashtbl.mli @@ -0,0 +1 @@ +include Hashtbl_intf.Hashtbl (** @inline *) diff --git a/src/hashtbl_intf.ml b/src/hashtbl_intf.ml new file mode 100644 index 0000000..99f7913 --- /dev/null +++ b/src/hashtbl_intf.ml @@ -0,0 +1,769 @@ +open! Import + +module type Key = sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** Two [t]s that [compare] equal must have equal hashes for the hashtable + to behave properly. *) + val hash : t -> int +end + +module type Accessors = sig + (** {2 Accessors} *) + + type ('a, 'b) t + type 'a key + + val sexp_of_key : ('a, _) t -> 'a key -> Sexp.t + val clear : (_, _) t -> unit + val copy : ('a, 'b) t -> ('a, 'b) t + + (** Attempting to modify ([set], [remove], etc.) the hashtable during iteration ([fold], + [iter], [iter_keys], [iteri]) will raise an exception. *) + val fold : ('a, 'b) t -> init:'c -> f:(key:'a key -> data:'b -> 'c -> 'c) -> 'c + + val iter_keys : ('a, _) t -> f:( 'a key -> unit) -> unit + val iter : ( _, 'b) t -> f:( 'b -> unit) -> unit + + (** Iterates over both keys and values. + + Example: + + {v + let h = Hashtbl.of_alist_exn (module Int) [(1, 4); (5, 6)] in + Hashtbl.iteri h ~f:(fun ~key ~data -> + print_endline (Printf.sprintf "%d-%d" key data));; + 1-4 + 5-6 + - : unit = () + v} *) + val iteri : ('a, 'b) t -> f:(key:'a key -> data:'b -> unit) -> unit + + val existsi : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> bool + val exists : (_ , 'b) t -> f:( 'b -> bool) -> bool + val for_alli : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> bool + val for_all : (_ , 'b) t -> f:( 'b -> bool) -> bool + val counti : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> int + val count : (_ , 'b) t -> f:( 'b -> bool) -> int + + val length : (_, _) t -> int + val is_empty : (_, _) t -> bool + val mem : ('a, _) t -> 'a key -> bool + val remove : ('a, _) t -> 'a key -> unit + + (** Sets the given [key] to [data]. *) + val set : ('a, 'b) t -> key:'a key -> data:'b -> unit + + (** [add] and [add_exn] leave the table unchanged if the key was already present. *) + val add : ('a, 'b) t -> key:'a key -> data:'b -> [ `Ok | `Duplicate ] + val add_exn : ('a, 'b) t -> key:'a key -> data:'b -> unit + + (** [change t key ~f] changes [t]'s value for [key] to be [f (find t key)]. *) + val change : ('a, 'b) t -> 'a key -> f:('b option -> 'b option) -> unit + + (** [update t key ~f] is [change t key ~f:(fun o -> Some (f o))]. *) + val update : ('a, 'b) t -> 'a key -> f:('b option -> 'b) -> unit + + (** [map t f] returns a new table with values replaced by the result of applying [f] + to the current values. + + Example: + + {v + let h = Hashtbl.of_alist_exn (module Int) [(1, 4); (5, 6)] in + let h' = Hashtbl.map h ~f:(fun x -> x * 2) in + Hashtbl.to_alist h';; + - : (int * int) list = [(5, 12); (1, 8)] + v} *) + val map : ('a, 'b) t -> f:('b -> 'c) -> ('a, 'c) t + + (** Like [map], but the function [f] takes both key and data as arguments. *) + val mapi : ('a, 'b) t -> f:(key:'a key -> data:'b -> 'c) -> ('a, 'c) t + + (** Returns a new table by filtering the given table's values by [f]: the keys for which + [f] applied to the current value returns [Some] are kept, and those for which it + returns [None] are discarded. + + Example: + + {v + let h = Hashtbl.of_alist_exn (module Int) [(1, 4); (5, 6)] in + Hashtbl.filter_map h ~f:(fun x -> if x > 5 then Some x else None) + |> Hashtbl.to_alist;; + - : (int * int) list = [(5, 6)] + v} *) + val filter_map : ('a, 'b) t -> f:('b -> 'c option) -> ('a, 'c) t + + (** Like [filter_map], but the function [f] takes both key and data as arguments. *) + val filter_mapi + : ('a, 'b) t -> f:(key:'a key -> data:'b -> 'c option) -> ('a, 'c) t + + val filter_keys : ('a, 'b) t -> f:('a key -> bool) -> ('a, 'b) t + val filter : ('a, 'b) t -> f:('b -> bool) -> ('a, 'b) t + val filteri : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> ('a, 'b) t + + (** Returns new tables with bound values partitioned by [f] applied to the bound + values. *) + val partition_map + : ('a, 'b) t + -> f:('b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c) t * ('a, 'd) t + + (** Like [partition_map], but the function [f] takes both key and data as arguments. *) + val partition_mapi + : ('a, 'b) t + -> f:(key:'a key -> data:'b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c) t * ('a, 'd) t + + (** Returns a pair of tables [(t1, t2)], where [t1] contains all the elements of the + initial table which satisfy the predicate [f], and [t2] contains the rest. *) + val partition_tf : ('a, 'b) t -> f:('b -> bool) -> ('a, 'b) t * ('a, 'b) t + + (** Like [partition_tf], but the function [f] takes both key and data as arguments. *) + val partitioni_tf + : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> ('a, 'b) t * ('a, 'b) t + + (** [find_or_add t k ~default] returns the data associated with key [k] if it is in the + table [t], and otherwise assigns [k] the value returned by [default ()]. *) + val find_or_add : ('a, 'b) t -> 'a key -> default:(unit -> 'b) -> 'b + + (** Like [find_or_add] but [default] takes the key as an argument. *) + val findi_or_add : ('a, 'b) t -> 'a key -> default:('a key -> 'b) -> 'b + + (** [find t k] returns [Some] (the current binding) of [k] in [t], or [None] if no such + binding exists. *) + val find : ('a, 'b) t -> 'a key -> 'b option + + (** [find_exn t k] returns the current binding of [k] in [t], or raises [Caml.Not_found] + or [Not_found_s] if no such binding exists. *) + val find_exn : ('a, 'b) t -> 'a key -> 'b + + (** [find_and_call t k ~if_found ~if_not_found] + + is equivalent to: + + [match find t k with Some v -> if_found v | None -> if_not_found k] + + except that it doesn't allocate the option. *) + val find_and_call + : ('a, 'b) t + -> 'a key + -> if_found:('b -> 'c) + -> if_not_found:('a key -> 'c) + -> 'c + + val findi_and_call + : ('a, 'b) t + -> 'a key + -> if_found:(key:'a key -> data:'b -> 'c) + -> if_not_found:('a key -> 'c) + -> 'c + + (** [find_and_remove t k] returns Some (the current binding) of k in t and removes it, + or None is no such binding exists. *) + val find_and_remove : ('a, 'b) t -> 'a key -> 'b option + + (** Merges two hashtables. + + The result of [merge f h1 h2] has as keys the set of all [k] in the union of the + sets of keys of [h1] and [h2] for which [d(k)] is not None, where: + + d(k) = + - [f ~key:k (Some d1) None] + if [k] in [h1] maps to d1, and [h2] does not have data for [k]; + + - [f ~key:k None (Some d2)] + if [k] in [h2] maps to d2, and [h1] does not have data for [k]; + + - [f ~key:k (Some d1) (Some d2)] + otherwise, where [k] in [h1] maps to [d1] and [k] in [h2] maps to [d2]. + + Each key [k] is mapped to a single piece of data [x], where [d(k) = Some x]. + + Example: + + {v + let h1 = Hashtbl.of_alist_exn (module Int) [(1, 5); (2, 3232)] in + let h2 = Hashtbl.of_alist_exn (module Int) [(1, 3)] in + Hashtbl.merge h1 h2 ~f:(fun ~key:_ -> function + | `Left x -> Some (`Left x) + | `Right x -> Some (`Right x) + | `Both (x, y) -> if x=y then None else Some (`Both (x,y)) + ) |> Hashtbl.to_alist;; + - : (int * [> `Both of int * int | `Left of int | `Right of int ]) list = + [(2, `Left 3232); (1, `Both (5, 3))] + v} *) + val merge + : ('k, 'a) t + -> ('k, 'b) t + -> f:(key:'k key -> [ `Left of 'a | `Right of 'b | `Both of 'a * 'b ] -> 'c option) + -> ('k, 'c) t + + (** Every [key] in [src] will be removed or set in [dst] according to the return value + of [f]. *) + type 'a merge_into_action = Remove | Set_to of 'a + + val merge_into + : src:('k, 'a) t + -> dst:('k, 'b) t + -> f:(key:'k key -> 'a -> 'b option -> 'b merge_into_action) + -> unit + + (** Returns the list of all keys for given hashtable. *) + val keys : ('a, _) t -> 'a key list + + (** Returns the list of all data for given hashtable. *) + val data : (_, 'b) t -> 'b list + + (** [filter_inplace t ~f] removes all the elements from [t] that don't satisfy [f]. *) + val filter_keys_inplace : ('a, _) t -> f:('a key -> bool) -> unit + val filter_inplace : ( _, 'b) t -> f:('b -> bool) -> unit + val filteri_inplace : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> unit + + (** [map_inplace t ~f] applies [f] to all elements in [t], transforming them in + place. *) + val map_inplace : (_, 'b) t -> f:( 'b -> 'b) -> unit + val mapi_inplace : ('a, 'b) t -> f:(key:'a key -> data:'b -> 'b) -> unit + + (** [filter_map_inplace] combines the effects of [map_inplace] and [filter_inplace]. *) + val filter_map_inplace : (_, 'b) t -> f:( 'b -> 'b option) -> unit + val filter_mapi_inplace : ('a, 'b) t -> f:(key:'a key -> data:'b -> 'b option) -> unit + + (** [equal t1 t2 f] and [similar t1 t2 f] both return true iff [t1] and [t2] have the + same keys and for all keys [k], [f (find_exn t1 k) (find_exn t2 k)]. [equal] and + [similar] only differ in their types. *) + val equal : ('a, 'b ) t -> ('a, 'b ) t -> ('b -> 'b -> bool) -> bool + val similar : ('a, 'b1) t -> ('a, 'b2) t -> ('b1 -> 'b2 -> bool) -> bool + + (** Returns the list of all (key, data) pairs for given hashtable. *) + val to_alist : ('a, 'b) t -> ('a key * 'b) list + + val validate + : name:('a key -> string) + -> 'b Validate.check + -> ('a, 'b) t Validate.check + + (** [remove_if_zero]'s default is [false]. *) + val incr : ?by:int -> ?remove_if_zero:bool -> ('a, int) t -> 'a key -> unit + val decr : ?by:int -> ?remove_if_zero:bool -> ('a, int) t -> 'a key -> unit +end + +module type Multi = sig + type ('a, 'b) t + type 'a key + + (** [add_multi t ~key ~data] if [key] is present in the table then cons + [data] on the list, otherwise add [key] with a single element list. *) + val add_multi : ('a, 'b list) t -> key:'a key -> data:'b -> unit + + (** [remove_multi t key] updates the table, removing the head of the list bound to + [key]. If the list has only one element (or is empty) then the binding is + removed. *) + val remove_multi : ('a, _ list) t -> 'a key -> unit + + (** [find_multi t key] returns the empty list if [key] is not present in the table, + returns [t]'s values for [key] otherwise. *) + val find_multi : ('a, 'b list) t -> 'a key -> 'b list +end + +type ('key, 'data, 'z) create_options = + ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'key) + -> 'z + +type ('key, 'data, 'z) create_options_without_first_class_module = + ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> 'z + +module type Creators_generic = sig + type ('a, 'b) t + type 'a key + type ('key, 'data, 'z) create_options + + val create : ('a key, 'b, unit -> ('a, 'b) t) create_options + + val of_alist + : ('a key, + 'b, + ('a key * 'b) list + -> [ `Ok of ('a, 'b) t + | `Duplicate_key of 'a key + ]) create_options + + val of_alist_report_all_dups + : ('a key, + 'b, + ('a key * 'b) list + -> [ `Ok of ('a, 'b) t + | `Duplicate_keys of 'a key list + ]) create_options + + val of_alist_or_error + : ('a key, 'b, ('a key * 'b) list -> ('a, 'b) t Or_error.t) create_options + + val of_alist_exn : ('a key, 'b, ('a key * 'b) list -> ('a, 'b) t) create_options + + val of_alist_multi + : ('a key, 'b list, ('a key * 'b) list -> ('a, 'b list) t) create_options + + + (** {[ create_mapped get_key get_data [x1,...,xn] + = of_alist [get_key x1, get_data x1; ...; get_key xn, get_data xn] ]} *) + val create_mapped + : ('a key, + 'b, + get_key:('r -> 'a key) + -> get_data:('r -> 'b) + -> 'r list + -> [ `Ok of ('a, 'b) t + | `Duplicate_keys of 'a key list ]) create_options + + (** {[ create_with_key ~get_key [x1,...,xn] + = of_alist [get_key x1, x1; ...; get_key xn, xn] ]} *) + val create_with_key + : ('a key, + 'r, + get_key:('r -> 'a key) + -> 'r list + -> [ `Ok of ('a, 'r) t + | `Duplicate_keys of 'a key list ]) create_options + + val create_with_key_or_error + : ('a key, + 'r, + get_key:('r -> 'a key) + -> 'r list + -> ('a, 'r) t Or_error.t) create_options + + val create_with_key_exn + : ('a key, + 'r, + get_key:('r -> 'a key) + -> 'r list + -> ('a, 'r) t) create_options + + val group + : ('a key, + 'b, + get_key:('r -> 'a key) + -> get_data:('r -> 'b) + -> combine:('b -> 'b -> 'b) + -> 'r list + -> ('a, 'b) t) create_options +end + +module type Creators = sig + type ('a, 'b) t + + (** {2 Creators} *) + + (** The module you pass to [create] must have a type that is hashable, sexpable, and + comparable. + + Example: + + {v + Hashtbl.create (module Int);; + - : (int, '_a) Hashtbl.t = ;; + v} *) + val create + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a, 'b) t + + (** Example: + + {v + Hashtbl.of_alist (module Int) [(3, "something"); (2, "whatever")] + - : [ `Duplicate_key of int | `Ok of (int, string) Hashtbl.t ] = `Ok + v} *) + val of_alist + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a * 'b) list + -> [ `Ok of ('a, 'b) t + | `Duplicate_key of 'a + ] + + (** Whereas [of_alist] will report [Duplicate_key] no matter how many dups there are in + your list, [of_alist_report_all_dups] will report each and every duplicate entry. + + For example: + + {v + Hashtbl.of_alist (module Int) [(1, "foo"); (1, "bar"); (2, "foo"); (2, "bar")];; + - : [ `Duplicate_key of int | `Ok of (int, string) Hashtbl.t ] = `Duplicate_key 1 + + Hashtbl.of_alist_report_all_dups (module Int) [(1, "foo"); (1, "bar"); (2, "foo"); (2, "bar")];; + - : [ `Duplicate_keys of int list | `Ok of (int, string) Hashtbl.t ] = `Duplicate_keys [1; 2] + v} *) + val of_alist_report_all_dups + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a * 'b) list + -> [ `Ok of ('a, 'b) t + | `Duplicate_keys of 'a list + ] + + val of_alist_or_error + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a * 'b) list + -> ('a, 'b) t Or_error.t + + val of_alist_exn + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a * 'b) list + -> ('a, 'b) t + + (** Creates a {{!Multi} "multi"} hashtable, i.e., a hashtable where each key points to a + list potentially containing multiple values. So instead of short-circuiting with a + [`Duplicate_key] variant on duplicates, as in [of_alist], [of_alist_multi] folds + those values into a list for the given key: + + {v + let h = Hashtbl.of_alist_multi (module Int) [(1, "a"); (1, "b"); (2, "c"); (2, "d")];; + val h : (int, string list) Hashtbl.t = + + Hashtbl.find_exn h 1;; + - : string list = ["b"; "a"] + v} *) + val of_alist_multi + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a * 'b) list + -> ('a, 'b list) t + + (** Applies the [get_key] and [get_data] functions to the ['r list] to create the + initial keys and values, respectively, for the new hashtable. + + {[ create_mapped get_key get_data [x1;...;xn] + = of_alist [get_key x1, get_data x1; ...; get_key xn, get_data xn] + ]} + + Example: + + {v + let h = + Hashtbl.create_mapped (module Int) + ~get_key:(fun x -> x) + ~get_data:(fun x -> x + 1) + [1; 2; 3];; + val h : [ `Duplicate_keys of int list | `Ok of (int, int) Hashtbl.t ] = `Ok + + let h = + match h with + | `Ok x -> x + | `Duplicate_keys _ -> failwith "" + in + Hashtbl.find_exn h 1;; + - : int = 2 + v} *) + val create_mapped + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> get_key:('r -> 'a) + -> get_data:('r -> 'b) + -> 'r list + -> [ `Ok of ('a, 'b) t + | `Duplicate_keys of 'a list ] + + (** {[ create_with_key ~get_key [x1;...;xn] + = of_alist [get_key x1, x1; ...; get_key xn, xn] ]} *) + val create_with_key + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> get_key:('r -> 'a) + -> 'r list + -> [ `Ok of ('a, 'r) t + | `Duplicate_keys of 'a list ] + + val create_with_key_or_error + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> get_key:('r -> 'a) + -> 'r list + -> ('a, 'r) t Or_error.t + + val create_with_key_exn + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> get_key:('r -> 'a) + -> 'r list + -> ('a, 'r) t + + (** Like [create_mapped], applies the [get_key] and [get_data] functions to the ['r + list] to create the initial keys and values, respectively, for the new hashtable -- + and then, like [add_multi], folds together values belonging to the same keys. Here, + though, the function used for the folding is given by [combine] (instead of just + being a [cons]). + + Example: + + {v + Hashtbl.group (module Int) + ~get_key:(fun x -> x / 2) + ~get_data:(fun x -> x) + ~combine:(fun x y -> x * y) + [ 1; 2; 3; 4] + |> Hashtbl.to_alist;; + - : (int * int) list = [(2, 4); (1, 6); (0, 1)] + v} *) + val group + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> get_key:('r -> 'a) + -> get_data:('r -> 'b) + -> combine:('b -> 'b -> 'b) + -> 'r list + -> ('a, 'b) t +end + +module type S_without_submodules = sig + + val hash : 'a -> int + val hash_param : int -> int -> 'a -> int + + + type ('a, 'b) t + + (** We provide a [sexp_of_t] but not a [t_of_sexp] for this type because one needs to be + explicit about the hash and comparison functions used when creating a hashtable. + Note that [Hashtbl.Poly.t] does have [[@@deriving_inline sexp][@@@end]], and uses OCaml's built-in + polymorphic comparison and and polymorphic hashing. *) + val sexp_of_t : ('a -> Sexp.t) -> ('b -> Sexp.t) -> ('a, 'b) t -> Sexp.t + + include Creators + with type ('a, 'b) t := ('a, 'b) t + (** @inline *) + + include Accessors + with type ('a, 'b) t := ('a, 'b) t + with type 'a key = 'a + (** @inline *) + + include Multi + with type ('a, 'b) t := ('a, 'b) t + with type 'a key := 'a key + (** @inline *) + + val hashable_s : ('key, _) t -> (module Key with type t = 'key) + + include Invariant.S2 with type ('a, 'b) t := ('a, 'b) t + +end + +module type S_poly = sig + + type ('a, 'b) t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('a,'b) t := ('a, 'b) t + end[@@ocaml.doc "@inline"] + [@@@end] + + val hashable : 'a Hashable.t + + include Invariant.S2 with type ('a, 'b) t := ('a, 'b) t + + include Creators_generic + with type ('a, 'b) t := ('a, 'b) t + with type 'a key = 'a + with type ('key, 'data, 'z) create_options + := ('key, 'data, 'z) create_options_without_first_class_module + + include Accessors + with type ('a, 'b) t := ('a, 'b) t + with type 'a key := 'a key + + include Multi + with type ('a, 'b) t := ('a, 'b) t + with type 'a key := 'a key +end + +module type For_deriving = sig + type ('k, 'v) t + + module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end + module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Key with type t := t + end + val sexp_of_m__t + : (module Sexp_of_m with type t = 'k) + -> ('v -> Sexp.t) + -> ('k, 'v) t + -> Sexp.t + val m__t_of_sexp + : (module M_of_sexp with type t = 'k) + -> (Sexp.t -> 'v) + -> Sexp.t + -> ('k, 'v) t +end + +module type Hashtbl = sig + + (** A hash table is a mutable data structure implementing a map between keys and values. + It supports constant-time lookup and in-place modification. + + {1 Usage} + + As a simple example, we'll create a hash table with string keys using the + {{!create}[create]} constructor, which expects a module defining the key's type: + + {[ + let h = Hashtbl.create (module String);; + val h : (string, '_a) Hashtbl.t = + ]} + + We can set the values of individual keys with {{!set}[set]}. If the key already has + a value, it will be overwritten. + + {v + Hashtbl.set h ~key:"foo" ~data:5;; + - : unit = () + + Hashtbl.set h ~key:"foo" ~data:6;; + - : unit = () + + Hashtbl.set h ~key:"bar" ~data:6;; + - : unit = () + v} + + We can access values by key, or dump all of the hash table's data: + + {v + Hashtbl.find h "foo";; + - : int option = Some 6 + + Hashtbl.find_exn h "foo";; + - : int = 6 + + Hashtbl.to_alist h;; + - : (string * int) list = [("foo", 6); ("bar", 6)] + v} + + {{!change}[change]} lets us change a key's value by applying the given function: + + {v + Hashtbl.change h "foo" (fun x -> + match x with + | Some x -> Some (x * 2) + | None -> None + );; + - : unit = () + + Hashtbl.to_alist h;; + - : (string * int) list = [("foo", 12); ("bar", 6)] + v} + + + We can use {{!merge}[merge]} to merge two hashtables with fine-grained control over + how we choose values when a key is present in the first ("left") hashtable, the + second ("right"), or both. Here, we'll cons the values when both hashtables have a + key: + + {v + let h1 = Hashtbl.of_alist_exn (module Int) [(1, 5); (2, 3232)] in + let h2 = Hashtbl.of_alist_exn (module Int) [(1, 3)] in + Hashtbl.merge h1 h2 ~f:(fun ~key:_ -> function + | `Left x -> Some (`Left x) + | `Right x -> Some (`Right x) + | `Both (x, y) -> if x=y then None else Some (`Both (x,y)) + ) |> Hashtbl.to_alist;; + - : (int * [> `Both of int * int | `Left of int | `Right of int ]) list = + [(2, `Left 3232); (1, `Both (5, 3))] + v} + + {1 Interface} *) + + include S_without_submodules (** @inline *) + + module type Accessors = Accessors + module type Creators = Creators + module type Key = Key + module type Multi = Multi + module type S_poly = S_poly + module type S_without_submodules = S_without_submodules + + module type For_deriving = For_deriving + + type nonrec ('key, 'data, 'z) create_options = + ('key, 'data, 'z) create_options + + module Creators (Key : sig type 'a t val hashable : 'a t Hashable.t end) : sig + type ('a, 'b) t_ = ('a Key.t, 'b) t + val t_of_sexp : (Sexp.t -> 'a Key.t) -> (Sexp.t -> 'b) -> Sexp.t -> ('a, 'b) t_ + include Creators_generic + with type ('a, 'b) t := ('a, 'b) t_ + with type 'a key := 'a Key.t + with type ('key, 'data, 'a) create_options := + ('key, 'data, 'a) create_options_without_first_class_module + end + + module Poly : S_poly with type ('a, 'b) t = ('a, 'b) t + + (** [M] is meant to be used in combination with OCaml applicative functor types: + + {[ + type string_to_int_table = int Hashtbl.M(String).t + ]} + + which stands for: + + {[ + type string_to_int_table = (String.t, int) Hashtbl.t + ]} + + The point is that [int Hashtbl.M(String).t] supports deriving, whereas the second + syntax doesn't (because [t_of_sexp] doesn't know what comparison/hash function to + use). *) + module M (K : T.T) : sig + type nonrec 'v t = (K.t, 'v) t + end + + include For_deriving with type ('a, 'b) t := ('a, 'b) t + + (**/**) + (*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) + module Private : sig + + module type Creators_generic = Creators_generic + + type nonrec ('key, 'data, 'z) create_options_without_first_class_module = + ('key, 'data, 'z) create_options_without_first_class_module + + val hashable : ('key, _) t -> 'key Hashable.t + end +end diff --git a/src/hex_lexer.mll b/src/hex_lexer.mll new file mode 100644 index 0000000..0551e8d --- /dev/null +++ b/src/hex_lexer.mll @@ -0,0 +1,15 @@ +{ +type result = +| Neg of string +| Pos of string +} + +let hex_digit = ['0' - '9' 'A' - 'F' 'a' - 'f'] +let body = (hex_digit (hex_digit | '_')*) as body +let body_with_suffix = '0' ['X' 'x'] body +let pos = body_with_suffix +let neg = '-' body_with_suffix + +rule parse_hex = parse +| neg { Neg body } +| pos { Pos body } diff --git a/src/identifiable.ml b/src/identifiable.ml new file mode 100644 index 0000000..f4e059e --- /dev/null +++ b/src/identifiable.ml @@ -0,0 +1,58 @@ +open! Import + +module type S = sig + type t [@@deriving_inline hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Stringable .S with type t := t + include Comparable .S with type t := t + include Pretty_printer.S with type t := t +end + +module Make (T : sig + type t [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Stringable.S with type t := t + val module_name : string + end) = struct + include T + include Comparable .Make (T) + include Pretty_printer.Register (T) +end + +module Make_using_comparator (T : sig + type t [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Comparator.S with type t := t + include Stringable.S with type t := t + val module_name : string + end) = struct + include T + include Comparable .Make_using_comparator (T) + include Pretty_printer.Register (T) +end diff --git a/src/identifiable.mli b/src/identifiable.mli new file mode 100644 index 0000000..a0770bf --- /dev/null +++ b/src/identifiable.mli @@ -0,0 +1,77 @@ +(** A signature combining functionality that is commonly used for types that are intended + to act as names or identifiers. + + Modules that satisfy [Identifiable] can be printed and parsed (both through string and + s-expression converters) and can be used in hash-based and comparison-based + containers (e.g., hashtables and maps). + + This module also provides functors for conveniently constructing identifiable + modules. *) + +open! Import + +module type S = sig + type t [@@deriving_inline hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Stringable.S with type t := t + include Comparable.S with type t := t + include Pretty_printer.S with type t := t +end + +(** Used for making an Identifiable module. Here's an example. + + {[ + module Id = struct + module T = struct + type t = A | B [@@deriving_inline compare, hash, sexp][@@@end] + let of_string s = t_of_sexp (sexp_of_string s) + let to_string t = string_of_sexp (sexp_of_t t) + let module_name = "My_library.Std.Id" + end + include T + include Identifiable.Make (T) + end + ]} *) +module Make (M : sig + type t [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Stringable.S with type t := t + val module_name : string (** For registering the pretty printer. *) + end) : S + with type t := M.t + +module Make_using_comparator (M : sig + type t [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Comparator.S with type t := t + include Stringable.S with type t := t + val module_name : string + end) : S + with type t := M.t + with type comparator_witness := M.comparator_witness diff --git a/src/import.ml b/src/import.ml new file mode 100644 index 0000000..e3ba407 --- /dev/null +++ b/src/import.ml @@ -0,0 +1,8 @@ +include Import0 + +include Sexplib0.Sexp_conv +include Hash.Builtin +include Ppx_compare_lib.Builtin +include Int_replace_polymorphic_compare + +exception Not_found_s = Sexp.Not_found_s diff --git a/src/import0.ml b/src/import0.ml new file mode 100644 index 0000000..0431d2e --- /dev/null +++ b/src/import0.ml @@ -0,0 +1,392 @@ +(* This module is included in [Import]. It is aimed at modules that define the standard + combinators for [sexp_of], [of_sexp], [compare] and [hash] and are included in + [Import]. *) + +include + (Shadow_stdlib + : (module type of struct include Shadow_stdlib end + with type 'a ref := 'a ref + with type ('a, 'b, 'c) format := ('a, 'b, 'c) format + with type ('a, 'b, 'c, 'd) format4 := ('a, 'b, 'c, 'd) format4 + with type ('a, 'b, 'c, 'd, 'e, 'f) format6 := ('a, 'b, 'c, 'd, 'e, 'f) format6 + (* These modules are redefined in Base *) + with module Array := Shadow_stdlib.Array + with module Bool := Shadow_stdlib.Bool + with module Buffer := Shadow_stdlib.Buffer + with module Bytes := Shadow_stdlib.Bytes + with module Char := Shadow_stdlib.Char + with module Float := Shadow_stdlib.Float + with module Hashtbl := Shadow_stdlib.Hashtbl + with module Int := Shadow_stdlib.Int + with module Int32 := Shadow_stdlib.Int32 + with module Int64 := Shadow_stdlib.Int64 + with module Lazy := Shadow_stdlib.Lazy + with module List := Shadow_stdlib.List + with module Map := Shadow_stdlib.Map + with module Nativeint := Shadow_stdlib.Nativeint + with module Option := Shadow_stdlib.Option + with module Printf := Shadow_stdlib.Printf + with module Queue := Shadow_stdlib.Queue + with module Random := Shadow_stdlib.Random + with module Result := Shadow_stdlib.Result + with module Set := Shadow_stdlib.Set + with module Stack := Shadow_stdlib.Stack + with module String := Shadow_stdlib.String + with module Sys := Shadow_stdlib.Sys + with module Uchar := Shadow_stdlib.Uchar + with module Unit := Shadow_stdlib.Unit + )) [@ocaml.warning "-3"] +type 'a ref = 'a Caml.ref = { mutable contents: 'a } + +(* Reshuffle [Caml] so that we choose the modules using labels when available. *) +module Caml = struct + + + (** @canonical Caml.Arg *) + module Arg = Caml.Arg + + (** @canonical Caml.StdLabels.Array *) + module Array = Caml.StdLabels.Array + + (** @canonical Caml.Bool *) + module Bool = Caml.Bool + + (** @canonical Caml.Buffer *) + module Buffer = Caml.Buffer + + (** @canonical Caml.StdLabels.Bytes *) + module Bytes = Caml.StdLabels.Bytes + + (** @canonical Caml.Char *) + module Char = Caml.Char + + (** @canonical Caml.Ephemeron *) + module Ephemeron = Caml.Ephemeron + + (** @canonical Caml.Float *) + module Float = Caml.Float + + (** @canonical Caml.Format *) + module Format = Caml.Format + + (** @canonical Caml.Fun *) + module Fun = Caml.Fun + + (** @canonical Caml.Gc *) + module Gc = Caml.Gc + + (** @canonical Caml.MoreLabels.Hashtbl *) + module Hashtbl = Caml.MoreLabels.Hashtbl + + (** @canonical Caml.Int32 *) + module Int32 = Caml.Int32 + + (** @canonical Caml.Int *) + module Int = Caml.Int + + (** @canonical Caml.Int64 *) + module Int64 = Caml.Int64 + + (** @canonical Caml.Lazy *) + module Lazy = Caml.Lazy + + (** @canonical Caml.Lexing *) + module Lexing = Caml.Lexing + + (** @canonical Caml.StdLabels.List *) + module List = Caml.StdLabels.List + + (** @canonical Caml.MoreLabels.Map *) + module Map = Caml.MoreLabels.Map + + (** @canonical Caml.Nativeint *) + module Nativeint = Caml.Nativeint + + (** @canonical Caml.Obj *) + module Obj = Caml.Obj + + (** @canonical Caml.Option *) + module Option = Caml.Option + + (** @canonical Caml.Parsing *) + module Parsing = Caml.Parsing + + (** @canonical Caml.Printexc *) + module Printexc = Caml.Printexc + + (** @canonical Caml.Printf *) + module Printf = Caml.Printf + + (** @canonical Caml.Queue *) + module Queue = Caml.Queue + + (** @canonical Caml.Random *) + module Random = Caml.Random + + (** @canonical Caml.Result *) + module Result = Caml.Result + + (** @canonical Caml.Scanf *) + module Scanf = Caml.Scanf + + (** @canonical Caml.MoreLabels.Set *) + module Set = Caml.MoreLabels.Set + + (** @canonical Caml.Stack *) + module Stack = Caml.Stack + + (** @canonical Caml.Stream *) + module Stream = Caml.Stream + + (** @canonical Caml.StdLabels.String *) + module String = Caml.StdLabels.String + + (** @canonical Caml.Sys *) + module Sys = Caml.Sys + + (** @canonical Caml.Uchar *) + module Uchar = Caml.Uchar + + (** @canonical Caml.Unit *) + module Unit = Caml.Unit + + include Pervasives [@ocaml.warning "-3"] + + exception Not_found = Caml.Not_found +end + +external ( |> ) : 'a -> ( 'a -> 'b) -> 'b = "%revapply" + +(* These need to be declared as an external to get the lazy behavior *) +external ( && ) : bool -> bool -> bool = "%sequand" +external ( || ) : bool -> bool -> bool = "%sequor" +external not : bool -> bool = "%boolnot" + +(* This need to be declared as an external for the warnings to work properly *) +external ignore : _ -> unit = "%ignore" + +let ( != ) = Caml.( != ) +let ( * ) = Caml.( * ) +let ( ** ) = Caml.( ** ) +let ( *. ) = Caml.( *. ) +let ( + ) = Caml.( + ) +let ( +. ) = Caml.( +. ) +let ( - ) = Caml.( - ) +let ( -. ) = Caml.( -. ) +let ( / ) = Caml.( / ) +let ( /. ) = Caml.( /. ) + +(** @canonical Base.Poly *) +module Poly = Poly0 + +module Int_replace_polymorphic_compare = struct + let ( < ) (x : int) y = Poly.( < ) x y + let ( <= ) (x : int) y = Poly.( <= ) x y + let ( <> ) (x : int) y = Poly.( <> ) x y + let ( = ) (x : int) y = Poly.( = ) x y + let ( > ) (x : int) y = Poly.( > ) x y + let ( >= ) (x : int) y = Poly.( >= ) x y + + let ascending (x : int) y = Poly.ascending x y + let descending (x : int) y = Poly.descending x y + let compare (x : int) y = Poly.compare x y + let equal (x : int) y = Poly.equal x y + let max (x : int) y = if x >= y then x else y + let min (x : int) y = if x <= y then x else y +end + +include Int_replace_polymorphic_compare + +module Int32_replace_polymorphic_compare = struct + let ( < ) (x : Caml.Int32.t) y = Poly.( < ) x y + let ( <= ) (x : Caml.Int32.t) y = Poly.( <= ) x y + let ( <> ) (x : Caml.Int32.t) y = Poly.( <> ) x y + let ( = ) (x : Caml.Int32.t) y = Poly.( = ) x y + let ( > ) (x : Caml.Int32.t) y = Poly.( > ) x y + let ( >= ) (x : Caml.Int32.t) y = Poly.( >= ) x y + + let ascending (x : Caml.Int32.t) y = Poly.ascending x y + let descending (x : Caml.Int32.t) y = Poly.descending x y + let compare (x : Caml.Int32.t) y = Poly.compare x y + let equal (x : Caml.Int32.t) y = Poly.equal x y + let max (x : Caml.Int32.t) y = if x >= y then x else y + let min (x : Caml.Int32.t) y = if x <= y then x else y +end + +module Int64_replace_polymorphic_compare = struct + let ( < ) (x : Caml.Int64.t) y = Poly.( < ) x y + let ( <= ) (x : Caml.Int64.t) y = Poly.( <= ) x y + let ( <> ) (x : Caml.Int64.t) y = Poly.( <> ) x y + let ( = ) (x : Caml.Int64.t) y = Poly.( = ) x y + let ( > ) (x : Caml.Int64.t) y = Poly.( > ) x y + let ( >= ) (x : Caml.Int64.t) y = Poly.( >= ) x y + + let ascending (x : Caml.Int64.t) y = Poly.ascending x y + let descending (x : Caml.Int64.t) y = Poly.descending x y + let compare (x : Caml.Int64.t) y = Poly.compare x y + let equal (x : Caml.Int64.t) y = Poly.equal x y + let max (x : Caml.Int64.t) y = if x >= y then x else y + let min (x : Caml.Int64.t) y = if x <= y then x else y +end + +module Nativeint_replace_polymorphic_compare = struct + let ( < ) (x : Caml.Nativeint.t) y = Poly.( < ) x y + let ( <= ) (x : Caml.Nativeint.t) y = Poly.( <= ) x y + let ( <> ) (x : Caml.Nativeint.t) y = Poly.( <> ) x y + let ( = ) (x : Caml.Nativeint.t) y = Poly.( = ) x y + let ( > ) (x : Caml.Nativeint.t) y = Poly.( > ) x y + let ( >= ) (x : Caml.Nativeint.t) y = Poly.( >= ) x y + + let ascending (x : Caml.Nativeint.t) y = Poly.ascending x y + let descending (x : Caml.Nativeint.t) y = Poly.descending x y + let compare (x : Caml.Nativeint.t) y = Poly.compare x y + let equal (x : Caml.Nativeint.t) y = Poly.equal x y + let max (x : Caml.Nativeint.t) y = if x >= y then x else y + let min (x : Caml.Nativeint.t) y = if x <= y then x else y +end + +module Bool_replace_polymorphic_compare = struct + let ( < ) (x : bool) y = Poly.( < ) x y + let ( <= ) (x : bool) y = Poly.( <= ) x y + let ( <> ) (x : bool) y = Poly.( <> ) x y + let ( = ) (x : bool) y = Poly.( = ) x y + let ( > ) (x : bool) y = Poly.( > ) x y + let ( >= ) (x : bool) y = Poly.( >= ) x y + + let ascending (x : bool) y = Poly.ascending x y + let descending (x : bool) y = Poly.descending x y + let compare (x : bool) y = Poly.compare x y + let equal (x : bool) y = Poly.equal x y + let max (x : bool) y = if x >= y then x else y + let min (x : bool) y = if x <= y then x else y +end + +module Char_replace_polymorphic_compare = struct + let ( < ) (x : char) y = Poly.( < ) x y + let ( <= ) (x : char) y = Poly.( <= ) x y + let ( <> ) (x : char) y = Poly.( <> ) x y + let ( = ) (x : char) y = Poly.( = ) x y + let ( > ) (x : char) y = Poly.( > ) x y + let ( >= ) (x : char) y = Poly.( >= ) x y + + let ascending (x : char) y = Poly.ascending x y + let descending (x : char) y = Poly.descending x y + let compare (x : char) y = Poly.compare x y + let equal (x : char) y = Poly.equal x y + let max (x : char) y = if x >= y then x else y + let min (x : char) y = if x <= y then x else y +end + +module Uchar_replace_polymorphic_compare = struct + let i x = Caml.Uchar.to_int x + + let ( < ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( < ) (i x) (i y) + let ( <= ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( <= ) (i x) (i y) + let ( <> ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( <> ) (i x) (i y) + let ( = ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( = ) (i x) (i y) + let ( > ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( > ) (i x) (i y) + let ( >= ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( >= ) (i x) (i y) + + let ascending (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.ascending (i x) (i y) + let descending (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.descending (i x) (i y) + let compare (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.compare (i x) (i y) + let equal (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.equal (i x) (i y) + let max (x : Caml.Uchar.t) y = if x >= y then x else y + let min (x : Caml.Uchar.t) y = if x <= y then x else y +end + +module Float_replace_polymorphic_compare = struct + let ( < ) (x : float) y = Poly.( < ) x y + let ( <= ) (x : float) y = Poly.( <= ) x y + let ( <> ) (x : float) y = Poly.( <> ) x y + let ( = ) (x : float) y = Poly.( = ) x y + let ( > ) (x : float) y = Poly.( > ) x y + let ( >= ) (x : float) y = Poly.( >= ) x y + + let ascending (x : float) y = Poly.ascending x y + let descending (x : float) y = Poly.descending x y + let compare (x : float) y = Poly.compare x y + let equal (x : float) y = Poly.equal x y + let max (x : float) y = if x >= y then x else y + let min (x : float) y = if x <= y then x else y +end + +module String_replace_polymorphic_compare = struct + let ( < ) (x : string) y = Poly.( < ) x y + let ( <= ) (x : string) y = Poly.( <= ) x y + let ( <> ) (x : string) y = Poly.( <> ) x y + let ( = ) (x : string) y = Poly.( = ) x y + let ( > ) (x : string) y = Poly.( > ) x y + let ( >= ) (x : string) y = Poly.( >= ) x y + + let ascending (x : string) y = Poly.ascending x y + let descending (x : string) y = Poly.descending x y + let compare (x : string) y = Poly.compare x y + let equal (x : string) y = Poly.equal x y + let max (x : string) y = if x >= y then x else y + let min (x : string) y = if x <= y then x else y +end + +module Bytes_replace_polymorphic_compare = struct + let ( < ) (x : bytes) y = Poly.( < ) x y + let ( <= ) (x : bytes) y = Poly.( <= ) x y + let ( <> ) (x : bytes) y = Poly.( <> ) x y + let ( = ) (x : bytes) y = Poly.( = ) x y + let ( > ) (x : bytes) y = Poly.( > ) x y + let ( >= ) (x : bytes) y = Poly.( >= ) x y + + let ascending (x : bytes) y = Poly.ascending x y + let descending (x : bytes) y = Poly.descending x y + let compare (x : bytes) y = Poly.compare x y + let equal (x : bytes) y = Poly.equal x y + let max (x : bytes) y = if x >= y then x else y + let min (x : bytes) y = if x <= y then x else y +end + +(* This needs to be defined as an external so that the compiler can specialize it as a + direct set or caml_modify *) +external ( := ) : 'a ref -> 'a -> unit = "%setfield0" + +(* These need to be defined as an external otherwise the compiler won't unbox + references *) +external ( ! ) : 'a ref -> 'a = "%field0" +external ref : 'a -> 'a ref = "%makemutable" + +let ( @ ) = Caml.( @ ) +let ( ^ ) = Caml.( ^ ) +let ( ~- ) = Caml.( ~- ) +let ( ~-. ) = Caml.( ~-. ) + +let ( asr ) = Caml.( asr ) +let ( land ) = Caml.( land ) +let lnot = Caml.lnot +let ( lor ) = Caml.( lor ) +let ( lsl ) = Caml.( lsl ) +let ( lsr ) = Caml.( lsr ) +let ( lxor ) = Caml.( lxor ) +let ( mod ) = Caml.( mod ) + +let abs = Caml.abs +let failwith = Caml.failwith +let fst = Caml.fst +let invalid_arg = Caml.invalid_arg +let snd = Caml.snd + +(* [raise] needs to be defined as an external as the compiler automatically replaces + '%raise' by '%reraise' when appropriate. *) +external raise : exn -> _ = "%raise" + +let phys_equal = Caml.( == ) + +let decr = Caml.decr +let incr = Caml.incr + +(* used by sexp_conv, which float0 depends on through option *) +let float_of_string = Caml.float_of_string + +(* [am_testing] is used in a few places to behave differently when in testing mode, such + as in [random.ml]. [am_testing] is implemented using [Base_am_testing], a weak C/js + primitive that returns [false], but when linking an inline-test-runner executable, is + overridden by another primitive that returns [true]. *) +external am_testing : unit -> bool = "Base_am_testing" +let am_testing = am_testing () diff --git a/src/indexed_container.ml b/src/indexed_container.ml new file mode 100644 index 0000000..809bfdf --- /dev/null +++ b/src/indexed_container.ml @@ -0,0 +1,65 @@ +include Indexed_container_intf + +let with_return = With_return.with_return + +let iteri ~fold t ~f = + ignore (fold t ~init:0 ~f:(fun i x -> f i x; i + 1) : int) +;; + +let foldi ~fold t ~init ~f = + let i = ref 0 in + fold t ~init ~f:(fun acc v -> + let acc = f !i acc v in + i := !i + 1; + acc) +;; + +let counti ~foldi t ~f = + foldi t ~init:0 ~f:(fun i n a -> if f i a then n + 1 else n) +;; + +let existsi ~iteri c ~f = + with_return (fun r -> + iteri c ~f:(fun i x -> if f i x then r.return true); + false) +;; + +let for_alli ~iteri c ~f = + with_return (fun r -> + iteri c ~f:(fun i x -> if not (f i x) then r.return false); + true) +;; + +let find_mapi ~iteri t ~f = + with_return (fun r -> + iteri t ~f:(fun i x -> match f i x with None -> () | Some _ as res -> r.return res); + None) +;; + +let findi ~iteri c ~f = + with_return (fun r -> + iteri c ~f:(fun i x -> if f i x then r.return (Some (i, x))); + None) +;; + +module Make (T : Make_arg) : S1 with type 'a t := 'a T.t = struct + include (Container.Make (T)) + + let iteri = + match T.iteri with + | `Custom iteri -> iteri + | `Define_using_fold -> fun t ~f -> iteri ~fold t ~f + ;; + + let foldi = + match T.foldi with + | `Custom foldi -> foldi + | `Define_using_fold -> fun t ~init ~f -> foldi ~fold t ~init ~f + ;; + + let counti t ~f = counti ~foldi t ~f + let existsi t ~f = existsi ~iteri t ~f + let for_alli t ~f = for_alli ~iteri t ~f + let find_mapi t ~f = find_mapi ~iteri t ~f + let findi t ~f = findi ~iteri t ~f +end diff --git a/src/indexed_container.mli b/src/indexed_container.mli new file mode 100644 index 0000000..f5a29e5 --- /dev/null +++ b/src/indexed_container.mli @@ -0,0 +1 @@ +include Indexed_container_intf.Indexed_container (** @inline *) diff --git a/src/indexed_container_intf.ml b/src/indexed_container_intf.ml new file mode 100644 index 0000000..3ae88d2 --- /dev/null +++ b/src/indexed_container_intf.ml @@ -0,0 +1,66 @@ + +type ('t, 'a, 'accum) fold = 't -> init:'accum -> f:('accum -> 'a -> 'accum) -> 'accum +type ('t, 'a, 'accum) foldi = + 't -> init:'accum -> f:(int -> 'accum -> 'a -> 'accum) -> 'accum +type ('t, 'a) iteri = 't -> f:(int -> 'a -> unit) -> unit + +module type S1 = sig + include Container.S1 + + (** These are all like their equivalents in [Container] except that an index starting at + 0 is added as the first argument to [f]. *) + + val foldi : ('a t, 'a, _) foldi + + val iteri : ('a t, 'a) iteri + val existsi : 'a t -> f:(int -> 'a -> bool) -> bool + val for_alli : 'a t -> f:(int -> 'a -> bool) -> bool + val counti : 'a t -> f:(int -> 'a -> bool) -> int + val findi : 'a t -> f:(int -> 'a -> bool) -> (int * 'a) option + val find_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b option + +end + +module type Make_arg = sig + include Container_intf.Make_arg + + val iteri : + [ `Define_using_fold + | `Custom of ('a t, 'a) iteri + ] + + val foldi : + [ `Define_using_fold + | `Custom of ('a t, 'a, _) foldi + ] +end + +module type Indexed_container = sig + + (** Provides generic signatures for containers that support indexed iteration ([iteri], + [foldi], ...). In principle, any container that has [iter] can also implement [iteri], + but the idea is that [Indexed_container_intf] should be included only for containers + that have a meaningful underlying ordering. *) + + module type S1 = S1 + + (** Generic definitions of [foldi] and [iteri] in terms of [fold]. + + E.g., [iteri ~fold t ~f = ignore (fold t ~init:0 ~f:(fun i x -> f i x; i + 1))]. *) + + val foldi : fold:('t, 'a, 'accum) fold -> ('t, 'a, 'accum) foldi + val iteri : fold:('t, 'a, int) fold -> ('t, 'a) iteri + + (** Generic definitions of indexed container operations in terms of [foldi]. *) + + val counti : foldi:('t, 'a, int) foldi -> 't -> f:(int -> 'a -> bool) -> int + + (** Generic definitions of indexed container operations in terms of [iteri]. *) + + val existsi : iteri:('t, 'a) iteri -> 't -> f:(int -> 'a -> bool) -> bool + val for_alli : iteri:('t, 'a) iteri -> 't -> f:(int -> 'a -> bool) -> bool + val findi : iteri:('t, 'a) iteri -> 't -> f:(int -> 'a -> bool) -> (int * 'a) option + val find_mapi : iteri:('t, 'a) iteri -> 't -> f:(int -> 'a -> 'b option) -> 'b option + + module Make (T : Make_arg) : S1 with type 'a t := 'a T.t +end diff --git a/src/info.ml b/src/info.ml new file mode 100644 index 0000000..ce12e85 --- /dev/null +++ b/src/info.ml @@ -0,0 +1,240 @@ +(* This module is trying to minimize dependencies on modules in Core, so as to allow + [Info], [Error], and [Or_error] to be used in as many places as possible. Please avoid + adding new dependencies. *) + +open! Import + +include Info_intf + +module String = String0 + +module Message = struct + type t = + | Could_not_construct of Sexp.t + | String of string + | Exn of exn + | Sexp of Sexp.t + | Tag_sexp of string * Sexp.t * Source_code_position0.t option + | Tag_t of string * t + | Tag_arg of string * Sexp.t * t + | Of_list of int option * t list + | With_backtrace of t * string (* backtrace *) + [@@deriving_inline sexp_of] + let rec sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | Could_not_construct v0 -> + let v0 = Sexp.sexp_of_t v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Could_not_construct"; v0] + | String v0 -> + let v0 = sexp_of_string v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "String"; v0] + | Exn v0 -> + let v0 = sexp_of_exn v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Exn"; v0] + | Sexp v0 -> + let v0 = Sexp.sexp_of_t v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Sexp"; v0] + | Tag_sexp (v0, v1, v2) -> + let v0 = sexp_of_string v0 + and v1 = Sexp.sexp_of_t v1 + and v2 = sexp_of_option Source_code_position0.sexp_of_t v2 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Tag_sexp"; v0; v1; v2] + | Tag_t (v0, v1) -> + let v0 = sexp_of_string v0 + and v1 = sexp_of_t v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Tag_t"; v0; v1] + | Tag_arg (v0, v1, v2) -> + let v0 = sexp_of_string v0 + and v1 = Sexp.sexp_of_t v1 + and v2 = sexp_of_t v2 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Tag_arg"; v0; v1; v2] + | Of_list (v0, v1) -> + let v0 = sexp_of_option sexp_of_int v0 + and v1 = sexp_of_list sexp_of_t v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Of_list"; v0; v1] + | With_backtrace (v0, v1) -> + let v0 = sexp_of_t v0 + and v1 = sexp_of_string v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "With_backtrace"; v0; v1] + [@@@end] + + let rec to_strings_hum t ac = + (* We use [Sexp.to_string_mach], despite the fact that we are implementing + [to_strings_hum], because we want the info to fit on a single line, and once we've + had to resort to sexps, the message is going to start not looking so pretty + anyway. *) + match t with + | Could_not_construct sexp -> + "could not construct info: " :: Sexp.to_string_mach sexp :: ac + | String string -> string :: ac + | Exn exn -> Sexp.to_string_mach (Exn.sexp_of_t exn) :: ac + | Sexp sexp -> Sexp.to_string_mach sexp :: ac + | Tag_sexp (tag, sexp, _) -> tag :: ": " :: Sexp.to_string_mach sexp :: ac + | Tag_t (tag, t) -> tag :: ": " :: to_strings_hum t ac + | Tag_arg (tag, sexp, t) -> + tag :: ": " :: Sexp.to_string_mach sexp :: ": " :: to_strings_hum t ac + | With_backtrace (t, backtrace) -> + to_strings_hum t ("\nBacktrace:\n" :: backtrace :: ac) + | Of_list (trunc_after, ts) -> + let ts = + match trunc_after with + | None -> ts + | Some max -> + let n = List.length ts in + if n <= max then + ts + else + List.take ts max @ [ String (Printf.sprintf "and %d more info" (n - max)) ] + in + List.fold (List.rev ts) ~init:ac ~f:(fun ac t -> + to_strings_hum t (if List.is_empty ac then ac else ("; " :: ac))) + ;; + + let to_string_hum_deprecated t = String.concat (to_strings_hum t []) + + let rec to_sexps_hum t ac = + match t with + | Could_not_construct _ as t -> sexp_of_t t :: ac + | String string -> Atom string :: ac + | Exn exn -> Exn.sexp_of_t exn :: ac + | Sexp sexp -> sexp :: ac + | Tag_sexp (tag, sexp, here) -> + List ( Atom tag + :: sexp + :: (match here with + | None -> [] + | Some here -> [ Source_code_position0.sexp_of_t here ])) + :: ac + | Tag_t (tag, t) -> List (Atom tag :: to_sexps_hum t []) :: ac + | Tag_arg (tag, sexp, t) -> List (Atom tag :: sexp :: to_sexps_hum t []) :: ac + | With_backtrace (t, backtrace) -> + Sexp.List [ to_sexp_hum t; Sexp.Atom backtrace ] :: ac + | Of_list (_, ts) -> + List.fold (List.rev ts) ~init:ac ~f:(fun ac t -> to_sexps_hum t ac) + and to_sexp_hum t = + match to_sexps_hum t [] with + | [sexp] -> sexp + | sexps -> Sexp.List sexps + ;; + + (* We use [protect] to guard against exceptions raised by user-supplied functions, so + that failure to produce one part of an info doesn't interfere with other parts. *) + let protect f = + try f () with exn -> Could_not_construct (Exn.sexp_of_t exn) + ;; + + let of_info info = protect (fun () -> Lazy.force info) + let to_info t = lazy t +end + +open Message + +type t = Message.t Lazy.t + +let invariant _ = () + +let to_message = Message.of_info +let of_message = Message.to_info + +(* It is OK to use [Message.to_sexp_hum], which is not stable, because [t_of_sexp] below + can handle any sexp. *) +let sexp_of_t t = Message.to_sexp_hum (to_message t) + +let t_of_sexp sexp = lazy (Message.Sexp sexp) + +let compare t1 t2 = + Sexp.compare (sexp_of_t t1) (sexp_of_t t2) +;; + +let hash_fold_t state t = Sexp.hash_fold_t state (sexp_of_t t) +let hash t = Hash.run hash_fold_t t + +let to_string_hum t = + match to_message t with + | String s -> s + | message -> Sexp.to_string_hum (Message.to_sexp_hum message) +;; + +let to_string_hum_deprecated t = Message.to_string_hum_deprecated (to_message t) + +let to_string_mach t = Sexp.to_string_mach (sexp_of_t t) + +let of_lazy l = lazy (protect (fun () -> String (Lazy.force l))) + +let of_lazy_t lazy_t = Lazy.join lazy_t + +let of_string message = Lazy.from_val (String message) + +let createf format = Printf.ksprintf of_string format + +let of_thunk f = lazy (protect (fun () -> String (f ()))) + +let create ?here ?strict tag x sexp_of_x = + match strict with + | None -> lazy (protect (fun () -> Tag_sexp (tag, sexp_of_x x, here))) + | Some () -> of_message ( Tag_sexp (tag, sexp_of_x x, here)) +;; + +let create_s sexp = Lazy.from_val (Sexp sexp) + +let tag t ~tag = lazy (Tag_t (tag, to_message t)) + +let tag_arg t tag x sexp_of_x = + lazy (protect (fun () -> Tag_arg (tag, sexp_of_x x, to_message t))) +;; + +let of_list ?trunc_after ts = + lazy (Of_list (trunc_after, List.map ts ~f:to_message)) +;; + +exception Exn of t + +let () = + (* We install a custom exn-converter rather than use + [exception Exn of t [@@deriving_inline sexp][@@@end]] to eliminate the extra wrapping of + "(Exn ...)". *) + Sexplib.Conv.Exn_converter.add [%extension_constructor Exn] + (function + | Exn t -> sexp_of_t t + | _ -> + (* Reaching this branch indicates a bug in sexplib. *) + assert false) +;; + +let to_exn t = + if not (Lazy.is_val t) + then Exn t + else + match Lazy.force t with + | Message.Exn exn -> exn + | _ -> Exn t +;; + +let of_exn ?backtrace exn = + let backtrace = + match backtrace with + | None -> None + | Some `Get -> Some (Caml.Printexc.get_backtrace ()) + | Some (`This s) -> Some s + in + match exn, backtrace with + | Exn t, None -> t + | Exn t, Some backtrace -> lazy (With_backtrace (to_message t, backtrace)) + | _ , None -> Lazy.from_val (Message.Exn exn) + | _ , Some backtrace -> lazy (With_backtrace (Sexp (Exn.sexp_of_t exn), backtrace)) +;; + +include Pretty_printer.Register_pp(struct + type nonrec t = t + let module_name = "Base.Info" + let pp ppf t = Caml.Format.pp_print_string ppf (to_string_hum t) + end) + +module Internal_repr = Message + diff --git a/src/info.mli b/src/info.mli new file mode 100644 index 0000000..fc27b4a --- /dev/null +++ b/src/info.mli @@ -0,0 +1 @@ +include Info_intf.Info (** @inline *) diff --git a/src/info_intf.ml b/src/info_intf.ml new file mode 100644 index 0000000..6ed86ef --- /dev/null +++ b/src/info_intf.ml @@ -0,0 +1,148 @@ +(** [Info] is a library for lazily constructing human-readable information as a string + or sexp, with a primary use being error messages. + + Using [Info] is often preferable to [sprintf] or manually constructing strings + because you don't have to eagerly construct the string -- you only need to pay when + you actually want to display the info, which for many applications is rare. Using + [Info] is also better than creating custom exceptions because you have more control + over the format. + + Info is intended to be constructed in the following style; for simple info, you + write: + + {[Info.of_string "Unable to find file"]} + + Or for a more descriptive [Info] without attaching any content (but evaluating the + result eagerly): + + {[Info.createf "Process %s exited with code %d" process exit_code]} + + For info where you want to attach some content, you would write: + + {[Info.create "Unable to find file" filename [%sexp_of: string]]} + + Or even, + + {[ + Info.create "price too big" (price, [`Max max_price]) + [%sexp_of: float * [`Max of float]] + ]} + + Note that an [Info.t] can be created from any arbitrary sexp with [Info.t_of_sexp]. +*) + +open! Import + +module type S = sig + + (** Serialization and comparison force the lazy message. *) + type t [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Invariant_intf.S with type t := t + + (** [to_string_hum] forces the lazy message, which might be an expensive operation. + + [to_string_hum] usually produces a sexp; however, it is guaranteed that + [to_string_hum (of_string s) = s]. + + If this string is going to go into a log file, you may find it useful to ensure that + the string is only one line long. To do this, use [to_string_mach t]. *) + val to_string_hum : t -> string + + (** [to_string_mach t] outputs [t] as a sexp on a single line. *) + val to_string_mach : t -> string + + (** Old version (pre 109.61) of [to_string_hum] that some applications rely on. + + Calls should be replaced with [to_string_mach t], which outputs more parentheses and + backslashes. *) + val to_string_hum_deprecated : t -> string + + val of_string : string -> t + + (** Be careful that the body of the lazy or thunk does not access mutable data, since it + will only be called at an undetermined later point. *) + + val of_lazy : string Lazy.t -> t + val of_thunk : (unit -> string) -> t + val of_lazy_t : t Lazy.t -> t + + (** For [create message a sexp_of_a], [sexp_of_a a] is lazily computed, when the info is + converted to a sexp. So if [a] is mutated in the time between the call to [create] + and the sexp conversion, those mutations will be reflected in the sexp. Use + [~strict:()] to force [sexp_of_a a] to be computed immediately. *) + val create + : ?here : Source_code_position0.t + -> ?strict : unit + -> string + -> 'a + -> ('a -> Sexp.t) + -> t + + val create_s : Sexp.t -> t + + (** Constructs a [t] containing only a string from a format. This eagerly constructs + the string. *) + val createf : ('a, unit, string, t) format4 -> 'a + + (** Adds a string to the front. *) + val tag : t -> tag:string -> t + + (** Adds a string and some other data in the form of an s-expression at the front. *) + val tag_arg : t -> string -> 'a -> ('a -> Sexp.t) -> t + + (** Combines multiple infos into one. *) + val of_list : ?trunc_after:int -> t list -> t + + (** [of_exn] and [to_exn] are primarily used with [Error], but their definitions have to + be here because they refer to the underlying representation. + + [~backtrace:`Get] attaches the backtrace for the most recent exception. The same + caveats as for [Printexc.print_backtrace] apply. [~backtrace:(`This s)] attaches + the backtrace [s]. The default is no backtrace. *) + val of_exn : ?backtrace:[ `Get | `This of string ] -> exn -> t + val to_exn : t -> exn + + val pp : Formatter.t -> t -> unit + + module Internal_repr : sig + type info = t + + (** The internal representation. It is exposed so that we can write efficient + serializers outside of this module. *) + type t = + | Could_not_construct of Sexp.t + | String of string + | Exn of exn + | Sexp of Sexp.t + | Tag_sexp of string * Sexp.t * Source_code_position0.t option + | Tag_t of string * t + | Tag_arg of string * Sexp.t * t + | Of_list of int option * t list + | With_backtrace of t * string (** The second argument is the backtrace *) + [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + val of_info : info -> t + val to_info : t -> info + end with type info := t +end + +module type Info = sig + module type S = S + + include S +end diff --git a/src/int.ml b/src/int.ml new file mode 100644 index 0000000..eb97185 --- /dev/null +++ b/src/int.ml @@ -0,0 +1,310 @@ +open! Import + +include Int_intf +include Int0 + +module T = struct + type t = int [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = int_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_int + [@@@end] + + let compare (x : t) y = Bool.to_int (x > y) - Bool.to_int (x < y) + + let of_string s = + try of_string s with + | _ -> Printf.failwithf "Int.of_string: %S" s () + + let to_string = to_string +end + +let num_bits = Int_conversions.num_bits_int + +let float_lower_bound = Float0.lower_bound_for_int num_bits +let float_upper_bound = Float0.upper_bound_for_int num_bits + +let to_float = Caml.float_of_int +let of_float_unchecked = Caml.int_of_float +let of_float f = + if Float_replace_polymorphic_compare.(>=) f float_lower_bound + && Float_replace_polymorphic_compare.(<=) f float_upper_bound + then + Caml.int_of_float f + else + Printf.invalid_argf "Int.of_float: argument (%f) is out of range or NaN" + (Float0.box f) + () + +let zero = 0 +let one = 1 +let minus_one = -1 + +include T +include Comparator.Make(T) +include Comparable.Validate_with_zero (struct + include T + let zero = zero + end) + +module Conv = Int_conversions +include Conv.Make (T) +include Conv.Make_hex(struct + open Int_replace_polymorphic_compare + type t = int [@@deriving_inline compare, hash] + let compare : t -> t -> int = compare_int + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int in fun x -> func x + [@@@end] + + let zero = zero + let neg = (~-) + let (<) = (<) + let to_string i = Printf.sprintf "%x" i + let of_string s = Caml.Scanf.sscanf s "%x" Fn.id + + let module_name = "Base.Int.Hex" + end) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string = to_string + let module_name = "Base.Int" + end) + +(* Open replace_polymorphic_compare after including functor instantiations so + they do not shadow its definitions. This is here so that efficient versions + of the comparison functions are available within this module. *) +open! Int_replace_polymorphic_compare + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let pred i = i - 1 +let succ i = i + 1 + +let to_int i = i +let to_int_exn = to_int +let of_int i = i +let of_int_exn = of_int + +let max_value = Caml.max_int +let min_value = Caml.min_int + +let max_value_30_bits = 0x3FFF_FFFF + +let of_int32 = Conv.int32_to_int +let of_int32_exn = Conv.int32_to_int_exn +let of_int32_trunc = Conv.int32_to_int_trunc +let to_int32 = Conv.int_to_int32 +let to_int32_exn = Conv.int_to_int32_exn +let to_int32_trunc = Conv.int_to_int32_trunc +let of_int64 = Conv.int64_to_int +let of_int64_exn = Conv.int64_to_int_exn +let of_int64_trunc = Conv.int64_to_int_trunc +let to_int64 = Conv.int_to_int64 +let of_nativeint = Conv.nativeint_to_int +let of_nativeint_exn = Conv.nativeint_to_int_exn +let of_nativeint_trunc = Conv.nativeint_to_int_trunc +let to_nativeint = Conv.int_to_nativeint +let to_nativeint_exn = to_nativeint + +let abs x = abs x + +let ( + ) x y = ( + ) x y +let ( - ) x y = ( - ) x y +let ( * ) x y = ( * ) x y +let ( / ) x y = ( / ) x y + +let neg x = -x +let ( ~- ) = neg + +(* note that rem is not same as % *) +let rem a b = a mod b + +let incr = Caml.incr +let decr = Caml.decr + +let shift_right a b = a asr b +let shift_right_logical a b = a lsr b +let shift_left a b = a lsl b +let bit_not a = lnot a +let bit_or a b = a lor b +let bit_and a b = a land b +let bit_xor a b = a lxor b + +let pow = Int_math.int_pow +let ( ** ) b e = pow b e + +module Pow2 = struct + open! Import + + module Sys = Sys0 + + let raise_s = Error.raise_s + + let non_positive_argument () = + Printf.invalid_argf "argument must be strictly positive" () + + (** "ceiling power of 2" - Least power of 2 greater than or equal to x. *) + let ceil_pow2 x = + if x <= 0 then non_positive_argument (); + let x = x - 1 in + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + (* The next line is superfluous on 32-bit architectures, but it's faster to do it + anyway than to branch *) + let x = x lor (x lsr 32) in + x + 1 + + (** "floor power of 2" - Largest power of 2 less than or equal to x. *) + let floor_pow2 x = + if x <= 0 then non_positive_argument (); + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + (* The next line is superfluous on 32-bit architectures, but it's faster to do it + anyway than to branch *) + let x = x lor (x lsr 32) in + x - (x lsr 1) + + let is_pow2 x = + if x <= 0 then non_positive_argument (); + (x land (x-1)) = 0 + ;; + + (* C stub for int clz to use the CLZ/BSR instruction where possible *) + external int_clz : int -> int = "Base_int_math_int_clz" [@@noalloc] + + (** Hacker's Delight Second Edition p106 *) + let floor_log2 i = + if i <= 0 then + raise_s (Sexp.message "[Int.floor_log2] got invalid input" + ["", sexp_of_int i]); + Sys.word_size_in_bits - 1 - int_clz i + ;; + + let ceil_log2 i = + if i <= 0 then + raise_s (Sexp.message "[Int.ceil_log2] got invalid input" + ["", sexp_of_int i]); + if i = 1 + then 0 + else Sys.word_size_in_bits - int_clz (i - 1) + ;; +end +include Pow2 + +(* This is already defined by Comparable.Validate_with_zero, but Sign.of_int is + more direct. *) +let sign = Sign.of_int + +let popcount = Popcount.int_popcount + +module Pre_O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = ( ** ) + include (Int_replace_polymorphic_compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int_exn = of_int_exn +end + +module O = struct + include Pre_O + module F = Int_math.Make (struct + type nonrec t = t + include Pre_O + let rem = rem + let to_float = to_float + let of_float = of_float + let of_string = T.of_string + let to_string = T.to_string + end) + include F + + (* These inlined versions of (%), (/%), and (//) perform better than their functorized + counterparts in [F] (see benchmarks below). + + The reason these functions are inlined in [Int] but not in any of the other integer + modules is that they existed in [Int] and [Int] alone prior to the introduction of + the [Int_math.Make] functor, and we didn't want to degrade their performance. + + We won't pre-emptively do the same for new functions, unless someone cares, on a case + by case fashion. *) + + let ( % ) x y = + if y <= zero then + Printf.invalid_argf + "%s %% %s in core_int.ml: modulus should be positive" + (to_string x) (to_string y) (); + let rval = rem x y in + if rval < zero + then rval + y + else rval + ;; + + let ( /% ) x y = + if y <= zero then + Printf.invalid_argf + "%s /%% %s in core_int.ml: divisor should be positive" + (to_string x) (to_string y) (); + if x < zero + then (x + one) / y - one + else x / y + ;; + + let (//) x y = to_float x /. to_float y + ;; + + let ( land ) = ( land ) + let ( lor ) = ( lor ) + let ( lxor ) = ( lxor ) + let ( lnot ) = ( lnot ) + let ( lsl ) = ( lsl ) + let ( asr ) = ( asr ) + let ( lsr ) = ( lsr ) +end + +include O (* [Int] and [Int.O] agree value-wise *) + +module Private = struct + module O_F = O.F +end + +(* Include type-specific [Replace_polymorphic_compare] at the end, after including functor + application that could shadow its definitions. This is here so that efficient versions + of the comparison functions are exported by this module. *) +include Int_replace_polymorphic_compare diff --git a/src/int.mli b/src/int.mli new file mode 100644 index 0000000..6643709 --- /dev/null +++ b/src/int.mli @@ -0,0 +1 @@ +include Int_intf.Int (** @inline *) diff --git a/src/int0.ml b/src/int0.ml new file mode 100644 index 0000000..2306cc4 --- /dev/null +++ b/src/int0.ml @@ -0,0 +1,23 @@ +(* [Int0] defines integer functions that are primitives or can be simply + defined in terms of [Caml]. [Int0] is intended to completely express the + part of [Caml] that [Base] uses for integers -- no other file in Base other + than int0.ml should use these functions directly through [Caml]. [Int0] has + few dependencies, and so is available early in Base's build order. + + All Base files that need to use ints and come before [Base.Int] in build + order should do: + + {[ + module Int = Int0 + ]} + + Defining [module Int = Int0] is also necessary because it prevents ocamldep + from mistakenly causing a file to depend on [Base.Int]. *) + +let to_string = Caml.string_of_int +let of_string = Caml.int_of_string +let to_float = Caml.float_of_int +let of_float = Caml.int_of_float +let max_value = Caml.max_int +let min_value = Caml.min_int +let succ = Caml.succ diff --git a/src/int32.ml b/src/int32.ml new file mode 100644 index 0000000..8ecf0d6 --- /dev/null +++ b/src/int32.ml @@ -0,0 +1,274 @@ +open! Import +open! Caml.Int32 + +module T = struct + type t = int32 [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int32 + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int32 in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = int32_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_int32 + [@@@end] + let compare (x : t) y = compare x y + + let to_string = to_string + let of_string = of_string +end + +include T +include Comparator.Make(T) + +let num_bits = 32 + +let float_lower_bound = Float0.lower_bound_for_int num_bits +let float_upper_bound = Float0.upper_bound_for_int num_bits + +let float_of_bits = float_of_bits +let bits_of_float = bits_of_float +let shift_right_logical = shift_right_logical +let shift_right = shift_right +let shift_left = shift_left +let bit_not = lognot +let bit_xor = logxor +let bit_or = logor +let bit_and = logand +let min_value = min_int +let max_value = max_int +let abs = abs +let pred = pred +let succ = succ +let rem = rem +let neg = neg +let minus_one = minus_one +let one = one +let zero = zero +let compare = compare +let to_float = to_float +let of_float_unchecked = of_float +let of_float f = + if Float_replace_polymorphic_compare.(>=) f float_lower_bound + && Float_replace_polymorphic_compare.(<=) f float_upper_bound + then + of_float f + else + Printf.invalid_argf "Int32.of_float: argument (%f) is out of range or NaN" + (Float0.box f) + () +;; + +include Comparable.Validate_with_zero (struct + include T + let zero = zero + end) + +module Infix_compare = struct + open Poly + + let ( >= ) (x : t) y = x >= y + let ( <= ) (x : t) y = x <= y + let ( = ) (x : t) y = x = y + let ( > ) (x : t) y = x > y + let ( < ) (x : t) y = x < y + let ( <> ) (x : t) y = x <> y +end + +module Compare = struct + include Infix_compare + + let compare = compare + let ascending = compare + let descending x y = compare y x + let min (x : t) y = if x < y then x else y + let max (x : t) y = if x > y then x else y + let equal (x : t) y = x = y + let between t ~low ~high = low <= t && t <= high + let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + + let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + + let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) +end + +include Compare + +let ( / ) = div +let ( * ) = mul +let ( - ) = sub +let ( + ) = add +let ( ~- ) = neg + +let incr r = r := !r + one +let decr r = r := !r - one + +let of_int32 t = t +let of_int32_exn = of_int32 +let to_int32 t = t +let to_int32_exn = to_int32 + +let popcount = Popcount.int32_popcount + +module Conv = Int_conversions +let of_int = Conv.int_to_int32 +let of_int_exn = Conv.int_to_int32_exn +let of_int_trunc = Conv.int_to_int32_trunc +let to_int = Conv.int32_to_int +let to_int_exn = Conv.int32_to_int_exn +let to_int_trunc = Conv.int32_to_int_trunc +let of_int64 = Conv.int64_to_int32 +let of_int64_exn = Conv.int64_to_int32_exn +let of_int64_trunc = Conv.int64_to_int32_trunc +let to_int64 = Conv.int32_to_int64 +let of_nativeint = Conv.nativeint_to_int32 +let of_nativeint_exn = Conv.nativeint_to_int32_exn +let of_nativeint_trunc = Conv.nativeint_to_int32_trunc +let to_nativeint = Conv.int32_to_nativeint +let to_nativeint_exn = to_nativeint + +let pow b e = of_int_exn (Int_math.int_pow (to_int_exn b) (to_int_exn e)) +let ( ** ) b e = pow b e + +module Pow2 = struct + open! Import + open Int32_replace_polymorphic_compare + + module Sys = Sys0 + + let raise_s = Error.raise_s + + let non_positive_argument () = + Printf.invalid_argf "argument must be strictly positive" () + + let ( lor ) = Caml.Int32.logor;; + let ( lsr ) = Caml.Int32.shift_right_logical;; + let ( land ) = Caml.Int32.logand;; + + (** "ceiling power of 2" - Least power of 2 greater than or equal to x. *) + let ceil_pow2 x = + if x <= Caml.Int32.zero then non_positive_argument (); + let x = Caml.Int32.pred x in + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + Caml.Int32.succ x + ;; + + (** "floor power of 2" - Largest power of 2 less than or equal to x. *) + let floor_pow2 x = + if x <= Caml.Int32.zero then non_positive_argument (); + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + Caml.Int32.sub x (x lsr 1) + ;; + + let is_pow2 x = + if x <= Caml.Int32.zero then non_positive_argument (); + (x land (Caml.Int32.pred x)) = Caml.Int32.zero + ;; + + (* C stub for int clz to use the CLZ/BSR instruction where possible. *) + external int32_clz : int32 -> int = "Base_int_math_int32_clz" [@@noalloc] + + (** Hacker's Delight Second Edition p106 *) + let floor_log2 i = + if i <= Caml.Int32.zero then + raise_s (Sexp.message "[Int32.floor_log2] got invalid input" + ["", sexp_of_int32 i]); + num_bits - 1 - int32_clz i + ;; + + (** Hacker's Delight Second Edition p106 *) + let ceil_log2 i = + if i <= Caml.Int32.zero then + raise_s (Sexp.message "[Int32.ceil_log2] got invalid input" + ["", sexp_of_int32 i]); + (* The [i = 1] check is needed because clz(0) is undefined *) + if Caml.Int32.equal i Caml.Int32.one + then 0 + else num_bits - int32_clz (Caml.Int32.pred i) + ;; +end +include Pow2 + +include Conv.Make (T) + +include Conv.Make_hex(struct + + type t = int32 [@@deriving_inline compare, hash] + let compare : t -> t -> int = compare_int32 + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int32 + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int32 in fun x -> func x + [@@@end] + + let zero = zero + let neg = (~-) + let (<) = (<) + let to_string i = Printf.sprintf "%lx" i + let of_string s = Caml.Scanf.sscanf s "%lx" Fn.id + + let module_name = "Base.Int32.Hex" + + end) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string = to_string + let module_name = "Base.Int32" + end) + +module Pre_O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = ( ** ) + include (Compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int_exn = of_int_exn +end + +module O = struct + include Pre_O + include Int_math.Make (struct + type nonrec t = t + include Pre_O + let rem = rem + let to_float = to_float + let of_float = of_float + let of_string = T.of_string + let to_string = T.to_string + end) + + let ( land ) = bit_and + let ( lor ) = bit_or + let ( lxor ) = bit_xor + let ( lnot ) = bit_not + let ( lsl ) = shift_left + let ( asr ) = shift_right + let ( lsr ) = shift_right_logical +end + +include O (* [Int32] and [Int32.O] agree value-wise *) diff --git a/src/int32.mli b/src/int32.mli new file mode 100644 index 0000000..bbefbd1 --- /dev/null +++ b/src/int32.mli @@ -0,0 +1,48 @@ +(** An int of exactly 32 bits, regardless of the machine. + + Side note: There's not much reason to want an int of at least 32 bits (i.e., 32 on + 32-bit machines and 63 on 64-bit machines) because [Int63] is basically just as + efficient. + + Overflow issues are {i not} generally considered and explicitly handled. This may be + more of an issue for 32-bit ints than 64-bit ints. + + [Int32.t] is boxed on both 32-bit and 64-bit machines. *) + +open! Import + +include Int_intf.S with type t = int32 + +(** {2 Conversion functions} *) + +val of_int : int -> t option +val to_int : t -> int option + +val of_int32 : int32 -> t +val to_int32 : t -> int32 + +val of_nativeint : nativeint -> t option +val to_nativeint : t -> nativeint + +val of_int64 : int64 -> t option + +(** {3 Truncating conversions} + + These functions return the least-significant bits of the input. In cases where + optional conversions return [Some x], truncating conversions return [x]. *) + +val of_int_trunc : int -> t +val to_int_trunc : t -> int +val of_nativeint_trunc : nativeint -> t +val of_int64_trunc : int64 -> t + +(** {3 Low-level float conversions} *) + +(** Rounds a regular 64-bit OCaml float to a 32-bit IEEE-754 "single" float, and returns + its bit representation. We make no promises about the exact rounding behavior, or + what happens in case of over- or underflow. *) +val bits_of_float : float -> t + +(** Creates a 32-bit IEEE-754 "single" float from the given bits, and converts it to a + regular 64-bit OCaml float. *) +val float_of_bits : t -> float diff --git a/src/int63.ml b/src/int63.ml new file mode 100644 index 0000000..ac40b85 --- /dev/null +++ b/src/int63.ml @@ -0,0 +1,77 @@ +open! Import + +let raise_s = Error.raise_s +module Repr = Int63_emul.Repr + +include Int63_backend + +module Overflow_exn = struct + let ( + ) t u = + let sum = t + u in + if bit_or (bit_xor t u) (bit_xor t (bit_not sum)) < zero + then sum + else raise_s (Sexp.message "( + ) overflow" + [ "t" , sexp_of_t t + ; "u" , sexp_of_t u + ; "sum", sexp_of_t sum + ]) + ;; + + let ( - ) t u = + let diff = t - u in + let pos_diff = t > u in + if t <> u && Bool.(<>) pos_diff (is_positive diff) then + raise_s (Sexp.message "( - ) overflow" + [ "t" , sexp_of_t t + ; "u" , sexp_of_t u + ; "diff", sexp_of_t diff + ]) + else diff + ;; + + let abs t = if t = min_value then failwith "abs overflow" else abs t + let neg t = if t = min_value then failwith "neg overflow" else neg t +end + +let () = assert (Int.(=) num_bits 63) + +let random_of_int ?(state = Random.State.default) bound = + of_int (Random.State.int state (to_int_exn bound)) + +let random_of_int64 ?(state = Random.State.default) bound = + of_int64_exn (Random.State.int64 state (to_int64 bound)) + +let random = + match Word_size.word_size with + | W64 -> random_of_int + | W32 -> random_of_int64 + +let random_incl_of_int ?(state = Random.State.default) lo hi = + of_int (Random.State.int_incl state (to_int_exn lo) (to_int_exn hi)) + +let random_incl_of_int64 ?(state = Random.State.default) lo hi = + of_int64_exn (Random.State.int64_incl state (to_int64 lo) (to_int64 hi)) + +let random_incl = + match Word_size.word_size with + | W64 -> random_incl_of_int + | W32 -> random_incl_of_int64 + +let floor_log2 t = + match Word_size.word_size with + | W64 -> t |> to_int_exn |> Int.floor_log2 + | W32 -> + if t <= zero + then raise_s (Sexp.message "[Int.floor_log2] got invalid input" + ["", sexp_of_t t]); + let floor_log2 = ref (Int.( - ) num_bits 2) in + while equal zero (bit_and t (shift_left one !floor_log2)) do + floor_log2 := Int.( - ) !floor_log2 1 + done; + !floor_log2 +;; + +module Private = struct + module Repr = Repr + let repr = repr +end diff --git a/src/int63.mli b/src/int63.mli new file mode 100644 index 0000000..6a5c3d8 --- /dev/null +++ b/src/int63.mli @@ -0,0 +1,86 @@ +(** 63-bit integers. + + The size of Int63 is always 63 bits. On a 64-bit platform it is just an int + (63-bits), and on a 32-bit platform it is an int64 wrapped to respect the + semantics of 63-bit integers. + + Because [Int63] has different representations on 32-bit and 64-bit platforms, + marshalling [Int63] will not work between 32-bit and 64-bit platforms -- [unmarshal] + will segfault. *) + +open! Import + +(** In 64-bit architectures, we expose [type t = private int] so that the compiler can + omit [caml_modify] when dealing with record fields holding [Int63.t]. + + Code should not explicitly make use of the [private], e.g., via [(i :> int)], since + such code will not compile on 32-bit platforms. *) +include Int_intf.S with type t = Int63_backend.t + +(** {2 Arithmetic with overflow} + + Unlike the usual operations, these never overflow, preferring instead to raise. *) + +module Overflow_exn : sig + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val abs : t -> t + val neg : t -> t +end + +(** {2 Conversion functions} *) + +val of_int : int -> t +val to_int : t -> int option + +val of_int32 : Int32.t -> t +val to_int32 : t -> Int32.t option + +val of_int64 : Int64.t -> t option + +val of_nativeint : nativeint -> t option +val to_nativeint : t -> nativeint option + +(** {3 Truncating conversions} + + These functions return the least-significant bits of the input. In cases where + optional conversions return [Some x], truncating conversions return [x]. *) + +val to_int_trunc : t -> int +val to_int32_trunc : t -> Int32.t +val of_int64_trunc : Int64.t -> t +val of_nativeint_trunc : nativeint -> t +val to_nativeint_trunc : t -> nativeint + +(** {2 Random generation} *) + +(** [random ~state bound] returns a random integer between 0 (inclusive) and [bound] + (exclusive). [bound] must be greater than 0. + + The default [~state] is [Random.State.default]. *) +val random : ?state:Random.State.t -> t -> t + +(** [random_incl ~state lo hi] returns a random integer between [lo] (inclusive) and [hi] + (inclusive). Raises if [lo > hi]. + + The default [~state] is [Random.State.default]. *) +val random_incl : ?state:Random.State.t -> t -> t -> t + +(** [floor_log2 x] returns the floor of log-base-2 of [x], and raises if [x <= 0]. *) +val floor_log2 : t -> int + +(**/**) +(*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) +module Private : sig + (** [val repr] states how [Int63.t] is represented, i.e., as an [int] or an [int64], and + can be used for building [Int63] operations that behave differently depending on the + representation (e.g., see core_int63.ml). *) + module Repr : sig + type ('underlying_type, 'intermediate_type) t = + | Int : (int , int ) t + | Int64 : (int64 , Int63_emul.t) t + end + val repr : (t, t) Repr.t +end diff --git a/src/int63_backends.ml b/src/int63_backends.ml new file mode 100644 index 0000000..5efde7d --- /dev/null +++ b/src/int63_backends.ml @@ -0,0 +1,50 @@ +open! Import + +let raise_s = Error.raise_s + +module type Int_or_more = sig + type t [@@deriving_inline hash] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + end[@@ocaml.doc "@inline"] + [@@@end] + include Int_intf.S with type t := t + val of_int : int -> t + val to_int : t -> int option + val to_int_trunc : t -> int + val of_int32 : int32 -> t + val to_int32 : t -> Int32.t option + val to_int32_trunc : t -> Int32.t + val of_int64 : Int64.t -> t option + val of_int64_trunc : Int64.t -> t + val of_nativeint : nativeint -> t option + val to_nativeint : t -> nativeint option + val of_nativeint_trunc : nativeint -> t + val to_nativeint_trunc : t -> nativeint + val of_float_unchecked : float -> t + val repr : (t, t) Int63_emul.Repr.t +end + +module Native : Int_or_more with type t = private int = struct + include Int + let to_int x = Some x + let to_int_trunc x = x + (* [of_int32_exn] is a safe operation on platforms with 64-bit word sizes. *) + let of_int32 = of_int32_exn + let to_nativeint_trunc x = to_nativeint x + let to_nativeint x = Some (to_nativeint x) + let repr = Int63_emul.Repr.Int +end + +module Emulated : Int_or_more with type t = Int63_emul.t = Int63_emul + +let dynamic : (module Int_or_more) = + match Word_size.word_size with + | W64 -> (module Native : Int_or_more) + | W32 -> (module Emulated : Int_or_more) + +module Dynamic = (val dynamic) diff --git a/src/int63_emul.ml b/src/int63_emul.ml new file mode 100644 index 0000000..d91e3c9 --- /dev/null +++ b/src/int63_emul.ml @@ -0,0 +1,399 @@ +(* A 63bit integer is a 64bit integer with its bits shifted to the left + and its lowest bit set to 0. + This is the same kind of encoding as OCaml int on 64bit architecture. + The only difference being the lowest bit (immediate bit) set to 1. *) + +open! Import +include Int64_replace_polymorphic_compare + +module T0 = struct + module T = struct + type t = int64 [@@deriving_inline compare, hash, sexp] + let compare : t -> t -> int = compare_int64 + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int64 + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int64 in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = int64_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_int64 + [@@@end] + end + include T + include Comparator.Make(T) +end + +module Conv = Int_conversions + +module W : sig + type t = int64 + include (module type of struct include T0 end with type t := t) + val wrap_exn : Caml.Int64.t -> t + val wrap_modulo : Caml.Int64.t -> t + val unwrap : t -> Caml.Int64.t + + (** Returns a non-negative int64 that is equal to the input int63 modulo 2^63. *) + val unwrap_unsigned : t -> Caml.Int64.t + val add : t -> t -> t + val sub : t -> t -> t + val neg : t -> t + val abs : t -> t + val succ : t -> t + val pred : t -> t + val mul : t -> t -> t + val pow : t -> t -> t + val div : t -> t -> t + val rem : t -> t -> t + val popcount : t -> int + val bit_not : t -> t + val bit_xor : t -> t -> t + val bit_or : t -> t -> t + val bit_and : t -> t -> t + val shift_left : t -> int -> t + val shift_right : t -> int -> t + val shift_right_logical : t -> int -> t + val min_value : t + val max_value : t + + val to_int64 : t -> Caml.Int64.t + val of_int64 : Caml.Int64.t -> t option + val of_int64_exn : Caml.Int64.t -> t + val of_int64_trunc : Caml.Int64.t -> t + + val compare : t -> t -> int + + val ceil_pow2 : t -> t + val floor_pow2 : t -> t + val ceil_log2 : t -> int + val floor_log2 : t -> int + val is_pow2 : t -> bool +end = struct + type t = int64 + include (T0 : module type of struct include T0 end with type t := t) + + let wrap_exn x = + (* Raises if the int64 value does not fit on int63. *) + Conv.int64_fit_on_int63_exn x; + Caml.Int64.mul x 2L + let wrap x = + if Conv.int64_is_representable_as_int63 x + then Some (Caml.Int64.mul x 2L) + else None + let wrap_modulo x = + Caml.Int64.mul x 2L + let unwrap x = + Caml.Int64.shift_right x 1 + let unwrap_unsigned x = + Caml.Int64.shift_right_logical x 1 + + (* This does not use wrap or unwrap to avoid generating exceptions in the case of + overflows. This is to preserve the semantics of int type on 64 bit architecture. *) + let f2 f a b = Caml.Int64.mul (f (Caml.Int64.shift_right a 1) (Caml.Int64.shift_right b 1)) 2L + + let mask = 0xffff_ffff_ffff_fffeL + + let m x = Caml.Int64.logand x mask + + let add x y = Caml.Int64.add x y + let sub x y = Caml.Int64.sub x y + let neg x = Caml.Int64.neg x + let abs x = Caml.Int64.abs x + let one = wrap_exn 1L + let succ a = add a one + let pred a = sub a one + let min_value = m Caml.Int64.min_int + let max_value = m Caml.Int64.max_int + let bit_not x = m (Caml.Int64.lognot x) + let bit_and = Caml.Int64.logand + let bit_xor = Caml.Int64.logxor + let bit_or = Caml.Int64.logor + let shift_left x i = Caml.Int64.shift_left x i + let shift_right x i = m (Caml.Int64.shift_right x i) + let shift_right_logical x i = m (Caml.Int64.shift_right_logical x i) + let pow = f2 Int_math.int63_pow_on_int64 + let mul a b = Caml.Int64.mul a (Caml.Int64.shift_right b 1) + let div a b = wrap_modulo (Caml.Int64.div a b) + let rem a b = Caml.Int64.rem a b + let popcount x = Popcount.int64_popcount x + + let to_int64 t = unwrap t + let of_int64 t = wrap t + let of_int64_exn t = wrap_exn t + let of_int64_trunc t = wrap_modulo t + + let t_of_sexp x = wrap_exn (int64_of_sexp x) + let sexp_of_t x = sexp_of_int64 (unwrap x) + + let compare (x : t) y = compare x y + + let is_pow2 x = Int64.is_pow2 (unwrap x) + let floor_pow2 x = Int64.floor_pow2 (unwrap x) |> wrap_exn + let ceil_pow2 x = Int64.floor_pow2 (unwrap x) |> wrap_exn + let floor_log2 x = Int64.floor_log2 (unwrap x) + let ceil_log2 x = Int64.ceil_log2 (unwrap x) +end + +open W + +module T = struct + type t = W.t [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + W.hash_fold_t + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = W.hash in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = W.t_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = W.sexp_of_t + [@@@end] + type comparator_witness = W.comparator_witness + let comparator = W.comparator + + let compare = W.compare + + (* We don't expect [hash] to follow the behavior of int in 64bit architecture *) + let _ = hash + let hash (x : t) = Caml.Hashtbl.hash x + + let invalid_str x = failwith (Printf.sprintf "Int63.of_string: invalid input %S" x) + + (* + "sign" refers to whether the number starts with a '-' + "signedness = false" means the rest of the number is parsed as unsigned and then cast + to signed with wrap-around modulo 2^i + "signedness = true" means no such craziness happens + + The terminology and the logic is due to the code in byterun/ints.c in ocaml 4.03 + ([parse_sign_and_base] function). + + Signedness equals true for plain decimal number (e.g. 1235, -6789) + + Signedness equals false in the following cases: + - [0xffff], [-0xffff] (hexadecimal representation) + - [0b0101], [-0b0101] (binary representation) + - [0o1237], [-0o1237] (octal representation) + - [0u9812], [-0u9812] (unsigned decimal representation - available from OCaml 4.03) *) + let sign_and_signedness x = + let len = String.length x in + let open Int_replace_polymorphic_compare in + let pos,sign = + if 0 < len + then match x.[0] with + | '-' -> 1,`Neg + | '+' -> 1, `Pos + | _ -> 0, `Pos + else + 0, `Pos + in + if pos + 2 < len then + let c1 = x.[pos] in + let c2 = x.[pos + 1] in + match c1, c2 with + | '0', ('0' .. '9') -> sign,true + | '0', _ -> sign,false + | _ -> sign,true + else sign, true + + let to_string x = Caml.Int64.to_string (unwrap x) + + let of_string str = + try + let sign,signedness = sign_and_signedness str in + if signedness + then of_int64_exn (Caml.Int64.of_string str) + else + let pos_str = + match sign with + | `Neg -> String.sub str ~pos:1 ~len:(String.length str - 1) + | `Pos -> str + in + let int64 = Caml.Int64.of_string pos_str in + (* unsigned 63-bit int must parse as a positive signed 64-bit int *) + if Int64_replace_polymorphic_compare.(<) int64 0L then invalid_str str; + let int63 = wrap_modulo int64 in + match sign with + | `Neg -> neg int63 + | `Pos -> int63 + with _ -> invalid_str str +end + +include T + +let num_bits = 63 + +let float_lower_bound = Float0.lower_bound_for_int num_bits +let float_upper_bound = Float0.upper_bound_for_int num_bits + +let shift_right_logical = shift_right_logical +let shift_right = shift_right +let shift_left = shift_left +let bit_not = bit_not +let bit_xor = bit_xor +let bit_or = bit_or +let bit_and = bit_and +let popcount = popcount +let abs = abs +let pred = pred +let succ = succ +let pow = pow +let rem = rem +let neg = neg +let max_value = max_value +let min_value = min_value +let minus_one = wrap_exn Caml.Int64.minus_one +let one = wrap_exn Caml.Int64.one +let zero = wrap_exn Caml.Int64.zero +let is_pow2 = is_pow2 +let floor_pow2 = floor_pow2 +let ceil_pow2 = ceil_pow2 +let floor_log2 = floor_log2 +let ceil_log2 = ceil_log2 +let to_float x = Caml.Int64.to_float (unwrap x) +let of_float_unchecked x = wrap_modulo (Caml.Int64.of_float x) +let of_float t = + let open Float_replace_polymorphic_compare in + if t >= float_lower_bound && t <= float_upper_bound then + wrap_modulo (Caml.Int64.of_float t) + else + Printf.invalid_argf "Int63.of_float: argument (%f) is out of range or NaN" + (Float0.box t) + () +let of_int64 = of_int64 +let of_int64_exn = of_int64_exn +let of_int64_trunc = of_int64_trunc +let to_int64 = to_int64 + +include Comparable.Validate_with_zero (struct + include T + let zero = zero + end) + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let ( / ) = div +let ( * ) = mul +let ( - ) = sub +let ( + ) = add +let ( ~- ) = neg + +let ( ** ) b e = pow b e + +let incr r = r := !r + one +let decr r = r := !r - one + +(* We can reuse conversion function from/to int64 here. *) +let of_int x = wrap_exn (Conv.int_to_int64 x) +let of_int_exn x = of_int x +let to_int x = Conv.int64_to_int (unwrap x) +let to_int_exn x = Conv.int64_to_int_exn (unwrap x) +let to_int_trunc x = Conv.int64_to_int_trunc (unwrap x) + +let of_int32 x = wrap_exn (Conv.int32_to_int64 x) +let of_int32_exn x = of_int32 x +let to_int32 x = Conv.int64_to_int32 (unwrap x) +let to_int32_exn x = Conv.int64_to_int32_exn (unwrap x) +let to_int32_trunc x = Conv.int64_to_int32_trunc (unwrap x) + +let of_nativeint x = of_int64 (Conv.nativeint_to_int64 x) +let of_nativeint_exn x = wrap_exn (Conv.nativeint_to_int64 x) +let of_nativeint_trunc x = of_int64_trunc (Conv.nativeint_to_int64 x) +let to_nativeint x = Conv.int64_to_nativeint (unwrap x) +let to_nativeint_exn x = Conv.int64_to_nativeint_exn (unwrap x) +let to_nativeint_trunc x = Conv.int64_to_nativeint_trunc (unwrap x) + +include Conv.Make (T) + +include Conv.Make_hex(struct + + type t = T.t [@@deriving_inline compare, hash] + let compare : t -> t -> int = T.compare + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + T.hash_fold_t + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = T.hash in fun x -> func x + [@@@end] + + let zero = zero + let neg = (~-) + let (<) = (<) + let to_string i = + (* the use of [unwrap_unsigned] here is important for the case of [min_value] *) + Printf.sprintf "%Lx" (unwrap_unsigned i) + let of_string s = of_string ("0x"^s) + let module_name = "Base.Int63.Hex" + + end) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string x = to_string x + let module_name = "Base.Int63" + end) + +module Pre_O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = ( ** ) + include (Int64_replace_polymorphic_compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int_exn = of_int_exn +end + +module O = struct + include Pre_O + include Int_math.Make (struct + type nonrec t = t + include Pre_O + let rem = rem + let to_float = to_float + let of_float = of_float + let of_string = T.of_string + let to_string = T.to_string + end) + + let ( land ) = bit_and + let ( lor ) = bit_or + let ( lxor ) = bit_xor + let ( lnot ) = bit_not + let ( lsl ) = shift_left + let ( asr ) = shift_right + let ( lsr ) = shift_right_logical +end + +include O (* [Int63] and [Int63.O] agree value-wise *) + +module Repr = struct + type emulated = t + type ('underlying_type, 'intermediate_type) t = + | Int : (int , int ) t + | Int64 : (int64 , emulated) t +end + +let repr = Repr.Int64 + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Int64_replace_polymorphic_compare diff --git a/src/int63_emul.mli b/src/int63_emul.mli new file mode 100644 index 0000000..2c159fb --- /dev/null +++ b/src/int63_emul.mli @@ -0,0 +1,34 @@ +open! Import + +include Int_intf.S + +val of_int : int -> t +val to_int : t -> int option +val to_int_trunc : t -> int + +val of_int32 : int32 -> t +val to_int32 : t -> Int32.t option +val to_int32_trunc : t -> Int32.t + +val of_int64 : Int64.t -> t option +val of_int64_trunc : Int64.t -> t + +val of_nativeint : nativeint -> t option +val to_nativeint : t -> nativeint option +val of_nativeint_trunc : nativeint -> t +val to_nativeint_trunc : t -> nativeint + +(*_ exported for Core_kernel *) +module W : sig + val wrap_exn : int64 -> t + val unwrap : t -> int64 +end + +module Repr : sig + type emulated = t + type ('underlying_type, 'intermediate_type) t = + | Int : (int , int ) t + | Int64 : (int64 , emulated) t +end with type emulated := t + +val repr : (t, t) Repr.t diff --git a/src/int64.ml b/src/int64.ml new file mode 100644 index 0000000..3023a4d --- /dev/null +++ b/src/int64.ml @@ -0,0 +1,260 @@ +open! Import +open! Caml.Int64 + +module T = struct + type t = int64 [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int64 + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int64 in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = int64_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_int64 + [@@@end] + + let compare = Int64_replace_polymorphic_compare.compare + + let to_string = to_string + let of_string = of_string +end + +include T +include Comparator.Make(T) + +let num_bits = 64 +let float_lower_bound = Float0.lower_bound_for_int num_bits +let float_upper_bound = Float0.upper_bound_for_int num_bits + +let float_of_bits = float_of_bits +let bits_of_float = bits_of_float +let shift_right_logical = shift_right_logical +let shift_right = shift_right +let shift_left = shift_left +let bit_not = lognot +let bit_xor = logxor +let bit_or = logor +let bit_and = logand +let min_value = min_int +let max_value = max_int +let abs = abs +let pred = pred +let succ = succ +let pow = Int_math.int64_pow +let rem = rem +let neg = neg +let minus_one = minus_one +let one = one +let zero = zero +let to_float = to_float +let of_float_unchecked = Caml.Int64.of_float +let of_float f = + if Float_replace_polymorphic_compare.(>=) f float_lower_bound + && Float_replace_polymorphic_compare.(<=) f float_upper_bound + then + Caml.Int64.of_float f + else + Printf.invalid_argf "Int64.of_float: argument (%f) is out of range or NaN" + (Float0.box f) + () + +let ( ** ) b e = pow b e + +include Comparable.Validate_with_zero (struct + include T + let zero = zero + end) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open Int64_replace_polymorphic_compare + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let ( / ) = div +let ( * ) = mul +let ( - ) = sub +let ( + ) = add +let ( ~- ) = neg + +let incr r = r := !r + one +let decr r = r := !r - one + +let of_int64 t = t +let of_int64_exn = of_int64 +let to_int64 t = t + +let popcount = Popcount.int64_popcount + +module Conv = Int_conversions +let of_int = Conv.int_to_int64 +let of_int_exn = of_int +let to_int = Conv.int64_to_int +let to_int_exn = Conv.int64_to_int_exn +let to_int_trunc = Conv.int64_to_int_trunc +let of_int32 = Conv.int32_to_int64 +let of_int32_exn = of_int32 +let to_int32 = Conv.int64_to_int32 +let to_int32_exn = Conv.int64_to_int32_exn +let to_int32_trunc = Conv.int64_to_int32_trunc +let of_nativeint = Conv.nativeint_to_int64 +let of_nativeint_exn = of_nativeint +let to_nativeint = Conv.int64_to_nativeint +let to_nativeint_exn = Conv.int64_to_nativeint_exn +let to_nativeint_trunc = Conv.int64_to_nativeint_trunc + +module Pow2 = struct + open! Import + open Int64_replace_polymorphic_compare + + module Sys = Sys0 + + let raise_s = Error.raise_s + + let non_positive_argument () = + Printf.invalid_argf "argument must be strictly positive" () + + let ( lor ) = Caml.Int64.logor;; + let ( lsr ) = Caml.Int64.shift_right_logical;; + let ( land ) = Caml.Int64.logand;; + + (** "ceiling power of 2" - Least power of 2 greater than or equal to x. *) + let ceil_pow2 x = + if x <= Caml.Int64.zero then non_positive_argument (); + let x = Caml.Int64.pred x in + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + let x = x lor (x lsr 32) in + Caml.Int64.succ x + ;; + + (** "floor power of 2" - Largest power of 2 less than or equal to x. *) + let floor_pow2 x = + if x <= Caml.Int64.zero then non_positive_argument (); + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + let x = x lor (x lsr 32) in + Caml.Int64.sub x (x lsr 1) + ;; + + let is_pow2 x = + if x <= Caml.Int64.zero then non_positive_argument (); + (x land (Caml.Int64.pred x)) = Caml.Int64.zero + ;; + + (* C stub for int clz to use the CLZ/BSR instruction where possible *) + external int64_clz : int64 -> int = "Base_int_math_int64_clz" [@@noalloc] + + (** Hacker's Delight Second Edition p106 *) + let floor_log2 i = + if i <= Caml.Int64.zero then + raise_s (Sexp.message "[Int64.floor_log2] got invalid input" + ["", sexp_of_int64 i]); + num_bits - 1 - int64_clz i + ;; + + (** Hacker's Delight Second Edition p106 *) + let ceil_log2 i = + if Poly.( <= ) i Caml.Int64.zero then + raise_s (Sexp.message "[Int64.ceil_log2] got invalid input" + ["", sexp_of_int64 i]); + if Caml.Int64.equal i Caml.Int64.one + then 0 + else num_bits - int64_clz (Caml.Int64.pred i) + ;; +end +include Pow2 + +include Conv.Make (T) + +include Conv.Make_hex(struct + + type t = int64 [@@deriving_inline compare, hash] + let compare : t -> t -> int = compare_int64 + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int64 + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int64 in fun x -> func x + [@@@end] + + let zero = zero + let neg = (~-) + let (<) = (<) + let to_string i = Printf.sprintf "%Lx" i + let of_string s = Caml.Scanf.sscanf s "%Lx" Fn.id + + let module_name = "Base.Int64.Hex" + + end) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string = to_string + let module_name = "Base.Int64" + end) + +module Pre_O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = ( ** ) + include (Int64_replace_polymorphic_compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int_exn = of_int_exn +end + +module O = struct + include Pre_O + include Int_math.Make (struct + type nonrec t = t + include Pre_O + let rem = rem + let to_float = to_float + let of_float = of_float + let of_string = T.of_string + let to_string = T.to_string + end) + + let ( land ) = bit_and + let ( lor ) = bit_or + let ( lxor ) = bit_xor + let ( lnot ) = bit_not + let ( lsl ) = shift_left + let ( asr ) = shift_right + let ( lsr ) = shift_right_logical +end + +include O (* [Int64] and [Int64.O] agree value-wise *) + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Int64_replace_polymorphic_compare diff --git a/src/int64.mli b/src/int64.mli new file mode 100644 index 0000000..ffbf26c --- /dev/null +++ b/src/int64.mli @@ -0,0 +1,32 @@ +(** 64-bit integers. *) + +open! Import + +include Int_intf.S with type t = int64 + +(** {2 Conversion functions} *) + +val of_int : int -> t +val to_int : t -> int option + +val of_int32 : int32 -> t +val to_int32 : t -> int32 option + +val of_nativeint : nativeint -> t +val to_nativeint : t -> nativeint option + +val of_int64 : t -> t + +(** {3 Truncating conversions} + + These functions return the least-significant bits of the input. In cases where + optional conversions return [Some x], truncating conversions return [x]. *) + +val to_int_trunc : t -> int +val to_int32_trunc : t -> int32 +val to_nativeint_trunc : t -> nativeint + +(** {3 Low-level float conversions} *) + +val bits_of_float : float -> t +val float_of_bits : t -> float diff --git a/src/int_conversions.ml b/src/int_conversions.ml new file mode 100644 index 0000000..5a34845 --- /dev/null +++ b/src/int_conversions.ml @@ -0,0 +1,359 @@ +open! Import + +module Int = Int0 +module Sys = Sys0 + +let [@inline never] convert_failure x a b to_string = + Printf.failwithf + "conversion from %s to %s failed: %s is out of range" + a + b + (to_string x) + () + +let num_bits_int = Sys.int_size_in_bits +let num_bits_int32 = 32 +let num_bits_int64 = 64 +let num_bits_nativeint = Word_size.num_bits Word_size.word_size + +let () = + assert (num_bits_int = 63 + || num_bits_int = 31 + || num_bits_int = 32) + +let min_int32 = Caml.Int32.min_int +let max_int32 = Caml.Int32.max_int +let min_int64 = Caml.Int64.min_int +let max_int64 = Caml.Int64.max_int +let min_nativeint = Caml.Nativeint.min_int +let max_nativeint = Caml.Nativeint.max_int + +let int_to_string = Caml.string_of_int +let int32_to_string = Caml.Int32.to_string +let int64_to_string = Caml.Int64.to_string +let nativeint_to_string = Caml.Nativeint.to_string + +(* int <-> int32 *) + +let int_to_int32_failure x = convert_failure x "int" "int32" int_to_string +let int32_to_int_failure x = convert_failure x "int32" "int" int32_to_string + +let int32_to_int_trunc = Caml.Int32.to_int +let int_to_int32_trunc = Caml.Int32.of_int + +let int_is_representable_as_int32 = + if num_bits_int <= num_bits_int32 + then (fun _ -> true) + else + let min = int32_to_int_trunc min_int32 in + let max = int32_to_int_trunc max_int32 in + (fun x -> compare_int min x <= 0 && compare_int x max <= 0) + +let int32_is_representable_as_int = + if num_bits_int32 <= num_bits_int + then (fun _ -> true) + else + let min = int_to_int32_trunc Int.min_value in + let max = int_to_int32_trunc Int.max_value in + (fun x -> compare_int32 min x <= 0 && compare_int32 x max <= 0) + +let int_to_int32 x = + if int_is_representable_as_int32 x + then Some (int_to_int32_trunc x) + else None + +let int32_to_int x = + if int32_is_representable_as_int x + then Some (int32_to_int_trunc x) + else None + +let int_to_int32_exn x = + if int_is_representable_as_int32 x + then int_to_int32_trunc x + else int_to_int32_failure x + +let int32_to_int_exn x = + if int32_is_representable_as_int x + then int32_to_int_trunc x + else int32_to_int_failure x + +(* int <-> int64 *) + +let int64_to_int_failure x = convert_failure x "int64" "int" int64_to_string + +let () = assert (num_bits_int < num_bits_int64) + +let int_to_int64 = Caml.Int64.of_int +let int64_to_int_trunc = Caml.Int64.to_int + +let int64_is_representable_as_int = + let min = int_to_int64 Int.min_value in + let max = int_to_int64 Int.max_value in + (fun x -> compare_int64 min x <= 0 && compare_int64 x max <= 0) + +let int64_to_int x = + if int64_is_representable_as_int x + then Some (int64_to_int_trunc x) + else None + +let int64_to_int_exn x = + if int64_is_representable_as_int x + then int64_to_int_trunc x + else int64_to_int_failure x + +(* int <-> nativeint *) + +let nativeint_to_int_failure x = convert_failure x "nativeint" "int" nativeint_to_string + +let () = assert (num_bits_int <= num_bits_nativeint) + +let int_to_nativeint = Caml.Nativeint.of_int +let nativeint_to_int_trunc = Caml.Nativeint.to_int + +let nativeint_is_representable_as_int = + if num_bits_nativeint <= num_bits_int + then (fun _ -> true) + else + let min = int_to_nativeint Int.min_value in + let max = int_to_nativeint Int.max_value in + (fun x -> compare_nativeint min x <= 0 && compare_nativeint x max <= 0) + +let nativeint_to_int x = + if nativeint_is_representable_as_int x + then Some (nativeint_to_int_trunc x) + else None + +let nativeint_to_int_exn x = + if nativeint_is_representable_as_int x + then nativeint_to_int_trunc x + else nativeint_to_int_failure x + +(* int32 <-> int64 *) + +let int64_to_int32_failure x = convert_failure x "int64" "int32" int64_to_string + +let () = assert (num_bits_int32 < num_bits_int64) + +let int32_to_int64 = Caml.Int64.of_int32 +let int64_to_int32_trunc = Caml.Int64.to_int32 + +let int64_is_representable_as_int32 = + let min = int32_to_int64 min_int32 in + let max = int32_to_int64 max_int32 in + (fun x -> compare_int64 min x <= 0 && compare_int64 x max <= 0) + +let int64_to_int32 x = + if int64_is_representable_as_int32 x + then Some (int64_to_int32_trunc x) + else None + +let int64_to_int32_exn x = + if int64_is_representable_as_int32 x + then int64_to_int32_trunc x + else int64_to_int32_failure x + +(* int32 <-> nativeint *) + +let nativeint_to_int32_failure x = + convert_failure x "nativeint" "int32" nativeint_to_string + +let () = assert (num_bits_int32 <= num_bits_nativeint) + +let int32_to_nativeint = Caml.Nativeint.of_int32 +let nativeint_to_int32_trunc = Caml.Nativeint.to_int32 + +let nativeint_is_representable_as_int32 = + if num_bits_nativeint <= num_bits_int32 + then (fun _ -> true) + else + let min = int32_to_nativeint min_int32 in + let max = int32_to_nativeint max_int32 in + (fun x -> compare_nativeint min x <= 0 && compare_nativeint x max <= 0) + +let nativeint_to_int32 x = + if nativeint_is_representable_as_int32 x + then Some (nativeint_to_int32_trunc x) + else None + +let nativeint_to_int32_exn x = + if nativeint_is_representable_as_int32 x + then nativeint_to_int32_trunc x + else nativeint_to_int32_failure x + + +(* int64 <-> nativeint *) + +let int64_to_nativeint_failure x = convert_failure x "int64" "nativeint" int64_to_string + +let () = assert (num_bits_int64 >= num_bits_nativeint) + +let int64_to_nativeint_trunc = Caml.Int64.to_nativeint +let nativeint_to_int64 = Caml.Int64.of_nativeint + +let int64_is_representable_as_nativeint = + if num_bits_int64 <= num_bits_nativeint + then (fun _ -> true) + else + let min = nativeint_to_int64 min_nativeint in + let max = nativeint_to_int64 max_nativeint in + (fun x -> compare_int64 min x <= 0 && compare_int64 x max <= 0) + +let int64_to_nativeint x = + if int64_is_representable_as_nativeint x + then Some (int64_to_nativeint_trunc x) + else None + +let int64_to_nativeint_exn x = + if int64_is_representable_as_nativeint x + then int64_to_nativeint_trunc x + else int64_to_nativeint_failure x + +(* int64 <-> int63 *) + +let int64_to_int63_failure x = convert_failure x "int64" "int63" int64_to_string + +let int64_is_representable_as_int63 = + let min = Caml.Int64.shift_right min_int64 1 in + let max = Caml.Int64.shift_right max_int64 1 in + (fun x -> compare_int64 min x <= 0 && compare_int64 x max <= 0) + +let int64_fit_on_int63_exn x = + if int64_is_representable_as_int63 x + then () + else int64_to_int63_failure x + +(* string conversions *) + +let insert_delimiter_every input ~delimiter ~chars_per_delimiter = + let input_length = String.length input in + if input_length <= chars_per_delimiter then + input + else begin + let has_sign = match input.[0] with '+' | '-' -> true | _ -> false in + let num_digits = if has_sign then input_length - 1 else input_length in + let num_delimiters = (num_digits - 1) / chars_per_delimiter in + let output_length = input_length + num_delimiters in + let output = Bytes.create output_length in + let input_pos = ref (input_length - 1) in + let output_pos = ref (output_length - 1) in + let num_chars_until_delimiter = ref chars_per_delimiter in + let first_digit_pos = if has_sign then 1 else 0 in + while !input_pos >= first_digit_pos do + if !num_chars_until_delimiter = 0 then begin + Bytes.set output !output_pos delimiter; + decr output_pos; + num_chars_until_delimiter := chars_per_delimiter; + end; + Bytes.set output !output_pos input.[!input_pos]; + decr input_pos; + decr output_pos; + decr num_chars_until_delimiter; + done; + if has_sign then Bytes.set output 0 input.[0]; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:output; + end +;; + +let insert_delimiter input ~delimiter = + insert_delimiter_every input ~delimiter ~chars_per_delimiter:3 + +let insert_underscores input = + insert_delimiter input ~delimiter:'_' + +let sexp_of_int_style = Sexp.of_int_style + +module Make (I : sig + type t + val to_string : t -> string + end) = struct + + open I + + let chars_per_delimiter = 3 + + let to_string_hum ?(delimiter='_') t = + insert_delimiter_every (to_string t) ~delimiter ~chars_per_delimiter + + let sexp_of_t t = + let s = to_string t in + Sexp.Atom + (match !sexp_of_int_style with + | `Underscores -> insert_delimiter_every s ~chars_per_delimiter ~delimiter:'_' + | `No_underscores -> s) + ;; +end + +module Make_hex (I : sig + type t [@@deriving_inline compare, hash] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + end[@@ocaml.doc "@inline"] + [@@@end] + val to_string : t -> string + val of_string : string -> t + val zero : t + val (<) : t -> t -> bool + val neg : t -> t + val module_name : string + end) = +struct + + module T_hex = struct + + type t = I.t [@@deriving_inline compare, hash] + let compare : t -> t -> int = I.compare + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + I.hash_fold_t + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = I.hash in fun x -> func x + [@@@end] + + let chars_per_delimiter = 4 + + let to_string' ?delimiter t = + let make_suffix = + match delimiter with + | None -> I.to_string + | Some delimiter -> + (fun t -> + insert_delimiter_every (I.to_string t) ~delimiter ~chars_per_delimiter) + in + if I.(<) t I.zero + then "-0x" ^ make_suffix (I.neg t) + else "0x" ^ make_suffix t + + let to_string t = to_string' t ?delimiter:None + + let to_string_hum ?(delimiter='_') t = to_string' t ~delimiter + + let invalid str = + failwith (Printf.sprintf "%s.of_string: invalid input %S" I.module_name str) + + let of_string_with_delimiter str = + I.of_string (String.filter str ~f:(fun c -> Char.( <> ) c '_')) + + let of_string str = + let module L = Hex_lexer in + let lex = Caml.Lexing.from_string str in + let result = Option.try_with (fun () -> L.parse_hex lex) in + if lex.lex_curr_pos = lex.lex_buffer_len then ( + match result with + | None -> invalid str + | Some (Neg body) -> I.neg (of_string_with_delimiter body) + | Some (Pos body) -> of_string_with_delimiter body + ) else + invalid str + end + + module Hex = struct + include T_hex + include Sexpable.Of_stringable(T_hex) + end + +end diff --git a/src/int_conversions.mli b/src/int_conversions.mli new file mode 100644 index 0000000..23255ae --- /dev/null +++ b/src/int_conversions.mli @@ -0,0 +1,134 @@ +(** Conversions between various integer types *) + +open! Import + +(** Ocaml has the following integer types, with the following bit widths + on 32-bit and 64-bit architectures. + + {v + arch arch + type 32b 64b + ---------------------- + int 31 63 (32 when compiled to JavaScript) + nativeint 32 64 + int32 32 32 + int64 64 64 + v} + + In both cases, the following inequalities hold: + + {[ + width(int) < width(nativeint) + && width(int32) <= width(nativeint) <= width(int64) + ]} + + The conversion functions come in one of two flavors. + + If width(foo) <= width(bar) on both 32-bit and 64-bit architectures, then we have + + {[ val foo_to_bar : foo -> bar ]} + + otherwise we have + + {[ + val foo_to_bar : foo -> bar option + val foo_to_bar_exn : foo -> bar + ]} *) +val int_to_int32 : int -> int32 option +val int_to_int32_exn : int -> int32 +val int_to_int32_trunc : int -> int32 +val int_to_int64 : int -> int64 +val int_to_nativeint : int -> nativeint + +val int32_to_int : int32 -> int option +val int32_to_int_exn : int32 -> int +val int32_to_int_trunc : int32 -> int +val int32_to_int64 : int32 -> int64 +val int32_to_nativeint : int32 -> nativeint + +val int64_to_int : int64 -> int option +val int64_to_int_exn : int64 -> int +val int64_to_int_trunc : int64 -> int +val int64_to_int32 : int64 -> int32 option +val int64_to_int32_exn : int64 -> int32 +val int64_to_int32_trunc : int64 -> int32 +val int64_to_nativeint : int64 -> nativeint option +val int64_to_nativeint_exn : int64 -> nativeint +val int64_to_nativeint_trunc : int64 -> nativeint + +val int64_fit_on_int63_exn : int64 -> unit +val int64_is_representable_as_int63 : int64 -> bool + +val nativeint_to_int : nativeint -> int option +val nativeint_to_int_exn : nativeint -> int +val nativeint_to_int_trunc : nativeint -> int +val nativeint_to_int32 : nativeint -> int32 option +val nativeint_to_int32_exn : nativeint -> int32 +val nativeint_to_int32_trunc : nativeint -> int32 +val nativeint_to_int64 : nativeint -> int64 + +val num_bits_int : int +val num_bits_int32 : int +val num_bits_int64 : int +val num_bits_nativeint : int + +(** human-friendly string (and possibly sexp) conversions *) +module Make (I : sig + + type t + + val to_string : t -> string + + end) : sig + + val to_string_hum + : ?delimiter:char (** defaults to ['_'] *) + -> I.t + -> string + + val sexp_of_t : I.t -> Sexp.t + +end + +module Make_hex (I : sig + type t [@@deriving_inline compare, hash] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + end[@@ocaml.doc "@inline"] + [@@@end] + + (** [to_string] and [of_string] convert between [t] and unsigned, + unprefixed hexadecimal. + They must be able to handle all non-negative values and also + [min_value]. [to_string min_value] must write a positive hex + representation. *) + val to_string : t -> string + val of_string : string -> t + val zero : t + val (<) : t -> t -> bool + val neg : t -> t + val module_name : string + end) + : Int_intf.Hexable with type t := I.t +(** in the output, [to_string], [of_string], [sexp_of_t], and [t_of_sexp] convert + between [t] and signed hexadecimal with an optional "0x" or "0X" prefix. *) + +(** global ref affecting whether the [sexp_of_t] returned by [Make] + is consistent with the [to_string] input or the [to_string_hum] output *) +val sexp_of_int_style : [ `No_underscores | `Underscores ] ref + +(** utility for defining to_string_hum on numeric types -- takes a string matching + (-|+)?[0-9a-fA-F]+ and puts [delimiter] every [chars_per_delimiter] characters + starting from the right. *) +val insert_delimiter_every : string -> delimiter:char -> chars_per_delimiter:int -> string + +(** [insert_delimiter_every ~chars_per_delimiter:3] *) +val insert_delimiter : string -> delimiter:char -> string + +(** [insert_delimiter ~delimiter:'_'] *) +val insert_underscores : string -> string diff --git a/src/int_intf.ml b/src/int_intf.ml new file mode 100644 index 0000000..2c75eaa --- /dev/null +++ b/src/int_intf.ml @@ -0,0 +1,363 @@ +(** An interface to use for int-like types, e.g., {{!Base.Int}[Int]} and + {{!Base.Int64}[Int64]}. *) + +open! Import + +module type Round = sig + type t + + (** [round] rounds an int to a multiple of a given [to_multiple_of] argument, according + to a direction [dir], with default [dir] being [`Nearest]. [round] will raise if + [to_multiple_of <= 0]. + + {v + | `Down | rounds toward Int.neg_infinity | + | `Up | rounds toward Int.infinity | + | `Nearest | rounds to the nearest multiple, or `Up in case of a tie | + | `Zero | rounds toward zero | + v} + + Here are some examples for [round ~to_multiple_of:10] for each direction: + + {v + | `Down | {10 .. 19} --> 10 | { 0 ... 9} --> 0 | {-10 ... -1} --> -10 | + | `Up | { 1 .. 10} --> 10 | {-9 ... 0} --> 0 | {-19 .. -10} --> -10 | + | `Zero | {10 .. 19} --> 10 | {-9 ... 9} --> 0 | {-19 .. -10} --> -10 | + | `Nearest | { 5 .. 14} --> 10 | {-5 ... 4} --> 0 | {-15 ... -6} --> -10 | + v} + + For convenience and performance, there are variants of [round] with [dir] + hard-coded. If you are writing performance-critical code you should use these. *) + + val round : ?dir:[ `Zero | `Nearest | `Up | `Down ] -> t -> to_multiple_of:t -> t + + val round_towards_zero : t -> to_multiple_of:t -> t + val round_down : t -> to_multiple_of:t -> t + val round_up : t -> to_multiple_of:t -> t + val round_nearest : t -> to_multiple_of:t -> t +end + +module type Hexable = sig + type t + module Hex : sig + type nonrec t = t [@@deriving_inline sexp, compare, hash] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + end[@@ocaml.doc "@inline"] + [@@@end] + + include Stringable.S with type t := t + + val to_string_hum : ?delimiter:char -> t -> string + end +end + +module type S_common = sig + type t [@@deriving_inline hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Floatable.S with type t := t + include Intable.S with type t := t + include Identifiable.S with type t := t + include Comparable.With_zero with type t := t + include Hexable with type t := t + + (** [delimiter] is an underscore by default. *) + val to_string_hum : ?delimiter:char -> t -> string + + (** {2 Infix operators and constants} *) + + val zero : t + val one : t + val minus_one : t + + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val ( * ) : t -> t -> t + + (** Integer exponentiation *) + val ( ** ) : t -> t -> t + + (** Negation *) + + val neg : t -> t + val ( ~- ) : t -> t + + (** There are two pairs of integer division and remainder functions, [/%] and [%], and + [/] and [rem]. They both satisfy the same equation relating the quotient and the + remainder: + + {[ + x = (x /% y) * y + (x % y); + x = (x / y) * y + (rem x y); + ]} + + The functions return the same values if [x] and [y] are positive. They all raise + if [y = 0]. + + The functions differ if [x < 0] or [y < 0]. + + If [y < 0], then [%] and [/%] raise, whereas [/] and [rem] do not. + + [x % y] always returns a value between 0 and [y - 1], even when [x < 0]. On the + other hand, [rem x y] returns a negative value if and only if [x < 0]; that value + satisfies [abs (rem x y) <= abs y - 1]. *) + + val ( /% ) : t -> t -> t + val ( % ) : t -> t -> t + val ( / ) : t -> t -> t + val rem : t -> t -> t + + (** Float division of integers. *) + val ( // ) : t -> t -> float + + (** Same as [bit_and]. *) + val ( land ) : t -> t -> t + + (** Same as [bit_or]. *) + val ( lor ) : t -> t -> t + + (** Same as [bit_xor]. *) + val ( lxor ) : t -> t -> t + + (** Same as [bit_not]. *) + val lnot : t -> t + + (** Same as [shift_left]. *) + val ( lsl ) : t -> int -> t + + (** Same as [shift_right]. *) + val ( asr ) : t -> int -> t + + (** {2 Other common functions} *) + + include Round with type t := t + + (** Returns the absolute value of the argument. May be negative if the input is + [min_value]. *) + val abs : t -> t + + (** {2 Successor and predecessor functions} *) + + val succ : t -> t + val pred : t -> t + + (** {2 Exponentiation} *) + + (** [pow base exponent] returns [base] raised to the power of [exponent]. It is OK if + [base <= 0]. [pow] raises if [exponent < 0], or an integer overflow would occur. *) + val pow : t -> t -> t + + (** {2 Bit-wise logical operations } *) + + (** These are identical to [land], [lor], etc. except they're not infix and have + different names. *) + val bit_and : t -> t -> t + val bit_or : t -> t -> t + val bit_xor : t -> t -> t + val bit_not : t -> t + + (** Returns the number of 1 bits in the binary representation of the input. *) + val popcount : t -> int + + (** {2 Bit-shifting operations } + + The results are unspecified for negative shifts and shifts [>= num_bits]. *) + + (** Shifts left, filling in with zeroes. *) + val shift_left : t -> int -> t + + (** Shifts right, preserving the sign of the input. *) + val shift_right : t -> int -> t + + (** {2 Increment and decrement functions for integer references } *) + + val decr : t ref -> unit + val incr : t ref -> unit + + (** {2 Conversion functions to related integer types} *) + + val of_int32_exn : int32 -> t + val to_int32_exn : t -> int32 + val of_int64_exn : int64 -> t + val to_int64 : t -> int64 + val of_nativeint_exn : nativeint -> t + val to_nativeint_exn : t -> nativeint + + (** [of_float_unchecked] truncates the given floating point number to an integer, + rounding towards zero. + The result is unspecified if the argument is nan or falls outside the range + of representable integers. *) + val of_float_unchecked : float -> t +end + +module type Operators_unbounded = sig + type t + + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val ( * ) : t -> t -> t + val ( / ) : t -> t -> t + val ( ~- ) : t -> t + val ( ** ) : t -> t -> t + include Comparisons.Infix with type t := t + + val abs : t -> t + val neg : t -> t + val zero : t + + val ( % ) : t -> t -> t + val ( /% ) : t -> t -> t + val ( // ) : t -> t -> float + + val ( land ) : t -> t -> t + val ( lor ) : t -> t -> t + val ( lxor ) : t -> t -> t + val lnot : t -> t + + val ( lsl ) : t -> int -> t + val ( asr ) : t -> int -> t +end + +module type Operators = sig + include Operators_unbounded + val ( lsr ) : t -> int -> t +end + +(** [S_unbounded] is a generic interface for unbounded integers, e.g. [Bignum.Bigint]. + [S_unbounded] is a restriction of [S] (below) that omits values that depend on + fixed-size integers. *) +module type S_unbounded = sig + include S_common (** @inline *) + + (** A sub-module designed to be opened to make working with ints more convenient. *) + module O : Operators_unbounded with type t := t +end + +(** [S] is a generic interface for fixed-size integers. *) +module type S = sig + include S_common (** @inline *) + + (** The number of bits available in this integer type. Note that the integer + representations are signed. *) + val num_bits : int + + (** The largest representable integer. *) + val max_value : t + + (** The smallest representable integer. *) + val min_value : t + + (** Same as [shift_right_logical]. *) + val ( lsr ) : t -> int -> t + + (** Shifts right, filling in with zeroes, which will not preserve the sign of the + input. *) + val shift_right_logical : t -> int -> t + + (** [ceil_pow2 x] returns the smallest power of 2 that is greater than or equal to [x]. + The implementation may only be called for [x > 0]. Example: [ceil_pow2 17 = 32] *) + val ceil_pow2 : t -> t + + (** [floor_pow2 x] returns the largest power of 2 that is less than or equal to [x]. The + implementation may only be called for [x > 0]. Example: [floor_pow2 17 = 16] *) + val floor_pow2 : t -> t + + (** [ceil_log2 x] returns the ceiling of log-base-2 of [x], and raises if [x <= 0]. *) + val ceil_log2 : t -> int + + (** [floor_log2 x] returns the floor of log-base-2 of [x], and raises if [x <= 0]. *) + val floor_log2 : t -> int + + (** [is_pow2 x] returns true iff [x] is a power of 2. [is_pow2] raises if [x <= 0]. *) + val is_pow2 : t -> bool + + (** A sub-module designed to be opened to make working with ints more convenient. *) + module O : Operators with type t := t +end + +include + (struct + (** Various functors whose type-correctness ensures desired relationships between + interfaces. *) + + module Check_O_contained_in_S (M : S) = (M : module type of M.O) + module Check_O_contained_in_S_unbounded (M : S_unbounded) = (M : module type of M.O) + module Check_S_unbounded_in_S (M : S) = (M : S_unbounded) + end : sig end) + +module type Int_without_module_types = sig + include S with type t = int + + (** [max_value_30_bits = 2^30 - 1]. It is useful for writing tests that work on both + 64-bit and 32-bit platforms. *) + val max_value_30_bits : t + + (** {2 Conversion functions} *) + + val of_int : int -> t + val to_int : t -> int + val of_int32 : int32 -> t option + val to_int32 : t -> int32 option + val of_int64 : int64 -> t option + val of_nativeint : nativeint -> t option + val to_nativeint : t -> nativeint + + (** {3 Truncating conversions} + + These functions return the least-significant bits of the input. In cases + where optional conversions return [Some x], truncating conversions return [x]. *) + + val of_int32_trunc : int32 -> t + val to_int32_trunc : t -> int32 + val of_int64_trunc : int64 -> t + val of_nativeint_trunc : nativeint -> t + + (**/**) + (*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) + module Private : sig + (*_ For ../bench/bench_int.ml *) + module O_F : sig + val ( % ) : int -> int -> int + val ( /% ) : int -> int -> int + val ( // ) : int -> int -> float + end + end +end + +(** OCaml's native integer type. + + The number of bits in an integer is platform dependent, being 31-bits on a 32-bit + platform, and 63-bits on a 64-bit platform. [int] is a signed integer type. [int]s + are also subject to overflow, meaning that [Int.max_value + 1 = Int.min_value]. + + [int]s always fit in a machine word. *) +module type Int = sig + include Int_without_module_types + + (** {2 Module types specifying integer operations.} *) + module type Hexable = Hexable + module type Int_without_module_types = Int_without_module_types + module type Operators = Operators + module type Operators_unbounded = Operators_unbounded + module type Round = Round + module type S = S + module type S_common = S_common + module type S_unbounded = S_unbounded +end diff --git a/src/int_math.ml b/src/int_math.ml new file mode 100644 index 0000000..63db7b8 --- /dev/null +++ b/src/int_math.ml @@ -0,0 +1,144 @@ +open! Import + +let invalid_argf = Printf.invalid_argf + +let negative_exponent () = + Printf.invalid_argf "exponent can not be negative" () + +let overflow () = + Printf.invalid_argf "integer overflow in pow" () + +(* To implement [int64_pow], we use C code rather than OCaml to eliminate allocation. *) +external int_math_int_pow : int -> int -> int = "Base_int_math_int_pow_stub" [@@noalloc] +external int_math_int64_pow : int64 -> int64 -> int64 = "Base_int_math_int64_pow_stub" + +let int_pow base exponent = + if exponent < 0 then negative_exponent (); + + if abs(base) > 1 && + (exponent > 63 || + abs(base) > Pow_overflow_bounds.int_positive_overflow_bounds.(exponent)) + then overflow (); + + int_math_int_pow base exponent +;; + +module Int64_with_comparisons = struct + include Caml.Int64 + external ( < ) : int64 -> int64 -> bool = "%lessthan" + external ( > ) : int64 -> int64 -> bool = "%greaterthan" + external ( >= ) : int64 -> int64 -> bool = "%greaterequal" +end + +(* we don't do [abs] in int64 case to avoid allocation *) +let int64_pow base exponent = + let open Int64_with_comparisons in + if exponent < 0L then negative_exponent (); + + if (base > 1L || base < (-1L)) && + (exponent > 63L || + (base >= 0L && + base > Pow_overflow_bounds.int64_positive_overflow_bounds.(to_int exponent)) + || + (base < 0L && + base < Pow_overflow_bounds.int64_negative_overflow_bounds.(to_int exponent))) + then overflow (); + + int_math_int64_pow base exponent +;; + + +let int63_pow_on_int64 base exponent = + let open Int64_with_comparisons in + if exponent < 0L then negative_exponent (); + + if abs(base) > 1L && + (exponent > 63L || + abs(base) > Pow_overflow_bounds.int63_on_int64_positive_overflow_bounds.(to_int exponent)) + then overflow (); + + int_math_int64_pow base exponent +;; + +module type T = sig + type t + include Floatable.S with type t := t + include Stringable.S with type t := t + + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val ( * ) : t -> t -> t + val ( / ) : t -> t -> t + val ( ~- ) : t -> t + include Comparisons.Infix with type t := t + + val abs : t -> t + val neg : t -> t + val zero : t + val of_int_exn : int -> t + val rem : t -> t -> t +end + +module Make (X : T) = struct + open X + + let ( % ) x y = + if y <= zero then + invalid_argf + "%s %% %s in core_int.ml: modulus should be positive" + (to_string x) (to_string y) (); + let rval = X.rem x y in + if rval < zero + then rval + y + else rval + ;; + + let one = of_int_exn 1 + ;; + + let ( /% ) x y = + if y <= zero then + invalid_argf + "%s /%% %s in core_int.ml: divisor should be positive" + (to_string x) (to_string y) (); + if x < zero + then (x + one) / y - one + else x / y + ;; + + (** float division of integers *) + let (//) x y = to_float x /. to_float y + ;; + + let round_down i ~to_multiple_of:modulus = i - (i % modulus) + ;; + + let round_up i ~to_multiple_of:modulus = + let remainder = i % modulus in + if remainder = zero + then i + else i + modulus - remainder + ;; + + let round_towards_zero i ~to_multiple_of = + if i = zero then zero else + if i > zero + then round_down i ~to_multiple_of + else round_up i ~to_multiple_of + ;; + + let round_nearest i ~to_multiple_of:modulus = + let remainder = i % modulus in + if remainder * of_int_exn 2 < modulus + then i - remainder + else i - remainder + modulus + ;; + + let round ?(dir=`Nearest) i ~to_multiple_of = + match dir with + | `Nearest -> round_nearest i ~to_multiple_of + | `Down -> round_down i ~to_multiple_of + | `Up -> round_up i ~to_multiple_of + | `Zero -> round_towards_zero i ~to_multiple_of + ;; +end diff --git a/src/int_math.mli b/src/int_math.mli new file mode 100644 index 0000000..bfd0f82 --- /dev/null +++ b/src/int_math.mli @@ -0,0 +1,39 @@ +(** This module is not exposed in Core. Instead, these functions are accessed and + commented in the various Core modules implementing [Int_intf.S]. *) + +open! Import + +(*_ This interface is not defined in int_intf.ml because we don't want users of Core to + think about it. *) +module type T = sig + type t + include Floatable.S with type t := t + include Stringable.S with type t := t + + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val ( * ) : t -> t -> t + val ( / ) : t -> t -> t + val ( ~- ) : t -> t + include Comparisons.Infix with type t := t + + val abs : t -> t + val neg : t -> t + val zero : t + val of_int_exn : int -> t + + val rem : t -> t -> t +end + +(** derived operations common to various integer modules *) +module Make (X : T) : sig + val ( % ) : X.t -> X.t -> X.t + val ( /% ) : X.t -> X.t -> X.t + val ( // ) : X.t -> X.t -> float + include Int_intf.Round with type t := X.t +end + +val int_pow : int -> int -> int +val int64_pow : int64 -> int64 -> int64 + +val int63_pow_on_int64 : int64 -> int64 -> int64 diff --git a/src/int_math_stubs.c b/src/int_math_stubs.c new file mode 100644 index 0000000..733a327 --- /dev/null +++ b/src/int_math_stubs.c @@ -0,0 +1,92 @@ +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER + +#include + +#define __builtin_popcountll __popcnt64 +#define __builtin_popcount __popcnt + +static uint32_t __inline __builtin_clz(uint32_t x) +{ + int r = 0; + _BitScanForward(&r, x); + return r; +} + +static uint64_t __inline __builtin_clzll(uint64_t x) +{ + int r = 0; + _BitScanForward64(&r, x); + return r; +} + +#endif + +static int64_t int_pow(int64_t base, int64_t exponent) { + int64_t ret = 1; + int64_t mul[4]; + mul[0] = 1; + mul[1] = base; + mul[3] = 1; + + while(exponent != 0) { + mul[1] *= mul[3]; + mul[2] = mul[1] * mul[1]; + mul[3] = mul[2] * mul[1]; + ret *= mul[exponent & 3]; + exponent >>= 2; + } + + return ret; +} + +CAMLprim value Base_int_math_int_pow_stub(value base, value exponent) { + return (Val_long(int_pow(Long_val(base), Long_val(exponent)))); +} + +CAMLprim value Base_int_math_int64_pow_stub(value base, value exponent) { + CAMLparam2(base, exponent); + CAMLreturn(caml_copy_int64(int_pow(Int64_val(base), Int64_val(exponent)))); +} + +/* This implementation is faster than [__builtin_popcount(v) - 1], even though + * it seems more complicated. The [&] clears the shifted sign bit after + * [Long_val] or [Int_val]. */ +CAMLprim value Base_int_math_int_popcount(value v) { +#ifdef ARCH_SIXTYFOUR + return Val_int (__builtin_popcountll (Long_val (v) & ~((uint64_t)1 << 63))); +#else + return Val_int (__builtin_popcount (Int_val (v) & ~((uint32_t)1 << 31))); +#endif +} + +/* The specification of all below [clz] functions is undefined for [v = 0]. */ +CAMLprim value Base_int_math_int_clz(value v) { +#ifdef ARCH_SIXTYFOUR + return Val_int (__builtin_clzll (Long_val(v))); +#else + return Val_int (__builtin_clz (Int_val (v))); +#endif +} + +CAMLprim value Base_int_math_int32_clz(value v) { + return Val_int (__builtin_clz (Int32_val(v))); +} + +CAMLprim value Base_int_math_int64_clz(value v) { + return Val_int (__builtin_clzll (Int64_val(v))); +} + +CAMLprim value Base_int_math_nativeint_clz(value v) { +#ifdef ARCH_SIXTYFOUR + return Val_int (__builtin_clzll (Nativeint_val(v))); +#else + return Val_int (__builtin_clz (Nativeint_val(v))); +#endif +} diff --git a/src/intable.ml b/src/intable.ml new file mode 100644 index 0000000..2300503 --- /dev/null +++ b/src/intable.ml @@ -0,0 +1,11 @@ +(** Functor that adds integer conversion functions to a module. *) + +open! Import + +module type S = sig + type t + + val of_int_exn : int -> t + val to_int_exn : t -> int +end + diff --git a/src/internalhash.h b/src/internalhash.h new file mode 100644 index 0000000..b752b91 --- /dev/null +++ b/src/internalhash.h @@ -0,0 +1,3 @@ +#include +#include +CAMLexport uint32_t Base_internalhash_fold_blob(uint32_t h, mlsize_t len, uint8_t *s); diff --git a/src/internalhash_stubs.c b/src/internalhash_stubs.c new file mode 100644 index 0000000..f3d0b39 --- /dev/null +++ b/src/internalhash_stubs.c @@ -0,0 +1,101 @@ +#include +#include +#include +#include "internalhash.h" + +/* This pretends that the state of the OCaml internal hash function, which is an + int32, is actually stored in an OCaml int. */ + +CAMLprim value Base_internalhash_fold_int32(value st, value i) +{ + return Val_long(caml_hash_mix_uint32(Long_val(st), Int32_val(i))); +} + +CAMLprim value Base_internalhash_fold_nativeint(value st, value i) +{ + return Val_long(caml_hash_mix_intnat(Long_val(st), Nativeint_val(i))); +} + +CAMLprim value Base_internalhash_fold_int64(value st, value i) +{ + return Val_long(caml_hash_mix_int64(Long_val(st), Int64_val(i))); +} + +CAMLprim value Base_internalhash_fold_int(value st, value i) +{ + return Val_long(caml_hash_mix_intnat(Long_val(st), Long_val(i))); +} + +CAMLprim value Base_internalhash_fold_float(value st, value i) +{ + return Val_long(caml_hash_mix_double(Long_val(st), Double_val(i))); +} + +/* This code mimics what hashtbl.hash does in OCaml's hash.c */ +#define FINAL_MIX(h) \ + h ^= h >> 16; \ + h *= 0x85ebca6b; \ + h ^= h >> 13; \ + h *= 0xc2b2ae35; \ + h ^= h >> 16; + +CAMLprim value Base_internalhash_get_hash_value(value st) +{ + uint32_t h = Int_val(st); + FINAL_MIX(h); + return Val_int(h & 0x3FFFFFFFU); /*30 bits*/ +} + +/* Macros copied from hash.c in ocaml distribution */ +#define ROTL32(x,n) ((x) << n | (x) >> (32-n)) + +#define MIX(h,d) \ + d *= 0xcc9e2d51; \ + d = ROTL32(d, 15); \ + d *= 0x1b873593; \ + h ^= d; \ + h = ROTL32(h, 13); \ + h = h * 5 + 0xe6546b64; + +/* Version of [caml_hash_mix_string] from hash.c - adapted for arbitrary char arrays */ +CAMLexport uint32_t Base_internalhash_fold_blob(uint32_t h, mlsize_t len, uint8_t *s) +{ + mlsize_t i; + uint32_t w; + + /* Mix by 32-bit blocks (little-endian) */ + for (i = 0; i + 4 <= len; i += 4) { +#ifdef ARCH_BIG_ENDIAN + w = s[i] + | (s[i+1] << 8) + | (s[i+2] << 16) + | (s[i+3] << 24); +#else + w = *((uint32_t *) &(s[i])); +#endif + MIX(h, w); + } + /* Finish with up to 3 bytes */ + w = 0; + switch (len & 3) { + case 3: w = s[i+2] << 16; /* fallthrough */ + case 2: w |= s[i+1] << 8; /* fallthrough */ + case 1: w |= s[i]; + MIX(h, w); + default: /*skip*/; /* len & 3 == 0, no extra bytes, do nothing */ + } + /* Finally, mix in the length. Ignore the upper 32 bits, generally 0. */ + h ^= (uint32_t) len; + return h; +} + +CAMLprim value Base_internalhash_fold_string(value st, value v_str) +{ + uint32_t h = Long_val(st); + mlsize_t len = caml_string_length(v_str); + uint8_t *s = (uint8_t *) String_val(v_str); + + h = Base_internalhash_fold_blob(h, len, s); + + return Val_long(h); +} diff --git a/src/invariant.ml b/src/invariant.ml new file mode 100644 index 0000000..026a9e7 --- /dev/null +++ b/src/invariant.ml @@ -0,0 +1,26 @@ +open! Import + +include Invariant_intf + +let raise_s = Error.raise_s + +let invariant here t sexp_of_t f : unit = + try + f () + with exn -> + raise_s + (Sexp.message "invariant failed" + [ "" , Source_code_position0.sexp_of_t here + ; "exn", sexp_of_exn exn + ; "" , sexp_of_t t ]) +;; + +let check_field t f field = + try + f (Field.get field t) + with exn -> + raise_s (Sexp.message "problem with field" + [ "field", sexp_of_string (Field.name field) + ; "exn" , sexp_of_exn exn + ]) +;; diff --git a/src/invariant.mli b/src/invariant.mli new file mode 100644 index 0000000..23ca91b --- /dev/null +++ b/src/invariant.mli @@ -0,0 +1 @@ +include Invariant_intf.Invariant (** @inline *) diff --git a/src/invariant_intf.ml b/src/invariant_intf.ml new file mode 100644 index 0000000..4e657a4 --- /dev/null +++ b/src/invariant_intf.ml @@ -0,0 +1,98 @@ +open! Import + +type 'a t = 'a -> unit + +type 'a inv = 'a t + +module type S = sig + type t + val invariant : t inv +end + +module type S1 = sig + type 'a t + val invariant : 'a inv -> 'a t inv +end + +module type S2 = sig + type ('a, 'b) t + val invariant : 'a inv -> 'b inv -> ('a, 'b) t inv +end + +module type S3 = sig + type ('a, 'b, 'c) t + val invariant : 'a inv -> 'b inv -> 'c inv -> ('a, 'b, 'c) t inv +end + +module type Invariant = sig + + (** This module defines signatures that are to be included in other signatures to ensure + a consistent interface to invariant-style functions. There is a signature ([S], + [S1], [S2], [S3]) for each arity of type. Usage looks like: + + {[ + type t + include Invariant.S with type t := t + ]} + + or + + {[ + type 'a t + include Invariant.S1 with type 'a t := 'a t + ]} + *) + + type nonrec 'a t = 'a t + + module type S = S + module type S1 = S1 + module type S2 = S2 + module type S3 = S3 + + (** [invariant here t sexp_of_t f] runs [f ()], and if [f] raises, wraps the exception + in an [Error.t] that states "invariant failed" and includes both the exception + raised by [f], as well as [sexp_of_t t]. Idiomatic usage looks like: + + {[ + invariant [%here] t [%sexp_of: t] (fun () -> + ... check t's invariants ... ) + ]} + + For polymorphic types: + + {[ + let invariant check_a t = + Invariant.invariant [%here] t [%sexp_of: _ t] (fun () -> ... ) + ]} + + It's okay to use [ [%sexp_of: _ t] ] because the exceptions raised by [check_a] will + show the parts that are sexp_opaque at top-level. *) + val invariant + : Source_code_position0.t + -> 'a + -> ('a -> Sexp.t) + -> (unit -> unit) + -> unit + + (** [check_field] is used when checking invariants using [Fields.iter]. It wraps an + exception raised when checking a field with the field's name. Idiomatic usage looks + like: + + {[ + type t = + { foo : Foo.t; + bar : Bar.t; + } + [@@deriving_inline fields][@@@end] + + let invariant t : unit = + Invariant.invariant [%here] t [%sexp_of: t] (fun () -> + let check f = Invariant.check_field t f in + Fields.iter + ~foo:(check Foo.invariant) + ~bar:(check Bar.invariant)) + ;; + ]} *) + val check_field : 'a -> 'b t -> ('a, 'b) Field.t -> unit +end diff --git a/src/lazy.ml b/src/lazy.ml new file mode 100644 index 0000000..6a32e4a --- /dev/null +++ b/src/lazy.ml @@ -0,0 +1,44 @@ +open! Import + +type 'a t = 'a lazy_t [@@deriving_inline sexp] +let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + lazy_t_of_sexp +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + sexp_of_lazy_t +[@@@end] + +include (Caml.Lazy : module type of Caml.Lazy with type 'a t := 'a t) + +let map t ~f = lazy (f (force t)) + +let compare compare_a t1 t2 = + if phys_equal t1 t2 + then 0 + else compare_a (force t1) (force t2) +;; + +let hash_fold_t = Hash.Builtin.hash_fold_lazy_t + +include Monad.Make (struct + type nonrec 'a t = 'a t + + let return x = from_val x + + let bind t ~f = lazy (force (f (force t))) + + let map = map + + let map = `Custom map + end) + +module T_unforcing = struct + type nonrec 'a t = 'a t + + let sexp_of_t sexp_of_a t = + if is_val t + then sexp_of_a (force t) + else sexp_of_string "" + ;; +end diff --git a/src/lazy.mli b/src/lazy.mli new file mode 100644 index 0000000..1a0c599 --- /dev/null +++ b/src/lazy.mli @@ -0,0 +1,81 @@ +(*_ JS-only: This file is a modified version of lazy.mli from the OCaml distribution. *) + +(** A value of type ['a Lazy.t] is a deferred computation, called a suspension, that has a + result of type ['a]. + + The special expression syntax [lazy (expr)] makes a suspension of the computation of + [expr], without computing [expr] itself yet. "Forcing" the suspension will then + compute [expr] and return its result. + + Note: [lazy_t] is the built-in type constructor used by the compiler for the [lazy] + keyword. You should not use it directly. Always use [Lazy.t] instead. + + Note: [Lazy.force] is not thread-safe. If you use this module in a multi-threaded + program, you will need to add some locks. + + Note: if the program is compiled with the [-rectypes] option, ill-founded recursive + definitions of the form [let rec x = lazy x] or [let rec x = lazy(lazy(...(lazy x)))] + are accepted by the type-checker and lead, when forced, to ill-formed values that + trigger infinite loops in the garbage collector and other parts of the run-time + system. Without the [-rectypes] option, such ill-founded recursive definitions are + rejected by the type-checker. *) + +open! Import + +type 'a t = 'a lazy_t [@@deriving_inline compare, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Monad.S with type 'a t := 'a t + +exception Undefined + +(** [force x] forces the suspension [x] and returns its result. If [x] has already been + forced, [Lazy.force x] returns the same value again without recomputing it. If it + raised an exception, the same exception is raised again. Raise [Undefined] if the + forcing of [x] tries to force [x] itself recursively. *) +external force : 'a t -> 'a = "%lazy_force" + +(** Like [force] except that [force_val x] does not use an exception handler, so it may be + more efficient. However, if the computation of [x] raises an exception, it is + unspecified whether [force_val x] raises the same exception or [Undefined]. *) +val force_val : 'a t -> 'a + +(** [from_fun f] is the same as [lazy (f ())] but slightly more efficient if [f] is a + variable. [from_fun] should only be used if the function [f] is already defined. In + particular it is always less efficient to write [from_fun (fun () -> expr)] than [lazy + expr]. *) +val from_fun : (unit -> 'a) -> 'a t + +(** [from_val v] returns an already-forced suspension of [v] (where [v] can be any + expression). Essentially, [from_val expr] is the same as [let var = expr in lazy + var]. *) +val from_val : 'a -> 'a t + +(** [is_val x] returns [true] if [x] has already been forced and did not raise an + exception. *) +val is_val : 'a t -> bool + +(** This type offers a serialization function [sexp_of_t] that won't force its argument. + Instead, it will serialize the ['a] if it is available, or just use a custom string + indicating it is not forced. Note that this is not a round-trippable type, thus the + type does not expose [of_sexp]. To be used in debug code, while tracking a Heisenbug, + etc. *) +module T_unforcing : sig + type nonrec 'a t = 'a t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] +end diff --git a/src/linked_queue.ml b/src/linked_queue.ml new file mode 100644 index 0000000..bc5d7af --- /dev/null +++ b/src/linked_queue.ml @@ -0,0 +1,159 @@ +open! Import + +include Linked_queue0 + + +let enqueue t x = Linked_queue0.push x t + +let dequeue t = if is_empty t then None else Some (Linked_queue0.pop t) +let dequeue_exn = Linked_queue0.pop + +let peek t = if is_empty t then None else Some (Linked_queue0.peek t) +let peek_exn = Linked_queue0.peek + + +module C = + Indexed_container.Make (struct + type nonrec 'a t = 'a t + let fold = fold + let iter = `Custom iter + let length = `Custom length + let foldi = `Define_using_fold + let iteri = `Define_using_fold + end) + +let count = C.count +let exists = C.exists +let find = C.find +let find_map = C.find_map +let fold_result = C.fold_result +let fold_until = C.fold_until +let for_all = C.for_all +let max_elt = C.max_elt +let mem = C.mem +let min_elt = C.min_elt +let sum = C.sum +let to_list = C.to_list + +let counti = C.counti +let existsi = C.existsi +let find_mapi = C.find_mapi +let findi = C.findi +let foldi = C.foldi +let for_alli = C.for_alli +let iteri = C.iteri + +let transfer ~src ~dst = Linked_queue0.transfer src dst + +let concat_map t ~f = + let res = create () in + iter t ~f:(fun a -> + List.iter (f a) ~f:(fun b -> enqueue res b)); + res +;; + +let concat_mapi t ~f = + let res = create () in + iteri t ~f:(fun i a -> + List.iter (f i a) ~f:(fun b -> enqueue res b)); + res +;; + +let filter_map t ~f = + let res = create () in + iter t ~f:(fun a -> + match f a with + | None -> () + | Some b -> enqueue res b); + res; +;; + +let filter_mapi t ~f = + let res = create () in + iteri t ~f:(fun i a -> + match f i a with + | None -> () + | Some b -> enqueue res b); + res; +;; + +let filter t ~f = + let res = create () in + iter t ~f:(fun a -> if f a then enqueue res a); + res; +;; + +let filteri t ~f = + let res = create () in + iteri t ~f:(fun i a -> if f i a then enqueue res a); + res; +;; + +let map t ~f = + let res = create () in + iter t ~f:(fun a -> enqueue res (f a)); + res; +;; + +let mapi t ~f = + let res = create () in + iteri t ~f:(fun i a -> enqueue res (f i a)); + res; +;; + +let filter_inplace q ~f = + let q' = filter q ~f in + clear q; + transfer ~src:q' ~dst:q; +;; + +let filteri_inplace q ~f = + let q' = filteri q ~f in + clear q; + transfer ~src:q' ~dst:q; +;; + +let enqueue_all t list = + List.iter list ~f:(fun x -> enqueue t x) +;; + +let of_list list = + let t = create () in + List.iter list ~f:(fun x -> enqueue t x); + t +;; + +let of_array array = + let t = create () in + Array.iter array ~f:(fun x -> enqueue t x); + t +;; + +let init len ~f = + let t = create () in + for i = 0 to len - 1 do + enqueue t (f i) + done; + t +;; + +let to_array t = + match length t with + | 0 -> [||] + | len -> + let arr = Array.create ~len (peek_exn t) in + let i = ref 0 in + iter t ~f:(fun v -> + arr.(!i) <- v; + incr i); + arr +;; + +let t_of_sexp a_of_sexp sexp = of_list (list_of_sexp a_of_sexp sexp) +let sexp_of_t sexp_of_a t = sexp_of_list sexp_of_a (to_list t) + +let singleton a = + let t = create () in + enqueue t a; + t +;; diff --git a/src/linked_queue.mli b/src/linked_queue.mli new file mode 100644 index 0000000..158e8f8 --- /dev/null +++ b/src/linked_queue.mli @@ -0,0 +1,19 @@ +(** This module is a Base-style wrapper around OCaml's standard [Queue] module. *) + +open! Import + +include Queue_intf.S with type 'a t = 'a Caml.Queue.t (** @inline *) + +(** [create ()] returns an empty queue. *) +val create : unit -> _ t + +(** [transfer ~src ~dst] adds all of the elements of [src] to the end of [dst], then + clears [src]. It is equivalent to the sequence: + + {[ + iter ~src ~f:(enqueue dst); + clear src + ]} + + but runs in constant time. *) +val transfer : src:'a t -> dst:'a t -> unit diff --git a/src/linked_queue0.ml b/src/linked_queue0.ml new file mode 100644 index 0000000..fcd67f4 --- /dev/null +++ b/src/linked_queue0.ml @@ -0,0 +1,16 @@ +open! Import0 + +type 'a t = 'a Caml.Queue.t + +let create = Caml.Queue.create +let clear = Caml.Queue.clear +let copy = Caml.Queue.copy +let is_empty = Caml.Queue.is_empty +let length = Caml.Queue.length +let peek = Caml.Queue.peek +let pop = Caml.Queue.pop +let push = Caml.Queue.push +let transfer = Caml.Queue.transfer + +let iter t ~f = Caml.Queue.iter f t +let fold t ~init ~f = Caml.Queue.fold f init t diff --git a/src/list.ml b/src/list.ml new file mode 100644 index 0000000..fc5f13c --- /dev/null +++ b/src/list.ml @@ -0,0 +1,1065 @@ +open! Import + +module Array = Array0 + + +include List1 (* This itself includes [List0]. *) + +let invalid_argf = Printf.invalid_argf + +module T = struct + type 'a t = 'a list [@@deriving_inline sexp] + let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + list_of_sexp + let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + sexp_of_list + [@@@end] +end + +module Or_unequal_lengths = struct + type 'a t = + | Ok of 'a + | Unequal_lengths + [@@deriving_inline compare, sexp_of] + let compare : 'a . ('a -> 'a -> int) -> 'a t -> 'a t -> int = + fun _cmp__a -> + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match (a__001_, b__002_) with + | (Ok _a__003_, Ok _b__004_) -> _cmp__a _a__003_ _b__004_ + | (Ok _, _) -> (-1) + | (_, Ok _) -> 1 + | (Unequal_lengths, Unequal_lengths) -> 0) + let sexp_of_t : type a. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + function + | Ok v0 -> + let v0 = _of_a v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Ok"; v0] + | Unequal_lengths -> Ppx_sexp_conv_lib.Sexp.Atom "Unequal_lengths" + [@@@end] +end + +include T + +let of_list t = t + +let range' ~compare ~stride ?(start=`inclusive) ?(stop=`exclusive) start_i stop_i = + let next_i = stride start_i in + let order x y = Ordering.of_int (compare x y) in + let raise_stride_cannot_return_same_value () = + invalid_arg "List.range': stride function cannot return the same value" + in + let initial_stride_order = + match order start_i next_i with + | Equal -> raise_stride_cannot_return_same_value () + | Less -> `Less + | Greater -> `Greater + in + let rec loop i accum = + let i_to_stop_order = order i stop_i in + match i_to_stop_order, initial_stride_order with + | (Less, `Less) + | (Greater, `Greater) -> begin + (* haven't yet reached [stop_i]. Continue. *) + let next_i = stride i in + match order i next_i, initial_stride_order with + | Equal, _ -> raise_stride_cannot_return_same_value () + | Less, `Greater + | Greater, `Less -> + invalid_arg "List.range': stride function cannot change direction" + | Less, `Less + | Greater, `Greater -> loop next_i (i :: accum) + end + | (Less, `Greater) + | (Greater, `Less) -> + (* stepped past [stop_i]. Finished. *) + accum + | (Equal, _) -> + (* reached [stop_i]. Finished. *) + match stop with + | `inclusive -> i :: accum + | `exclusive -> accum + in + let start_i = + match start with + | `inclusive -> start_i + | `exclusive -> next_i + in + rev (loop start_i []) +;; + +let range ?(stride=1) ?(start=`inclusive) ?(stop=`exclusive) start_i stop_i = + if stride = 0 then invalid_arg "List.range: stride must be non-zero"; + range' ~compare ~stride:(fun x -> x + stride) ~start ~stop start_i stop_i +;; + +let hd t = + match t with + | [] -> None + | x :: _ -> Some x +;; + +let tl t = + match t with + | [] -> None + | _ :: t' -> Some t' +;; + +let nth t n = + if n < 0 then None else + let rec nth_aux t n = + match t with + | [] -> None + | a :: t -> if n = 0 then Some a else nth_aux t (n-1) + in nth_aux t n +;; + +let nth_exn t n = + match nth t n with + | None -> + invalid_argf "List.nth_exn %d called on list of length %d" + n (length t) () + | Some a -> a +;; + +let unordered_append l1 l2 = + match l1, l2 with + | [], l | l, [] -> l + | _ -> rev_append l1 l2 + +let check_length2_exn name l1 l2 = + let n1 = length l1 in + let n2 = length l2 in + if n1 <> n2 + then raise (invalid_argf "length mismatch in %s: %d <> %d " name n1 n2 ()) +;; + +let check_length3_exn name l1 l2 l3 = + let n1 = length l1 in + let n2 = length l2 in + let n3 = length l3 in + if n1 <> n2 || n2 <> n3 + then raise (invalid_argf "length mismatch in %s: %d <> %d || %d <> %d" + name n1 n2 n2 n3 ()) +;; + +let check_length2 l1 l2 ~f = + if length l1 <> length l2 + then Or_unequal_lengths.Unequal_lengths + else Ok (f l1 l2) +;; + +let check_length3 l1 l2 l3 ~f = + let n1 = length l1 in + let n2 = length l2 in + let n3 = length l3 in + if n1 <> n2 || n2 <> n3 + then Or_unequal_lengths.Unequal_lengths + else Ok (f l1 l2 l3) +;; + +let iter2 l1 l2 ~f = check_length2 l1 l2 ~f:(iter2_ok ~f) + +let iter2_exn l1 l2 ~f = + check_length2_exn "iter2_exn" l1 l2; + iter2_ok l1 l2 ~f; +;; + +let rev_map2 l1 l2 ~f = check_length2 l1 l2 ~f:(rev_map2_ok ~f) + +let rev_map2_exn l1 l2 ~f = + check_length2_exn "rev_map2_exn" l1 l2; + rev_map2_ok l1 l2 ~f; +;; + +let fold2 l1 l2 ~init ~f = check_length2 l1 l2 ~f:(fold2_ok ~init ~f) + +let fold2_exn l1 l2 ~init ~f = + check_length2_exn "fold2_exn" l1 l2; + fold2_ok l1 l2 ~init ~f; +;; + +let for_all2 l1 l2 ~f = check_length2 l1 l2 ~f:(for_all2_ok ~f) + +let for_all2_exn l1 l2 ~f = + check_length2_exn "for_all2_exn" l1 l2; + for_all2_ok l1 l2 ~f; +;; + +let exists2 l1 l2 ~f = check_length2 l1 l2 ~f:(exists2_ok ~f) + +let exists2_exn l1 l2 ~f = + check_length2_exn "exists2_exn" l1 l2; + exists2_ok l1 l2 ~f; +;; + +let mem t a ~equal = + let rec loop equal a = function + | [] -> false + | b :: bs -> equal a b || loop equal a bs + in + loop equal a t + +(* This is a copy of the code from the standard library, with an extra eta-expansion to + avoid creating partial closures (showed up for [filter]) in profiling). *) +let rev_filter t ~f = + let rec find ~f accu = function + | [] -> accu + | x :: l -> if f x then find ~f (x :: accu) l else find ~f accu l + in + find ~f [] t +;; + +let filter t ~f = rev (rev_filter t ~f) + +let find_map t ~f = + let rec loop = function + | [] -> None + | x :: l -> + match f x with + | None -> loop l + | Some _ as r -> r + in + loop t +;; + +let find_map_exn t ~f = + match find_map t ~f with + | None -> raise Caml.Not_found + | Some x -> x + +let find t ~f = + let rec loop = function + | [] -> None + | x :: l -> if f x then Some x else loop l + in + loop t +;; + +let findi t ~f = + let rec loop i t = + match t with + | [] -> None + | x :: l -> if f i x then Some (i, x) else loop (i + 1) l + in + loop 0 t +;; + +let find_mapi t ~f = + let rec loop i t = + match t with + | [] -> None + | x :: l -> + match f i x with + | Some _ as result -> result + | None -> loop (i + 1) l + in + loop 0 t +;; + +let find_mapi_exn t ~f = + match find_mapi t ~f with + | None -> raise Caml.Not_found + | Some x -> x + +let for_alli t ~f = + let rec loop i t = + match t with + | [] -> true + | hd :: tl -> f i hd && loop (i+1) tl + in + loop 0 t + +let existsi t ~f = + let rec loop i t = + match t with + | [] -> false + | hd :: tl -> f i hd || loop (i+1) tl + in + loop 0 t + +(** For the container interface. *) +let fold_left = fold +let to_array = Array.of_list +let to_list t = t + +(** Tail recursive versions of standard [List] module *) + +let slow_append l1 l2 = rev_append (rev l1) l2 + +(* There are a few optimized list operations here, including append and map. There are + basically two optimizations in play: loop unrolling, and dynamic switching between + stack and heap allocation. + + The loop-unrolling is straightforward, we just unroll 5 levels of the loop. This makes + each iteration faster, and also reduces the number of stack frames consumed per list + element. + + The dynamic switching is done by counting the number of stack frames, and then + switching to the "slow" implementation when we exceed a given limit. This means that + short lists use the fast stack-allocation method, and long lists use a slower one that + doesn't require stack space. *) +let rec count_append l1 l2 count = + match l2 with + | [] -> l1 + | _ -> + match l1 with + | [] -> l2 + | [x1] -> x1 :: l2 + | [x1; x2] -> x1 :: x2 :: l2 + | [x1; x2; x3] -> x1 :: x2 :: x3 :: l2 + | [x1; x2; x3; x4] -> x1 :: x2 :: x3 :: x4 :: l2 + | x1 :: x2 :: x3 :: x4 :: x5 :: tl -> + x1 :: x2 :: x3 :: x4 :: x5 :: + (if count > 1000 + then slow_append tl l2 + else count_append tl l2 (count + 1)) + +let append l1 l2 = count_append l1 l2 0 + +let slow_map l ~f = rev (rev_map l ~f) + +let rec count_map ~f l ctr = + match l with + | [] -> [] + | [x1] -> + let f1 = f x1 in + [f1] + | [x1; x2] -> + let f1 = f x1 in + let f2 = f x2 in + [f1; f2] + | [x1; x2; x3] -> + let f1 = f x1 in + let f2 = f x2 in + let f3 = f x3 in + [f1; f2; f3] + | [x1; x2; x3; x4] -> + let f1 = f x1 in + let f2 = f x2 in + let f3 = f x3 in + let f4 = f x4 in + [f1; f2; f3; f4] + | x1 :: x2 :: x3 :: x4 :: x5 :: tl -> + let f1 = f x1 in + let f2 = f x2 in + let f3 = f x3 in + let f4 = f x4 in + let f5 = f x5 in + f1 :: f2 :: f3 :: f4 :: f5 :: + (if ctr > 1000 + then slow_map ~f tl + else count_map ~f tl (ctr + 1)) + +let map l ~f = count_map ~f l 0 + +let folding_map t ~init ~f = + let acc = ref init in + map t ~f:(fun x -> + let new_acc, y = f !acc x in + acc := new_acc; + y) +;; + +let fold_map t ~init ~f = + let acc = ref init in + let result = + map t ~f:(fun x -> + let new_acc, y = f !acc x in + acc := new_acc; + y) + in + !acc, result +;; + +let (>>|) l f = map l ~f + +let map2_ok l1 l2 ~f = rev (rev_map2_ok l1 l2 ~f) + +let map2 l1 l2 ~f = check_length2 l1 l2 ~f:(map2_ok ~f) + +let map2_exn l1 l2 ~f = + check_length2_exn "map2_exn" l1 l2; + map2_ok l1 l2 ~f +;; + +let rev_map3_ok l1 l2 l3 ~f = + let rec loop l1 l2 l3 ac = + match (l1, l2, l3) with + | ([], [], []) -> ac + | (x1 :: l1, x2 :: l2, x3 :: l3) -> loop l1 l2 l3 (f x1 x2 x3 :: ac) + | _ -> assert false + in + loop l1 l2 l3 []; +;; + +let rev_map3 l1 l2 l3 ~f = + check_length3 l1 l2 l3 ~f:(rev_map3_ok ~f) +;; + +let rev_map3_exn l1 l2 l3 ~f = + check_length3_exn "rev_map3_exn" l1 l2 l3; + rev_map3_ok l1 l2 l3 ~f +;; + +let map3_ok l1 l2 l3 ~f = rev (rev_map3_ok l1 l2 l3 ~f) + +let map3 l1 l2 l3 ~f = check_length3 l1 l2 l3 ~f:(map3_ok ~f) + +let map3_exn l1 l2 l3 ~f = + check_length3_exn "map3_exn" l1 l2 l3; + map3_ok l1 l2 l3 ~f; +;; + +let rec rev_map_append l1 l2 ~f = + match l1 with + | [] -> l2 + | h :: t -> rev_map_append ~f t (f h :: l2) + +let fold_right l ~f ~init = + match l with + | [] -> init (* avoid the allocation of [~f] below *) + | _ -> fold ~f:(fun a b -> f b a) ~init (rev l) + +let unzip list = + let rec loop list l1 l2 = + match list with + | [] -> (rev l1, rev l2) + | (x, y) :: tl -> loop tl (x :: l1) (y :: l2) + in + loop list [] [] + +let unzip3 list = + let rec loop list l1 l2 l3 = + match list with + | [] -> (rev l1, rev l2, rev l3) + | (x, y, z) :: tl -> loop tl (x :: l1) (y :: l2) (z :: l3) + in + loop list [] [] [] + +let zip_exn l1 l2 = + check_length2_exn "zip_exn" l1 l2; + map2_ok ~f:(fun a b -> (a, b)) l1 l2 + +let zip l1 l2 = map2 ~f:(fun a b -> (a, b)) l1 l2 + +(** Additional list operations *) + +let rev_mapi l ~f = + let rec loop i acc = function + | [] -> acc + | h :: t -> loop (i + 1) (f i h :: acc) t + in + loop 0 [] l + +let mapi l ~f = rev (rev_mapi l ~f) + +let folding_mapi t ~init ~f = + let acc = ref init in + mapi t ~f:(fun i x -> + let new_acc, y = f i !acc x in + acc := new_acc; + y) +;; + +let fold_mapi t ~init ~f = + let acc = ref init in + let result = + mapi t ~f:(fun i x -> + let new_acc, y = f i !acc x in + acc := new_acc; + y) + in + !acc, result +;; + +let iteri l ~f = + ignore (fold l ~init:0 ~f:(fun i x -> f i x; i + 1)); +;; + +let foldi t ~init ~f = + snd (fold t ~init:(0, init) ~f:(fun (i, acc) v -> (i + 1, f i acc v))) +;; + +let filteri l ~f = + rev (foldi l + ~f:(fun pos acc x -> + if f pos x then x :: acc else acc) + ~init:[]) + +let reduce l ~f = match l with + | [] -> None + | hd :: tl -> Some (fold ~init:hd ~f tl) + +let reduce_exn l ~f = + match reduce l ~f with + | None -> raise (Invalid_argument "List.reduce_exn") + | Some v -> v + +let reduce_balanced l ~f = + (* Call the "size" of a value the number of list elements that have been combined into + it via calls to [f]. We proceed by using [f] to combine elements in the accumulator + of the same size until we can't combine any more, then getting a new element from the + input list and repeating. + + With this strategy, in the accumulator: + - we only ever have elements of sizes a power of two + - we never have more than one element of each size + - the sum of all the element sizes is equal to the number of elements consumed + + These conditions enforce that list of elements of each size is precisely the binary + expansion of the number of elements consumed: if you've consumed 13 = 0b1101 + elements, you have one element of size 8, one of size 4, and one of size 1. Hence + when a new element comes along, the number of combinings you need to do is the number + of trailing 1s in the binary expansion of [num], the number of elements that have + already gone into the accumulator. The accumulator is in ascending order of size, so + the next element to combine with is always the head of the list. *) + let rec step_accum num acc x = + if num land 1 = 0 + then x :: acc + else + match acc with + | [] -> assert false + (* New elements from later in the input list go on the front of the accumulator, so + the accumulator is in reverse order wrt the original list order, hence [f y x] + instead of [f x y]. *) + | y :: ys -> step_accum (num asr 1) ys (f y x) + in + (* Experimentally, inlining [foldi] and unrolling this loop a few times can reduce + runtime down to a third and allocation to 1/16th or so in the microbenchmarks below. + However, in most use cases [f] is likely to be expensive (otherwise why do you care + about the order of reduction?) so the overhead of this function itself doesn't really + matter. If you come up with a use-case where it does, then that's something you might + want to try: see hg log -pr 49ef065f429d. *) + match foldi l ~init:[] ~f:step_accum with + | [] -> None + | x :: xs -> Some (fold xs ~init:x ~f:(fun x y -> f y x)) + +let reduce_balanced_exn l ~f = + match reduce_balanced l ~f with + | None -> raise (Invalid_argument "List.reduce_balanced_exn") + | Some v -> v + +let groupi l ~break = + let groups = + foldi l ~init:[] ~f:(fun i acc x -> + match acc with + | [] -> [[x]] + | current_group :: tl -> + if break i (hd_exn current_group) x then + [x] :: current_group :: tl (* start new group *) + else + (x :: current_group) :: tl) (* extend current group *) + in + match groups with + | [] -> [] + | l -> rev_map l ~f:rev + +let group l ~break = groupi l ~break:(fun _ x y -> break x y) + +let concat_map l ~f = + let rec aux acc = function + | [] -> rev acc + | hd :: tl -> aux (rev_append (f hd) acc) tl + in + aux [] l + +let concat_mapi l ~f = + let rec aux cont acc = function + | [] -> rev acc + | hd :: tl -> aux (cont + 1) (rev_append (f cont hd) acc) tl + in + aux 0 [] l + +let merge l1 l2 ~compare = + let rec loop acc l1 l2 = + match l1,l2 with + | [], l2 -> rev_append acc l2 + | l1, [] -> rev_append acc l1 + | h1 :: t1, h2 :: t2 -> + if compare h1 h2 <= 0 + then loop (h1 :: acc) t1 l2 + else loop (h2 :: acc) l1 t2 + in + loop [] l1 l2 +;; + + +include struct + (* We are explicit about what we import from the general Monad functor so that we don't + accidentally rebind more efficient list-specific functions. *) + module Monad = Monad.Make (struct + type 'a t = 'a list + let bind x ~f = concat_map x ~f + let map = `Custom map + let return x = [x] + end) + open Monad + module Monad_infix = Monad_infix + module Let_syntax = Let_syntax + let ignore_m = ignore_m + let join = join + let bind = bind + let (>>=) t f = bind t ~f + let return = return + let all = all + let all_unit = all_unit + let all_ignore = all_unit +end + +(** returns final element of list *) +let rec last_exn list = match list with + | [x] -> x + | _ :: tl -> last_exn tl + | [] -> raise (Invalid_argument "List.last") + +(** optionally returns final element of list *) +let rec last list = match list with + | [x] -> Some x + | _ :: tl -> last tl + | [] -> None +;; + +let rec is_prefix list ~prefix ~equal = + match prefix with + | [] -> true + | hd::tl -> + match list with + | [] -> false + | hd'::tl' -> equal hd hd' && is_prefix tl' ~prefix:tl ~equal +;; + +let find_consecutive_duplicate t ~equal = + match t with + | [] -> None + | a1 :: t -> + let rec loop a1 t = + match t with + | [] -> None + | a2 :: t -> if equal a1 a2 then Some (a1, a2) else loop a2 t + in + loop a1 t +;; + +(* returns list without adjacent duplicates *) +let remove_consecutive_duplicates ?(which_to_keep=`Last) list ~equal = + let rec loop to_keep accum = function + | [] -> to_keep :: accum + | hd :: tl -> + if equal hd to_keep then ( + let to_keep = + match which_to_keep with + | `First -> to_keep + | `Last -> hd + in + loop to_keep accum tl + ) else ( + loop hd (to_keep :: accum) tl + ) + in + match list with + | [] -> [] + | hd :: tl -> rev (loop hd [] tl) +;; + +(** returns sorted version of list with duplicates removed *) +let dedup_and_sort ~compare list = + match list with + | [] -> [] (* performance hack *) + | _ -> + let equal x x' = compare x x' = 0 in + let sorted = sort ~compare list in + remove_consecutive_duplicates ~equal sorted + +let dedup = dedup_and_sort + +let find_a_dup ~compare l = + let sorted = sort ~compare l in + let rec loop l = match l with + | [] | [_] -> None + | hd1 :: (hd2 :: _ as tl) -> + if compare hd1 hd2 = 0 then Some hd1 else loop tl + in + loop sorted +;; + +let contains_dup ~compare lst = + match find_a_dup ~compare lst with + | Some _ -> true + | None -> false +;; + +let find_all_dups ~compare l = + (* We add this reversal, so we can skip a [rev] at the end. We could skip + [rev] anyway since we don not give any ordering guarantees, but it is + nice to get results in natural order. *) + let compare a b = (-1) * compare a b in + let sorted = sort ~compare l in + (* Walk the list and record the first of each consecutive run of identical elements *) + let rec loop sorted prev ~already_recorded acc = + match sorted with + | [] -> acc + | hd :: tl -> + if compare prev hd <> 0 + then loop tl hd ~already_recorded:false acc + else if already_recorded + then loop tl hd ~already_recorded:true acc + else loop tl hd ~already_recorded:true (hd :: acc) + in + match sorted with + | [] -> [] + | hd :: tl -> loop tl hd ~already_recorded:false [] +;; + +let count t ~f = Container.count ~fold t ~f +let sum m t ~f = Container.sum ~fold m t ~f +let min_elt t ~compare = Container.min_elt ~fold t ~compare +let max_elt t ~compare = Container.max_elt ~fold t ~compare + +let counti t ~f = foldi t ~init:0 ~f:(fun idx count a -> if f idx a then count + 1 else count) + +let init n ~f = + if n < 0 then invalid_argf "List.init %d" n (); + let rec loop i accum = + assert (i >= 0); + if i = 0 then accum + else loop (i-1) (f (i-1) :: accum) + in + loop n [] +;; + +let rev_filter_map l ~f = + let rec loop l accum = + match l with + | [] -> accum + | hd :: tl -> + match f hd with + | Some x -> loop tl (x :: accum) + | None -> loop tl accum + in + loop l [] +;; + +let filter_map l ~f = rev (rev_filter_map l ~f) + +let rev_filter_mapi l ~f = + let rec loop i l accum = + match l with + | [] -> accum + | hd :: tl -> + match f i hd with + | Some x -> loop (i + 1) tl (x :: accum) + | None -> loop (i + 1) tl accum + in + loop 0 l [] +;; + +let filter_mapi l ~f = rev (rev_filter_mapi l ~f) + +let filter_opt l = filter_map l ~f:Fn.id + +let partition3_map t ~f = + let rec loop t fst snd trd = + match t with + | [] -> (rev fst, rev snd, rev trd) + | x :: t -> + match f x with + | `Fst y -> loop t (y :: fst) snd trd + | `Snd y -> loop t fst (y :: snd) trd + | `Trd y -> loop t fst snd (y :: trd) + in + loop t [] [] [] +;; + +let partition_tf t ~f = + let f x = if f x then `Fst x else `Snd x in + partition_map t ~f +;; + +let partition_result t = + let f x = + match x with + | Ok v -> `Fst v + | Error e -> `Snd e + in + partition_map t ~f + +module Assoc = struct + + type ('a, 'b) t = ('a * 'b) list [@@deriving_inline sexp] + let t_of_sexp : + 'a 'b . + (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> + (Ppx_sexp_conv_lib.Sexp.t -> 'b) -> + Ppx_sexp_conv_lib.Sexp.t -> ('a, 'b) t + = + let _tp_loc = "src/list.ml.Assoc.t" in + fun _of_a -> + fun _of_b -> + fun t -> + list_of_sexp + (function + | Ppx_sexp_conv_lib.Sexp.List (v0::v1::[]) -> + let v0 = _of_a v0 + and v1 = _of_b v1 in (v0, v1) + | sexp -> + Ppx_sexp_conv_lib.Conv_error.tuple_of_size_n_expected _tp_loc + 2 sexp) t + let sexp_of_t : + 'a 'b . + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('b -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 'b) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_a -> + fun _of_b -> + fun v -> + sexp_of_list + (function + | (v0, v1) -> + let v0 = _of_a v0 + and v1 = _of_b v1 in Ppx_sexp_conv_lib.Sexp.List [v0; v1]) v + [@@@end] + + let find t ~equal key = + match find t ~f:(fun (key', _) -> equal key key') with + | None -> None + | Some x -> Some (snd x) + + let find_exn t ~equal key = + match find t key ~equal with + | None -> raise Caml.Not_found + | Some value -> value + + let mem t ~equal key = + match find t ~equal key with + | None -> false + | Some _ -> true + ;; + + let remove t ~equal key = + filter t ~f:(fun (key', _) -> not (equal key key')) + + let add t ~equal key value = + (* the remove doesn't change the map semantics, but keeps the list small *) + (key, value) :: remove t ~equal key + + let inverse t = map t ~f:(fun (x, y) -> (y, x)) + + let map t ~f = map t ~f:(fun (key, value) -> (key, f value)) + +end + +let sub l ~pos ~len = + (* We use [pos > length l - len] rather than [pos + len > length l] to avoid the + possibility of overflow. *) + if pos < 0 || len < 0 || pos > length l - len then invalid_arg "List.sub"; + rev + (foldi l ~init:[] + ~f:(fun i acc el -> + if i >= pos && i < (pos + len) + then el :: acc + else acc + ) + ) +;; + +let split_n t_orig n = + if n <= 0 then + ([], t_orig) + else + let rec loop n t accum = + if n = 0 then + (rev accum, t) + else + match t with + | [] -> (t_orig, []) (* in this case, t_orig = rev accum *) + | hd :: tl -> loop (n - 1) tl (hd :: accum) + in + loop n t_orig [] + +(* copied from [split_n] to avoid allocating a tuple *) +let take t_orig n = + if n <= 0 then + [] + else + let rec loop n t accum = + if n = 0 then + rev accum + else + match t with + | [] -> t_orig + | hd :: tl -> loop (n - 1) tl (hd :: accum) + in + loop n t_orig [] + +let rec drop t n = + match t with + | _ :: tl when n > 0 -> drop tl (n - 1) + | t -> t + +let chunks_of l ~length = + if length <= 0 + then invalid_argf "List.chunks_of: Expected length > 0, got %d" length (); + let rec aux of_length acc l = + match l with + | [] -> rev acc + | _ :: _ -> + let sublist, l = split_n l length in + aux of_length (sublist :: acc) l + in + aux length [] l +;; + +let split_while xs ~f = + let rec loop acc = function + | hd :: tl when f hd -> loop (hd :: acc) tl + | t -> (rev acc, t) + in + loop [] xs +;; + +(* copied from [split_while] to avoid allocating a tuple *) +let take_while xs ~f = + let rec loop acc = function + | hd :: tl when f hd -> loop (hd :: acc) tl + | _ -> rev acc + in + loop [] xs +;; + +let rec drop_while t ~f = + match t with + | hd :: tl when f hd -> drop_while tl ~f + | t -> t + +let cartesian_product list1 list2 = + if is_empty list2 then [] else + let rec loop l1 l2 accum = match l1 with + | [] -> accum + | (hd :: tl) -> + loop tl l2 + (rev_append + (map ~f:(fun x -> (hd,x)) l2) + accum) + in + rev (loop list1 list2 []) + +let concat l = fold_right l ~init:[] ~f:append + +let concat_no_order l = fold l ~init:[] ~f:(fun acc l -> rev_append l acc) + +let cons x l = x :: l + +let is_sorted l ~compare = + let rec loop l = + match l with + | [] | [_] -> true + | x1 :: ((x2 :: _) as rest) -> + compare x1 x2 <= 0 && loop rest + in loop l + +let is_sorted_strictly l ~compare = + let rec loop l = + match l with + | [] | [_] -> true + | x1 :: ((x2 :: _) as rest) -> + compare x1 x2 < 0 && loop rest + in loop l +;; + +module Infix = struct + let ( @ ) = append +end + +let permute ?(random_state = Random.State.default) list = + match list with + (* special cases to speed things up in trivial cases *) + | [] | [_] -> list + | [ x; y ] -> if Random.State.bool random_state then [ y; x ] else list + | _ -> + let arr = Array.of_list list in + Array_permute.permute arr ~random_state; + Array.to_list arr; +;; + +let random_element_exn ?(random_state = Random.State.default) list = + if is_empty list + then failwith "List.random_element_exn: empty list" + else nth_exn list (Random.State.int random_state (length list)) +;; + +let random_element ?(random_state = Random.State.default) list = + try Some (random_element_exn ~random_state list) + with _ -> None +;; + +let rec compare cmp a b = + match a, b with + | [], [] -> 0 + | [], _ -> -1 + | _ , [] -> 1 + | x :: xs, y :: ys -> + let n = cmp x y in + if n = 0 then compare cmp xs ys + else n +;; + +let hash_fold_t = hash_fold_list + +let equal equal t1 t2 = + let rec loop ~equal t1 t2 = + match t1, t2 with + | [], [] -> true + | x1 :: t1, x2 :: t2 -> equal x1 x2 && loop ~equal t1 t2 + | _ -> false + in + loop ~equal t1 t2 +;; + +let transpose = + let rec transpose_aux t rev_columns = + match partition_map t ~f:(function [] -> `Snd () | x :: xs -> `Fst (x, xs)) with + | (_ :: _, _ :: _) -> None + | ([], _) -> Some (rev_append rev_columns []) + | (heads_and_tails, []) -> + let (column, trimmed_rows) = unzip heads_and_tails in + transpose_aux trimmed_rows (column :: rev_columns) + in + fun t -> + transpose_aux t [] + +exception Transpose_got_lists_of_different_lengths of int list [@@deriving_inline sexp] +let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Transpose_got_lists_of_different_lengths]) + (function + | Transpose_got_lists_of_different_lengths v0 -> + let v0 = sexp_of_list sexp_of_int v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom + "src/list.ml.Transpose_got_lists_of_different_lengths"; + v0] + | _ -> assert false) +[@@@end] + +let transpose_exn l = + match transpose l with + | Some l -> l + | None -> + raise (Transpose_got_lists_of_different_lengths (map l ~f:length)) + +let intersperse t ~sep = + match t with + | [] -> [] + | x :: xs -> x :: fold_right xs ~init:[] ~f:(fun y acc -> sep :: y :: acc) + +let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t +let fold_until t ~init ~f = Container.fold_until ~fold ~init ~f t diff --git a/src/list.mli b/src/list.mli new file mode 100644 index 0000000..0a1729a --- /dev/null +++ b/src/list.mli @@ -0,0 +1,514 @@ +(** Immutable, singly-linked lists, giving fast access to the front of the list, and slow + (i.e., O(n)) access to the back of the list. The comparison functions on lists are + lexicographic. *) + +open! Import + +type 'a t = 'a list [@@deriving_inline compare, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Container.S1 with type 'a t := 'a t +include Monad.S with type 'a t := 'a t + +(** [Or_unequal_lengths] is used for functions that take multiple lists and that only make + sense if all the lists have the same length, e.g., [iter2], [map3]. Such functions + check the list lengths prior to doing anything else, and return [Unequal_lengths] if + not all the lists have the same length. *) +module Or_unequal_lengths : sig + type 'a t = + | Ok of 'a + | Unequal_lengths + [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] +end + +(** [of_list] is the identity function. It is useful so that the [List] module matches + the same signature that other container modules do, namely: + + {[ + val of_list : 'a List.t -> 'a t + ]} *) +val of_list : 'a t -> 'a t + +val nth : 'a t -> int -> 'a option + +(** Return the [n]-th element of the given list. The first element (head of the list) is + at position 0. Raise if the list is too short or [n] is negative. *) +val nth_exn : 'a t -> int -> 'a + +(** List reversal. *) +val rev : 'a t -> 'a t + +(** [rev_append l1 l2] reverses [l1] and concatenates it to [l2]. This is equivalent + to [(]{!List.rev}[ l1) @ l2], but [rev_append] is more efficient. *) +val rev_append : 'a t -> 'a t -> 'a t + +(** [unordered_append l1 l2] has the same elements as [l1 @ l2], but in some + unspecified order. Generally takes time proportional to length of first list, but is + O(1) if either list is empty. *) +val unordered_append : 'a t -> 'a t -> 'a t + +(** [rev_map l ~f] gives the same result as {!List.rev}[ (]{!ListLabels.map}[ f l)], + but is more efficient. *) +val rev_map : 'a t -> f:('a -> 'b) -> 'b t + +(** [iter2 [a1; ...; an] [b1; ...; bn] ~f] calls in turn [f a1 b1; ...; f an bn]. + The exn version will raise if the two lists have different lengths. *) +val iter2_exn : 'a t -> 'b t -> f:('a -> 'b -> unit) -> unit +val iter2 : 'a t -> 'b t -> f:('a -> 'b -> unit) -> unit Or_unequal_lengths.t + +(** [rev_map2_exn l1 l2 ~f] gives the same result as [List.rev (List.map2_exn l1 l2 + ~f)], but is more efficient. *) +val rev_map2_exn : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t +val rev_map2 : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t Or_unequal_lengths.t + +(** [fold2 ~f ~init:a [b1; ...; bn] [c1; ...; cn]] is [f (... (f (f a b1 c1) b2 c2) + ...) bn cn]. The exn version will raise if the two lists have different lengths. *) +val fold2_exn : 'a t -> 'b t -> init:'c -> f:('c -> 'a -> 'b -> 'c) -> 'c +val fold2 : 'a t -> 'b t -> init:'c -> f:('c -> 'a -> 'b -> 'c) -> 'c Or_unequal_lengths.t + +(** Like {!List.for_all}, but passes the index as an argument. *) +val for_alli : 'a t -> f:(int -> 'a -> bool) -> bool + +(** Like {!List.for_all}, but for a two-argument predicate. The exn version will raise if + the two lists have different lengths. *) +val for_all2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool +val for_all2 : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool Or_unequal_lengths.t + +(** Like {!List.exists}, but passes the index as an argument. *) +val existsi : 'a t -> f:(int -> 'a -> bool) -> bool + +(** Like {!List.exists}, but for a two-argument predicate. The exn version will raise if + the two lists have different lengths. *) +val exists2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool +val exists2 : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool Or_unequal_lengths.t + +(** [filter l ~f] returns all the elements of the list [l] that satisfy the predicate [p]. + The order of the elements in the input list is preserved. *) +val filter : 'a t -> f:('a -> bool) -> 'a t + +(** Like [filter], but reverses the order of the input list. *) +val rev_filter : 'a t -> f:('a -> bool) -> 'a t + +val filteri : 'a t -> f: (int -> 'a -> bool) -> 'a t + +(** [partition_map t ~f] partitions [t] according to [f]. *) +val partition_map : 'a t -> f:('a -> [ `Fst of 'b | `Snd of 'c ]) -> 'b t * 'c t + +val partition3_map + : 'a t + -> f:('a -> [ `Fst of 'b | `Snd of 'c | `Trd of 'd]) + -> 'b t * 'c t * 'd t + +(** [partition_tf l ~f] returns a pair of lists [(l1, l2)], where [l1] is the list of all + the elements of [l] that satisfy the predicate [p], and [l2] is the list of all the + elements of [l] that do not satisfy [p]. The order of the elements in the input list + is preserved. The "tf" suffix is mnemonic to remind readers at a call that the result + is (trues, falses). *) +val partition_tf : 'a t -> f:('a -> bool) -> 'a t * 'a t + +(** [partition_result l] returns a pair of lists [(l1, l2)], where [l1] is the + list of all [Ok] elements in [l] and [l2] is the list of all [Error] + elements. + The order of elements in the input list is preserved. *) +val partition_result : ('ok, 'error) Result.t t -> 'ok t * 'error t + +(** [split_n \[e1; ...; em\] n] is [(\[e1; ...; en\], \[en+1; ...; em\])]. + + - If [n > m], [(\[e1; ...; em\], \[\])] is returned. + - If [n < 0], [(\[\], \[e1; ...; em\])] is returned. *) +val split_n : 'a t -> int -> 'a t * 'a t + +(** Sort a list in increasing order according to a comparison function. The comparison + function must return 0 if its arguments compare as equal, a positive integer if the + first is greater, and a negative integer if the first is smaller (see [Array.sort] for + a complete specification). For example, {!Poly.compare} is a suitable + comparison function. + + The current implementation uses Merge Sort. It runs in linear heap space and + logarithmic stack space. + + Presently, the sort is stable, meaning that two equal elements in the input will be in + the same order in the output. *) +val sort : 'a t -> compare:('a -> 'a -> int) -> 'a t + +(** Like [sort], but guaranteed to be stable. *) +val stable_sort : 'a t -> compare:('a -> 'a -> int) -> 'a t + +(** Merges two lists: assuming that [l1] and [l2] are sorted according to the comparison + function [compare], [merge compare l1 l2] will return a sorted list containing all the + elements of [l1] and [l2]. If several elements compare equal, the elements of [l1] + will be before the elements of [l2]. *) +val merge : 'a t -> 'a t -> compare:('a -> 'a -> int) -> 'a t + +val hd : 'a t -> 'a option + +val tl : 'a t -> 'a t option + +(** Returns the first element of the given list. Raises if the list is empty. *) +val hd_exn : 'a t -> 'a + +(** Returns the given list without its first element. Raises if the list is empty. *) +val tl_exn : 'a t -> 'a t + +val findi : 'a t -> f:(int -> 'a -> bool) -> (int * 'a) option + +(** [find_exn t ~f] returns the first element of [t] that satisfies [f]. It raises + [Caml.Not_found] or [Not_found_s] if there is no such element. *) +val find_exn : 'a t -> f:('a -> bool) -> 'a + +(** Returns the first evaluation of [f] that returns [Some]. Raises [Caml.Not_found] or + [Not_found_s] if [f] always returns [None]. *) +val find_map_exn : 'a t -> f:('a -> 'b option) -> 'b + +(** Like [find_map] and [find_map_exn], but passes the index as an argument. *) + +val find_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b option +val find_mapi_exn : 'a t -> f:(int -> 'a -> 'b option) -> 'b + +(** E.g., [append [1; 2] [3; 4; 5]] is [[1; 2; 3; 4; 5]] *) +val append : 'a t -> 'a t -> 'a t + +(** [map f [a1; ...; an]] applies function [f] to [a1], [a2], ..., [an], in order, + and builds the list [[f a1; ...; f an]] with the results returned by [f]. *) +val map : 'a t -> f:('a -> 'b) -> 'b t + +(** [folding_map] is a version of [map] that threads an accumulator through calls to + [f]. *) + +val folding_map : 'a t -> init:'b -> f:( 'b -> 'a -> 'b * 'c) -> 'c t +val folding_mapi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b * 'c) -> 'c t + +(** [fold_map] is a combination of [fold] and [map] that threads an accumulator through + calls to [f]. *) + +val fold_map : 'a t -> init:'b -> f:( 'b -> 'a -> 'b * 'c) -> 'b * 'c t +val fold_mapi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b * 'c) -> 'b * 'c t + +(** [concat_map t ~f] is [concat (map t ~f)], except that there is no guarantee about the + order in which [f] is applied to the elements of [t]. *) +val concat_map : 'a t -> f:('a -> 'b t) -> 'b t + +(** [concat_mapi t ~f] is like concat_map, but passes the index as an argument *) +val concat_mapi : 'a t -> f:(int -> 'a -> 'b t) -> 'b t + +(** [map2 [a1; ...; an] [b1; ...; bn] ~f] is [[f a1 b1; ...; f an bn]]. The exn + version will raise if the two lists have different lengths. *) +val map2_exn :'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t +val map2 :'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t Or_unequal_lengths.t + +(** Analogous to [rev_map2]. *) + +val rev_map3_exn : 'a t -> 'b t -> 'c t -> f:('a -> 'b -> 'c -> 'd) -> 'd t +val rev_map3 : 'a t -> 'b t -> 'c t -> f:('a -> 'b -> 'c -> 'd) -> 'd t Or_unequal_lengths.t + +(** Analogous to [map2]. *) + +val map3_exn : 'a t -> 'b t -> 'c t -> f:('a -> 'b -> 'c -> 'd) -> 'd t +val map3 : 'a t -> 'b t -> 'c t -> f:('a -> 'b -> 'c -> 'd) -> 'd t Or_unequal_lengths.t + +(** [rev_map_append l1 l2 ~f] reverses [l1] mapping [f] over each element, and appends the + result to the front of [l2]. *) +val rev_map_append : 'a t -> 'b t -> f:('a -> 'b) -> 'b t + +(** [fold_right [a1; ...; an] ~f ~init:b] is [f a1 (f a2 (... (f an b) ...))]. *) +val fold_right : 'a t -> f:('a -> 'b -> 'b) -> init:'b -> 'b + +(** [fold_left] is the same as {!Container.S1.fold}, and one should always use + [fold] rather than [fold_left], except in functors that are parameterized + over a more general signature where this equivalence does not hold. *) +val fold_left : 'a t -> init:'b -> f:('b -> 'a -> 'b) -> 'b + +(** Transform a list of pairs into a pair of lists: [unzip [(a1,b1); ...; (an,bn)]] is + [([a1; ...; an], [b1; ...; bn])]. *) + +val unzip : ('a * 'b ) t -> 'a t * 'b t +val unzip3 : ('a * 'b * 'c) t -> 'a t * 'b t * 'c t + +(** Transform a pair of lists into an (optional) list of pairs: [zip [a1; ...; an] [b1; + ...; bn]] is [[(a1,b1); ...; (an,bn)]]. Returns [Unequal_lengths] if the two lists + have different lengths. *) + +val zip : 'a t -> 'b t -> ('a * 'b) t Or_unequal_lengths.t +val zip_exn : 'a t -> 'b t -> ('a * 'b) t + +(** [mapi] is just like map, but it also passes in the index of each element as the first + argument to the mapped function. Tail-recursive. *) +val mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + +val rev_mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + +(** [iteri] is just like [iter], but it also passes in the index of each element as the + first argument to the iter'd function. Tail-recursive. *) +val iteri : 'a t -> f:(int -> 'a -> unit) -> unit + +(** [foldi] is just like [fold], but it also passes in the index of each element as the + first argument to the folded function. Tail-recursive. *) +val foldi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b) -> 'b + +(** [reduce_exn [a1; ...; an] ~f] is [f (... (f (f a1 a2) a3) ...) an]. It fails on the + empty list. Tail recursive. *) +val reduce_exn : 'a t -> f:('a -> 'a -> 'a) -> 'a +val reduce : 'a t -> f:('a -> 'a -> 'a) -> 'a option + +(** [reduce_balanced] returns the same value as [reduce] when [f] is associative, but + differs in that the tree of nested applications of [f] has logarithmic depth. + + This is useful when your ['a] grows in size as you reduce it and [f] becomes more + expensive with bigger inputs. For example, [reduce_balanced ~f:(^)] takes [n*log(n)] + time, while [reduce ~f:(^)] takes quadratic time. *) +val reduce_balanced : 'a t -> f:('a -> 'a -> 'a) -> 'a option +val reduce_balanced_exn : 'a t -> f:('a -> 'a -> 'a) -> 'a + +(** [group l ~break] returns a list of lists (i.e., groups) whose concatenation is equal + to the original list. Each group is broken where [break] returns true on a pair of + successive elements. + + Example: + + {[ + group ~break:(<>) ['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i'] -> + + [['M'];['i'];['s';'s'];['i'];['s';'s'];['i'];['p';'p'];['i']] ]} *) +val group : 'a t -> break:('a -> 'a -> bool) -> 'a t t + +(** This is just like [group], except that you get the index in the original list of the + current element along with the two elements. + + Example, group the chars of ["Mississippi"] into triples: + + {[ + groupi ~break:(fun i _ _ -> i mod 3 = 0) + ['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i'] -> + + [['M'; 'i'; 's']; ['s'; 'i'; 's']; ['s'; 'i'; 'p']; ['p'; 'i']] ]} +*) +val groupi : 'a t -> break:(int -> 'a -> 'a -> bool) -> 'a t t + +(** [chunks_of l ~length] returns a list of lists whose concatenation is equal to the + original list. Every list has [length] elements, except for possibly the last list, + which may have fewer. [chunks_of] raises if [length <= 0]. *) +val chunks_of : 'a t -> length : int -> 'a t t + +(** The final element of a list. The [_exn] version raises on the empty list. *) +val last : 'a t -> 'a option +val last_exn : 'a t -> 'a + +(** [is_prefix xs ~prefix] returns [true] if [xs] starts with [prefix]. *) +val is_prefix : 'a t -> prefix:'a t -> equal:('a -> 'a -> bool) -> bool + + +(** [find_consecutive_duplicate t ~equal] returns the first pair of consecutive elements + [(a1, a2)] in [t] such that [equal a1 a2]. They are returned in the same order as + they appear in [t]. [equal] need not be an equivalence relation; it is simply used as + a predicate on consecutive elements. *) +val find_consecutive_duplicate : 'a t -> equal:('a -> 'a -> bool) -> ('a * 'a) option + +(** Returns the given list with consecutive duplicates removed. The relative order of the + other elements is unaffected. The element kept from a run of duplicates is determined + by [which_to_keep]. *) +val remove_consecutive_duplicates + : ?which_to_keep:[ `First | `Last ] (** default = `Last *) + -> 'a t + -> equal:('a -> 'a -> bool) + -> 'a t + +(** Returns the given list with duplicates removed and in sorted order. *) +val dedup_and_sort : compare:('a -> 'a -> int) -> 'a t -> 'a t + +val dedup : compare:('a -> 'a -> int) -> 'a t -> 'a t +[@@deprecated "[since 2017-04] Use [dedup_and_sort] instead"] + +(** [find_a_dup] returns a duplicate from the list (with no guarantees about which + duplicate you get), or [None] if there are no dups. *) +val find_a_dup : compare:('a -> 'a -> int) -> 'a t -> 'a option + +(** Returns true if there are any two elements in the list which are the same. O(n log n) + time complexity. *) +val contains_dup : compare:('a -> 'a -> int) -> 'a t -> bool + +(** [find_all_dups] returns a list of all elements that occur more than once, with + no guarantees about order. O(n log n) time complexity. *) +val find_all_dups : compare:('a -> 'a -> int) -> 'a t -> 'a list + +(** [count l ~f] is the number of elements in [l] that satisfy the predicate [f]. *) +val count : 'a t -> f:( 'a -> bool) -> int +val counti : 'a t -> f:(int -> 'a -> bool) -> int + +(** [range ?stride ?start ?stop start_i stop_i] is the list of integers from [start_i] to + [stop_i], stepping by [stride]. If [stride] < 0 then we need [start_i] > [stop_i] for + the result to be nonempty (or [start_i] = [stop_i] in the case where both bounds are + inclusive). *) +val range + : ?stride:int (** default = 1 *) + -> ?start:[`inclusive|`exclusive] (** default = `inclusive *) + -> ?stop:[`inclusive|`exclusive] (** default = `exclusive *) + -> int + -> int + -> int t + +(** [range'] is analogous to [range] for general start/stop/stride types. [range'] raises + if [stride x] returns [x] or if the direction that [stride x] moves [x] changes from + one call to the next. *) +val range' + : compare:('a -> 'a -> int) + -> stride:('a -> 'a) + -> ?start:[`inclusive|`exclusive] (** default = `inclusive *) + -> ?stop:[`inclusive|`exclusive] (** default = `exclusive *) + -> 'a + -> 'a + -> 'a t + +(** [init n ~f] is [[(f 0); (f 1); ...; (f (n-1))]]. It is an error if [n < 0]. + [init] applies [f] to values in decreasing order; starting with [n - 1], and + ending with [0]. This is the opposite order to [Array.init]. *) +val init : int -> f:(int -> 'a) -> 'a t + +(** [rev_filter_map l ~f] is the reversed sublist of [l] containing only elements for + which [f] returns [Some e]. *) +val rev_filter_map : 'a t -> f:('a -> 'b option) -> 'b t + +(** rev_filter_mapi is just like [rev_filter_map], but it also passes in the index of each + element as the first argument to the mapped function. Tail-recursive. *) +val rev_filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + +(** [filter_map l ~f] is the sublist of [l] containing only elements for which [f] returns + [Some e]. *) +val filter_map : 'a t -> f:('a -> 'b option) -> 'b t + +(** filter_mapi is just like [filter_map], but it also passes in the index of each element + as the first argument to the mapped function. Tail-recursive. *) +val filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + +(** [filter_opt l] is the sublist of [l] containing only elements which are [Some e]. In + other words, [filter_opt l] = [filter_map ~f:ident l]. *) +val filter_opt : 'a option t -> 'a t + +(** Interpret a list of (key, value) pairs as a map in which only the first occurrence of + a key affects the semantics, i.e.: + + {[List.Assoc.xxx alist ...args... ]} + + is always the same as (or at least sort of isomorphic to): + + {[ Map.xxx (alist |> Map.of_alist_multi |> Map.map ~f:List.hd) ...args... ]} *) +module Assoc : sig + + type ('a, 'b) t = ('a * 'b) list [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('a,'b) t := ('a, 'b) t + end[@@ocaml.doc "@inline"] + [@@@end] + + val add : ('a, 'b) t -> equal:('a -> 'a -> bool) -> 'a -> 'b -> ('a, 'b) t + val find : ('a, 'b) t -> equal:('a -> 'a -> bool) -> 'a -> 'b option + val find_exn : ('a, 'b) t -> equal:('a -> 'a -> bool) -> 'a -> 'b + val mem : ('a, 'b) t -> equal:('a -> 'a -> bool) -> 'a -> bool + val remove : ('a, 'b) t -> equal:('a -> 'a -> bool) -> 'a -> ('a, 'b) t + val map : ('a, 'b) t -> f:('b -> 'c) -> ('a, 'c) t + + (** Bijectivity is not guaranteed because we allow a key to appear more than once. *) + val inverse : ('a, 'b) t -> ('b, 'a) t +end + +(** Note that [sub], unlike [slice], doesn't use Python-style indices! *) + +(** [sub pos len l] is the [len]-element sublist of [l], starting at [pos]. *) +val sub : 'a t -> pos:int -> len:int -> 'a t + +(** [take l n] returns the first [n] elements of [l], or all of [l] if [n > length l]. + [take l n = fst (split_n l n)]. *) +val take : 'a t -> int -> 'a t + +(** [drop l n] returns [l] without the first [n] elements, or the empty list if [n > + length l]. [drop l n] is equivalent to [snd (split_n l n)]. *) +val drop : 'a t -> int -> 'a t + +(** [take_while l ~f] returns the longest prefix of [l] for which [f] is [true]. *) +val take_while : 'a t -> f : ('a -> bool) -> 'a t + +(** [drop_while l ~f] drops the longest prefix of [l] for which [f] is [true]. *) +val drop_while : 'a t -> f : ('a -> bool) -> 'a t + +(** [split_while xs ~f = (take_while xs ~f, drop_while xs ~f)]. *) +val split_while : 'a t -> f : ('a -> bool) -> 'a t * 'a t + +(** Concatenates a list of lists. The elements of the argument are all concatenated + together (in the same order) to give the result. Tail recursive over outer and inner + lists. *) +val concat : 'a t t -> 'a t + +(** Like [concat], but faster and without preserving any ordering (i.e., for lists that + are essentially viewed as multi-sets). *) +val concat_no_order : 'a t t -> 'a t + +val cons : 'a -> 'a t -> 'a t + +(** Returns a list with all possible pairs -- if the input lists have length [len1] and + [len2], the resulting list will have length [len1 * len2]. *) +val cartesian_product : 'a t -> 'b t -> ('a * 'b) t + +(** [permute ?random_state t] returns a permutation of [t]. + + [permute] side-effects [random_state] by repeated calls to [Random.State.int]. If + [random_state] is not supplied, [permute] uses [Random.State.default]. *) +val permute : ?random_state:Random.State.t -> 'a t -> 'a t + +(** [random_element ?random_state t] is [None] if [t] is empty, else it is [Some x] for + some [x] chosen uniformly at random from [t]. + + [random_element] side-effects [random_state] by calling [Random.State.int]. If + [random_state] is not supplied, [random_element] uses [Random.State.default]. *) +val random_element : ?random_state:Random.State.t -> 'a t -> 'a option +val random_element_exn : ?random_state:Random.State.t -> 'a t -> 'a + +(** [is_sorted t ~compare] returns [true] iff for all adjacent [a1; a2] in [t], [compare + a1 a2 <= 0]. + + [is_sorted_strictly] is similar, except it uses [<] instead of [<=]. *) +val is_sorted : 'a t -> compare:('a -> 'a -> int) -> bool +val is_sorted_strictly : 'a t -> compare:('a -> 'a -> int) -> bool + +val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + +module Infix : sig + val ( @ ) : 'a t -> 'a t -> 'a t +end + +(** [transpose m] transposes the rows and columns of the matrix [m], + considered as either a row of column lists or (dually) a column of row lists. + + Example: + + {[transpose [[1;2;3];[4;5;6]] = [[1;4];[2;5];[3;6]]]} + + On non-empty rectangular matrices, [transpose] is an involution (i.e., [transpose + (transpose m) = m]). Transpose returns [None] when called on lists of lists with + non-uniform lengths. *) +val transpose : 'a t t -> 'a t t option + +(** [transpose_exn] transposes the rows and columns of its argument, throwing an exception + if the list is not rectangular. *) +val transpose_exn : 'a t t -> 'a t t + +(** [intersperse xs ~sep] places [sep] between adjacent elements of [xs]. For example, + [intersperse [1;2;3] ~sep:0 = [1;0;2;0;3]]. *) +val intersperse : 'a t -> sep:'a -> 'a t diff --git a/src/list0.ml b/src/list0.ml new file mode 100644 index 0000000..44e423a --- /dev/null +++ b/src/list0.ml @@ -0,0 +1,42 @@ +(* [List0] defines list functions that are primitives or can be simply defined in terms of + [Caml.List]. [List0] is intended to completely express the part of [Caml.List] that + [Base] uses -- no other file in Base other than list0.ml should use [Caml.List]. + [List0] has few dependencies, and so is available early in Base's build order. All + Base files that need to use lists and come before [Base.List] in build order should do + [module List = List0]. Defining [module List = List0] is also necessary because it + prevents ocamldep from mistakenly causing a file to depend on [Base.List]. *) + +open! Import0 + +let hd_exn = Caml.List.hd +let length = Caml.List.length +let rev_append = Caml.List.rev_append +let tl_exn = Caml.List.tl +let unzip = Caml.List.split + +(* These are eta expanded in order to permute parameter order to follow Base + conventions. *) +let exists t ~f = Caml.List.exists t ~f +let exists2_ok l1 l2 ~f = Caml.List.exists2 l1 l2 ~f +let find_exn t ~f = Caml.List.find t ~f +let fold t ~init ~f = Caml.List.fold_left t ~f ~init +let fold2_ok l1 l2 ~init ~f = Caml.List.fold_left2 l1 l2 ~init ~f +let for_all t ~f = Caml.List.for_all t ~f +let for_all2_ok l1 l2 ~f = Caml.List.for_all2 l1 l2 ~f +let iter t ~f = Caml.List.iter t ~f +let iter2_ok l1 l2 ~f = Caml.List.iter2 l1 l2 ~f +let nontail_map t ~f = Caml.List.map t ~f +let nontail_mapi t ~f = Caml.List.mapi t ~f +let partition t ~f = Caml.List.partition t ~f +let rev_map t ~f = Caml.List.rev_map t ~f +let rev_map2_ok l1 l2 ~f = Caml.List.rev_map2 l1 l2 ~f + +let sort l ~compare = Caml.List.sort l ~cmp:compare +let stable_sort l ~compare = Caml.List.stable_sort l ~cmp:compare + +let rev = function + | [] | [_] as res -> res + | x :: y :: rest -> rev_append rest [y; x] +;; + +let is_empty = function [] -> true | _ -> false diff --git a/src/list1.ml b/src/list1.ml new file mode 100644 index 0000000..d3856eb --- /dev/null +++ b/src/list1.ml @@ -0,0 +1,15 @@ +open! Import + +include List0 + +let partition_map t ~f = + let rec loop t fst snd = + match t with + | [] -> (rev fst, rev snd) + | x :: t -> + match f x with + | `Fst y -> loop t (y :: fst) snd + | `Snd y -> loop t fst (y :: snd) + in + loop t [] [] +;; diff --git a/src/map.ml b/src/map.ml new file mode 100644 index 0000000..18a2ee8 --- /dev/null +++ b/src/map.ml @@ -0,0 +1,1770 @@ +(***********************************************************************) +(* *) +(* Objective Caml *) +(* *) +(* Xavier Leroy, projet Cristal, INRIA Rocquencourt *) +(* *) +(* Copyright 1996 Institut National de Recherche en Informatique et *) +(* en Automatique. All rights reserved. This file is distributed *) +(* under the terms of the Apache 2.0 license. See ../THIRD-PARTY.txt *) +(* for details. *) +(* *) +(***********************************************************************) + +open! Import + +module List = List0 + +include Map_intf + +let with_return = With_return.with_return + +exception Duplicate [@@deriving_inline sexp] +let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Duplicate]) + (function + | Duplicate -> Ppx_sexp_conv_lib.Sexp.Atom "src/map.ml.Duplicate" + | _ -> assert false) +[@@@end] + +module Tree0 = struct + + type ('k, 'v) t = + | Empty + | Leaf of 'k * 'v + | Node of ('k, 'v) t * 'k * 'v * ('k, 'v) t * int + + type ('k, 'v) tree = ('k, 'v) t + + let height = function + | Empty -> 0 + | Leaf _ -> 1 + | Node(_,_,_,_,h) -> h + ;; + + let invariants = + let in_range lower upper compare_key k = + (match lower with + | None -> true + | Some lower -> compare_key lower k < 0 + ) + && (match upper with + | None -> true + | Some upper -> compare_key k upper < 0 + ) + in + let rec loop lower upper compare_key t = + match t with + | Empty -> true + | Leaf (k, _) -> in_range lower upper compare_key k + | Node (l, k, _, r, h) -> + let hl = height l and hr = height r in + abs (hl - hr) <= 2 + && h = (max hl hr) + 1 + && in_range lower upper compare_key k + && loop lower (Some k) compare_key l + && loop (Some k) upper compare_key r + in + fun t ~compare_key -> + loop None None compare_key t + ;; + + (* precondition: |height(l) - height(r)| <= 2 *) + let create l x d r = + let hl = height l and hr = height r in + if hl = 0 && hr = 0 then + Leaf (x, d) + else + Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1)) + ;; + + let singleton key data = Leaf (key, data) + + (* We must call [f] with increasing indexes, because the bin_prot reader in + Core_kernel.Map needs it. *) + let of_increasing_iterator_unchecked ~len ~f = + let rec loop n ~f i : (_, _) t = + match n with + | 0 -> Empty + | 1 -> + let k,v = f i in + Leaf (k, v) + | 2 -> + let kl,vl = f i in + let k ,v = f (i + 1) in + Node (Leaf (kl, vl), k, v, Empty, 2) + | 3 -> + let kl,vl = f i in + let k,v = f (i + 1) in + let kr,vr = f (i + 2) in + Node (Leaf (kl, vl), k, v, Leaf (kr, vr), 2) + | n -> + let left_length = n lsr 1 in + let right_length = n - left_length - 1 in + let left = loop left_length ~f i in + let k, v = f (i + left_length) in + let right = loop right_length ~f (i + left_length + 1) in + create left k v right + in + loop len ~f 0 + + let of_sorted_array_unchecked array ~compare_key = + let array_length = Array.length array in + let next = + if array_length < 2 + || let k0, _ = array.(0) in + let k1, _ = array.(1) in + compare_key k0 k1 < 0 + then + (fun i -> array.(i)) + else + (fun i -> array.(array_length - 1 - i)) + in + (of_increasing_iterator_unchecked ~len:array_length ~f:next, array_length) + + let of_sorted_array array ~compare_key = + match array with + | [||] | [|_|] -> Result.Ok (of_sorted_array_unchecked array ~compare_key) + | _ -> + with_return (fun r -> + let increasing = + match compare_key (fst array.(0)) (fst array.(1)) with + | 0 -> r.return (Or_error.error_string "of_sorted_array: duplicated elements") + | i -> i < 0 + in + for i = 1 to Array.length array - 2 do + match compare_key (fst array.(i)) (fst array.(i+1)) with + | 0 -> r.return (Or_error.error_string "of_sorted_array: duplicated elements") + | i -> + if Poly.(<>) (i < 0) increasing then + r.return (Or_error.error_string "of_sorted_array: elements are not ordered") + done; + Result.Ok (of_sorted_array_unchecked array ~compare_key) + ) + + (* precondition: |height(l) - height(r)| <= 3 *) + let bal l x d r = + let hl = height l in + let hr = height r in + if hl > hr + 2 then begin + match l with + | Empty -> invalid_arg "Map.bal" + | Leaf _ -> assert false (* height(Leaf) = 1 && 1 is not larger than hr + 2 *) + | Node(ll, lv, ld, lr, _) -> + if height ll >= height lr then + create ll lv ld (create lr x d r) + else begin + match lr with + Empty -> invalid_arg "Map.bal" + | Leaf (lrv, lrd) -> + create (create ll lv ld Empty) lrv lrd (create Empty x d r) + | Node(lrl, lrv, lrd, lrr, _)-> + create (create ll lv ld lrl) lrv lrd (create lrr x d r) + end + end else if hr > hl + 2 then begin + match r with + | Empty -> invalid_arg "Map.bal" + | Leaf _ -> assert false (* height(Leaf) = 1 && 1 is not larger than hl + 2 *) + | Node(rl, rv, rd, rr, _) -> + if height rr >= height rl then + create (create l x d rl) rv rd rr + else begin + match rl with + Empty -> invalid_arg "Map.bal" + | Leaf (rlv, rld) -> + create (create l x d Empty) rlv rld (create Empty rv rd rr) + | Node(rll, rlv, rld, rlr, _) -> + create (create l x d rll) rlv rld (create rlr rv rd rr) + end + end + else create l x d r + ;; + + let empty = Empty + + let is_empty = function Empty -> true | _ -> false + + let raise_key_already_present ~key ~sexp_of_key = + Error.raise_s ( + Sexp.message + "[Map.add_exn] got key already present" + [ "key", key |> sexp_of_key ]) + ;; + + module Add_or_set = struct + type t = + | Add_exn_internal + | Add_exn + | Set + end + + let rec find_and_add_or_set t ~length ~key:x ~data ~compare_key ~sexp_of_key + ~(add_or_set : Add_or_set.t) = + match t with + | Empty -> (Leaf (x, data), length + 1) + | Leaf(v, d) -> + let c = compare_key x v in + if c = 0 then + (match add_or_set with + | Add_exn_internal -> Exn.raise_without_backtrace Duplicate + | Add_exn -> raise_key_already_present ~key:x ~sexp_of_key + | Set -> (Leaf(x, data), length)) + else if c < 0 then + (Node(Leaf(x, data), v, d, Empty, 2), length + 1) + else + (Node(Empty, v, d, Leaf(x, data), 2), length + 1) + | Node(l, v, d, r, h) -> + let c = compare_key x v in + if c = 0 then + (match add_or_set with + | Add_exn_internal -> Exn.raise_without_backtrace Duplicate + | Add_exn -> raise_key_already_present ~key:x ~sexp_of_key + | Set -> (Node(l, x, data, r, h), length)) + else if c < 0 then + let l, length = + find_and_add_or_set ~length ~key:x ~data l ~compare_key ~sexp_of_key ~add_or_set + in + (bal l v d r, length) + else + let r, length = + find_and_add_or_set ~length ~key:x ~data r ~compare_key ~sexp_of_key ~add_or_set + in + (bal l v d r, length) + ;; + + let add_exn t ~length ~key ~data ~compare_key ~sexp_of_key = + find_and_add_or_set t ~length ~key ~data ~compare_key ~sexp_of_key ~add_or_set:Add_exn + ;; + + let add_exn_internal t ~length ~key ~data ~compare_key ~sexp_of_key = + find_and_add_or_set + t ~length ~key ~data ~compare_key ~sexp_of_key ~add_or_set:Add_exn_internal + ;; + + let set t ~length ~key ~data ~compare_key = + find_and_add_or_set t ~length ~key ~data + ~compare_key + ~sexp_of_key:(fun _ -> List []) + ~add_or_set:Set + ;; + + let set' t key data ~compare_key = fst (set t ~length:0 ~key ~data ~compare_key) + + module Build_increasing = struct + module Fragment = struct + type nonrec ('k, 'v) t = { + left_subtree : ('k, 'v) t; + key : 'k; + data: 'v; + } + + let singleton_to_tree_exn = function + | { left_subtree = Empty; key; data; } -> singleton key data + | _ -> failwith "Map.singleton_to_tree_exn: not a singleton" + ;; + + let singleton ~key ~data = { left_subtree = Empty; key; data; } + + (* precondition: |height(l.left_subtree) - height(r)| <= 2, + max_key(l) < min_key(r) + *) + let collapse l r = create l.left_subtree l.key l.data r + + (* precondition: |height(l.left_subtree) - height(r.left_subtree)| <= 2, + max_key(l) < min_key(r) + *) + let join l r = { r with left_subtree = collapse l r.left_subtree; } + + let max_key t = t.key + end + + (** Build trees from singletons in a balanced way by using skew binary encoding. + Each level contains trees of the same height, consecutive levels have consecutive + heights. There are no gaps. The first level are single keys. + *) + type ('k, 'v) t = + | Zero of unit (* [unit] to make pattern matching faster *) + | One of ('k, 'v) t * ('k, 'v) Fragment.t + | Two of ('k, 'v) t * ('k, 'v) Fragment.t * ('k, 'v) Fragment.t + + let empty = Zero () + + let add_unchecked = + let rec go t x = match t with + | Zero () -> One (t, x) + | One (t, y) -> Two (t, y, x) + | Two (t, z, y) -> One (go t (Fragment.join z y), x) + in + fun t ~key ~data -> go t (Fragment.singleton ~key ~data) + ;; + + let to_tree = + let rec go t r = match t with + | Zero () -> r + | One (t, l) -> go t (Fragment.collapse l r) + | Two (t, ll, l) -> go t (Fragment.collapse (Fragment.join ll l) r) + in + function + | Zero () -> Empty + | One (t, r) -> go t (Fragment.singleton_to_tree_exn r) + | Two (t, l, r) -> go (One (t, l)) (Fragment.singleton_to_tree_exn r) + ;; + + let max_key = function + | Zero () -> None + | One (_, r) | Two (_, _, r) -> Some (Fragment.max_key r) + ;; + end + + let of_increasing_sequence seq ~compare_key = with_return (fun { return; } -> + let builder, length = + Sequence.fold seq ~init:(Build_increasing.empty, 0) + ~f:(fun (builder, length) (key, data) -> + match Build_increasing.max_key builder with + | Some prev_key when compare_key prev_key key >= 0 -> + return (Or_error.error_string "of_increasing_sequence: non-increasing key") + | _ -> Build_increasing.add_unchecked builder ~key ~data, length + 1) + in + Ok (Build_increasing.to_tree builder, length)) + ;; + + (* Like [bal] but allows any difference in height between [l] and [r]. + + O(|height l - height r|) *) + let rec join l k d r ~compare_key = + match l, r with + | Empty, _ -> set' r k d ~compare_key + | _, Empty -> set' l k d ~compare_key + | Leaf(lk, ld), _ -> set' (set' r k d ~compare_key) lk ld ~compare_key + | _, Leaf(rk, rd) -> set' (set' l k d ~compare_key) rk rd ~compare_key + | Node(ll, lk, ld, lr, lh), Node(rl, rk, rd, rr, rh) -> + (* [bal] requires height difference <= 3. *) + if lh > rh + 3 + (* [height lr >= height r], + therefore [height (join lr k d r ...)] is [height rl + 1] or [height rl] + therefore the height difference with [ll] will be <= 3 *) + then bal ll lk ld (join lr k d r ~compare_key) + else if rh > lh + 3 + then bal (join l k d rl ~compare_key) rk rd rr + else bal l k d r + ;; + + let rec split t x ~compare_key = + match t with + | Empty -> (Empty, None, Empty) + | Leaf(k, d) -> + let cmp = compare_key x k in + if cmp = 0 then (Empty, Some (k, d), Empty) + else if cmp < 0 then (Empty, None, t) + else (t, None, Empty) + | Node(l, k, d, r, _) -> + let cmp = compare_key x k in + if cmp = 0 then (l, Some (k, d), r) + else if cmp < 0 then + let ll, maybe, lr = split l x ~compare_key in + (ll, maybe, join lr k d r ~compare_key) + else + let rl, maybe, rr = split r x ~compare_key in + (join l k d rl ~compare_key, maybe, rr) + ;; + + let split_and_reinsert_boundary t ~into x ~compare_key = + let left, boundary_opt, right = split t x ~compare_key in + match boundary_opt with + | None -> left, right + | Some (key, data) -> + let insert_into tree = fst (set tree ~key ~data ~length:0 ~compare_key) in + match into with + | `Left -> insert_into left, right + | `Right -> left, insert_into right + ;; + + let split_range t ~(lower_bound : 'a Maybe_bound.t) ~(upper_bound : 'a Maybe_bound.t) ~compare_key = + if Maybe_bound.bounds_crossed ~compare:compare_key ~lower:lower_bound ~upper:upper_bound + then empty, empty, empty + else + let left, mid_and_right = + match lower_bound with + | Unbounded -> empty, t + | Incl lb -> split_and_reinsert_boundary ~into:`Right t lb ~compare_key + | Excl lb -> split_and_reinsert_boundary ~into:`Left t lb ~compare_key + in + let mid, right = + match upper_bound with + | Unbounded -> mid_and_right, empty + | Incl lb -> split_and_reinsert_boundary ~into:`Left mid_and_right lb ~compare_key + | Excl lb -> split_and_reinsert_boundary ~into:`Right mid_and_right lb ~compare_key + in + left, mid, right + ;; + + let rec find t x ~compare_key = + match t with + | Empty -> None + | Leaf (v, d) -> if compare_key x v = 0 then Some d else None + | Node(l, v, d, r, _) -> + let c = compare_key x v in + if c = 0 then Some d + else find (if c < 0 then l else r) x ~compare_key + ;; + + let add_multi t ~length ~key ~data ~compare_key = + let data = data :: Option.value (find t key ~compare_key) ~default:[] in + set ~length ~key ~data t ~compare_key + ;; + + let find_multi t x ~compare_key = + match find t x ~compare_key with + | None -> [] + | Some l -> l + ;; + + let find_exn t x ~compare_key = + match find t x ~compare_key with + | Some data -> data + | None -> + raise Caml.Not_found + ;; + + let mem t x ~compare_key = Option.is_some (find t x ~compare_key) + + let rec min_elt = function + | Empty -> None + | Leaf (k, d) -> Some (k, d) + | Node (Empty, k, d, _, _) -> Some (k, d) + | Node (l, _, _, _, _) -> min_elt l + ;; + + exception Map_min_elt_exn_of_empty_map [@@deriving_inline sexp] + let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Map_min_elt_exn_of_empty_map]) + (function + | Map_min_elt_exn_of_empty_map -> + Ppx_sexp_conv_lib.Sexp.Atom + "src/map.ml.Tree0.Map_min_elt_exn_of_empty_map" + | _ -> assert false) + [@@@end] + exception Map_max_elt_exn_of_empty_map [@@deriving_inline sexp] + let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Map_max_elt_exn_of_empty_map]) + (function + | Map_max_elt_exn_of_empty_map -> + Ppx_sexp_conv_lib.Sexp.Atom + "src/map.ml.Tree0.Map_max_elt_exn_of_empty_map" + | _ -> assert false) + [@@@end] + + let min_elt_exn t = + match min_elt t with + | None -> raise Map_min_elt_exn_of_empty_map + | Some v -> v + ;; + + let rec max_elt = function + | Empty -> None + | Leaf (k, d) -> Some (k, d) + | Node (_, k, d, Empty, _) -> Some (k, d) + | Node (_, _, _, r, _) -> max_elt r + ;; + let max_elt_exn t = + match max_elt t with + | None -> raise Map_max_elt_exn_of_empty_map + | Some v -> v + ;; + + let rec remove_min_elt t = + match t with + Empty -> invalid_arg "Map.remove_min_elt" + | Leaf _ -> Empty + | Node(Empty, _, _, r, _) -> r + | Node(l, x, d, r, _) -> bal (remove_min_elt l) x d r + + let append ~lower_part ~upper_part ~compare_key = + match max_elt lower_part, min_elt upper_part with + | None, _ -> `Ok upper_part + | _, None -> `Ok lower_part + | Some (max_lower, _), Some (min_upper, v) + when compare_key max_lower min_upper < 0 -> + let upper_part_without_min = remove_min_elt upper_part in + `Ok (join ~compare_key:compare_key lower_part min_upper v upper_part_without_min) + | _ -> `Overlapping_key_ranges + ;; + + + let fold_range_inclusive = + (* This assumes that min <= max, which is checked by the outer function. *) + let rec go t ~min ~max ~init ~f ~compare_key = + match t with + | Empty -> init + | Leaf (k, d) -> + if compare_key k min < 0 || compare_key k max > 0 then + (* k < min || k > max *) + init + else + f ~key:k ~data:d init + | Node (l, k, d, r, _) -> + let c_min = compare_key k min in + if c_min < 0 then + (* if k < min, then this node and its left branch are outside our range *) + go r ~min ~max ~init ~f ~compare_key + else if c_min = 0 then + (* if k = min, then this node's left branch is outside our range *) + go r ~min ~max ~init:(f ~key:k ~data:d init) ~f ~compare_key + else (* k > min *) + begin + let z = go l ~min ~max ~init ~f ~compare_key in + let c_max = compare_key k max in + (* if k > max, we're done *) + if c_max > 0 then z + else + let z = f ~key:k ~data:d z in + (* if k = max, then we fold in this one last value and we're done *) + if c_max = 0 then z + else go r ~min ~max ~init:z ~f ~compare_key + end + in fun t ~min ~max ~init ~f ~compare_key -> + if compare_key min max <= 0 then + go t ~min ~max ~init ~f ~compare_key + else + init + ;; + + let range_to_alist t ~min ~max ~compare_key = + List.rev + (fold_range_inclusive t ~min ~max ~init:[] ~f:(fun ~key ~data l -> (key,data)::l) + ~compare_key) + ;; + + let concat_unchecked t1 t2 = + match (t1, t2) with + | (Empty, t) -> t + | (t, Empty) -> t + | (_, _) -> + let (x, d) = min_elt_exn t2 in + bal t1 x d (remove_min_elt t2) + ;; + + let rec remove t x ~length ~compare_key = + match t with + | Empty -> (Empty, length) + | Leaf (v, _) -> + if compare_key x v = 0 then + (Empty, length - 1) + else (t, length) + | Node(l, v, d, r, _) -> + let c = compare_key x v in + if c = 0 then + (concat_unchecked l r, length - 1) + else if c < 0 then + let l, length = remove l x ~length ~compare_key in + (bal l v d r, length) + else + let r, length = remove r x ~length ~compare_key in + (bal l v d r, length) + ;; + + (* Use exception to avoid tree-rebuild in no-op case *) + exception Change_no_op + + let change t key ~f ~length ~compare_key = + let rec change_core t key f = + match t with + | Empty -> + begin match (f None) with + | None -> raise Change_no_op (* equivalent to returning: Empty *) + | Some data -> (Leaf(key, data), length + 1) + end + | Leaf(v, d) -> + let c = compare_key key v in + if c = 0 then + match f (Some d) with + | None -> (Empty, length - 1) + | Some d' -> (Leaf(v, d'), length) + else if c < 0 then + let l, length = change_core Empty key f in + (bal l v d Empty, length) + else + let r, length = change_core Empty key f in + (bal Empty v d r, length) + | Node(l, v, d, r, h) -> + let c = compare_key key v in + if c = 0 then + begin match (f (Some d)) with + | None -> (concat_unchecked l r, length - 1) + | Some data -> (Node(l, key, data, r, h), length) + end + else + if c < 0 then + let l, length = change_core l key f in + (bal l v d r, length) + else + let r, length = change_core r key f in + (bal l v d r, length) + in + try change_core t key f with Change_no_op -> (t, length) + ;; + + let remove_multi t key ~length ~compare_key = + change t key ~length ~compare_key ~f:(function + | None | Some ([] | [_]) -> None + | Some (_ :: ((_ :: _) as non_empty_tail)) -> Some non_empty_tail) + ;; + + let rec iter_keys t ~f = + match t with + | Empty -> () + | Leaf(v, _) -> f v + | Node(l, v, _, r, _) -> iter_keys ~f l; f v; iter_keys ~f r + ;; + + let rec iter t ~f = + match t with + | Empty -> () + | Leaf(_, d) -> f d + | Node(l, _, d, r, _) -> iter ~f l; f d; iter ~f r + ;; + + let rec iteri t ~f = + match t with + | Empty -> () + | Leaf(v, d) -> f ~key:v ~data:d + | Node(l, v, d, r, _) -> iteri ~f l; f ~key:v ~data:d; iteri ~f r + ;; + + let rec map t ~f = + match t with + | Empty -> Empty + | Leaf(v, d) -> Leaf(v, f d) + | Node(l, v, d, r, h) -> + let l' = map ~f l in + let d' = f d in + let r' = map ~f r in + Node(l', v, d', r', h) + ;; + + let rec mapi t ~f = + match t with + | Empty -> Empty + | Leaf(v, d) -> Leaf(v, f ~key:v ~data:d) + | Node(l, v, d, r, h) -> + let l' = mapi ~f l in + let d' = f ~key:v ~data:d in + let r' = mapi ~f r in + Node(l', v, d', r', h) + ;; + + let rec fold t ~init:accu ~f = + match t with + | Empty -> accu + | Leaf(v, d) -> f ~key:v ~data:d accu + | Node(l, v, d, r, _) -> fold ~f r ~init:(f ~key:v ~data:d (fold ~f l ~init:accu)) + ;; + + let rec fold_right t ~init:accu ~f = + match t with + | Empty -> accu + | Leaf(v, d) -> f ~key:v ~data:d accu + | Node(l, v, d, r, _) -> + fold_right ~f l ~init:(f ~key:v ~data:d (fold_right ~f r ~init:accu)) + ;; + + let filter_keys t ~f ~compare_key = + fold ~init:(Empty, 0) t ~f:(fun ~key ~data (accu, length) -> + if f key + then set ~length ~key ~data accu ~compare_key + else (accu, length)) + ;; + + let filter t ~f ~compare_key = + fold ~init:(Empty, 0) t ~f:(fun ~key ~data (accu, length) -> + if f data + then set ~length ~key ~data accu ~compare_key + else (accu, length)) + ;; + + let filteri t ~f ~compare_key = + fold ~init:(Empty, 0) t ~f:(fun ~key ~data (accu, length) -> + if f ~key ~data + then set ~length ~key ~data accu ~compare_key + else (accu, length)) + ;; + + let filter_map t ~f ~compare_key = + fold ~init:(Empty, 0) t ~f:(fun ~key ~data (accu, length) -> + match f data with + | None -> (accu, length) + | Some b -> set ~length ~key ~data:b accu ~compare_key) + ;; + + let filter_mapi t ~f ~compare_key = + fold ~init:(Empty, 0) t ~f:(fun ~key ~data (accu, length) -> + match f ~key ~data with + | None -> (accu, length) + | Some b -> set ~length ~key ~data:b accu ~compare_key) + ;; + + let partition_mapi t ~f ~compare_key = + fold t ~init:((Empty, 0), (Empty, 0)) ~f:(fun ~key ~data (pair1, pair2) -> + match f ~key ~data with + | `Fst x -> + let t, length = pair1 in + (set t ~key ~data:x ~compare_key ~length, pair2) + | `Snd y -> + let t, length = pair2 in + (pair1, set t ~key ~data:y ~compare_key ~length)) + ;; + + let partition_map t ~f ~compare_key = + partition_mapi t ~compare_key ~f:(fun ~key:_ ~data -> f data) + ;; + + let partitioni_tf t ~f ~compare_key = + partition_mapi t ~compare_key ~f:(fun ~key ~data -> + if f ~key ~data + then `Fst data + else `Snd data) + ;; + + let partition_tf t ~f ~compare_key = + partition_mapi t ~compare_key ~f:(fun ~key:_ ~data -> + if f data + then `Fst data + else `Snd data) + ;; + + module Enum = struct + type increasing + type decreasing + + type ('k, 'v, 'direction) t = + | End + | More of 'k * 'v * ('k, 'v) tree * ('k, 'v, 'direction) t + + let rec cons t (e : (_, _, increasing) t) : (_, _, increasing) t = + match t with + | Empty -> e + | Leaf(v, d) -> More(v, d, Empty, e) + | Node(l, v, d, r, _) -> cons l (More(v, d, r, e)) + ;; + + let rec cons_right t (e : (_, _, decreasing) t) : (_, _, decreasing) t = + match t with + | Empty -> e + | Leaf(v, d) -> More(v, d, Empty, e) + | Node(l, v, d, r, _) -> cons_right r (More(v, d, l, e)) + ;; + + let of_tree tree : (_, _, increasing) t = cons tree End + ;; + + let of_tree_right tree : (_, _, decreasing) t = cons_right tree End + ;; + + let starting_at_increasing t key compare : (_, _, increasing) t = + let rec loop t e = + match t with + | Empty -> e + | Leaf(v, d) -> loop (Node(Empty, v, d, Empty, 1)) e + | Node(_, v, _, r, _) when compare v key < 0 -> loop r e + | Node(l, v, d, r, _) -> loop l (More(v, d, r, e)) + in + loop t End + ;; + + let starting_at_decreasing t key compare : (_, _, decreasing) t = + let rec loop t e = + match t with + | Empty -> e + | Leaf(v, d) -> loop (Node(Empty, v, d, Empty, 1)) e + | Node(l, v, _, _, _) when compare v key > 0 -> loop l e + | Node(l, v, d, r, _) -> loop r (More(v, d, l, e)) + in + loop t End + ;; + + let compare compare_key compare_data t1 t2 = + let rec loop t1 t2 = + match t1, t2 with + | (End, End) -> 0 + | (End, _) -> -1 + | (_, End) -> 1 + | (More (v1, d1, r1, e1), More (v2, d2, r2, e2)) -> + let c = compare_key v1 v2 in + if c <> 0 then c else + let c = compare_data d1 d2 in + if c <> 0 then c else + if phys_equal r1 r2 + then loop e1 e2 + else loop (cons r1 e1) (cons r2 e2) + in + loop t1 t2 + ;; + + let equal compare_key data_equal t1 t2 = + let rec loop t1 t2 = + match t1, t2 with + | (End, End) -> true + | (End, _) | (_, End) -> false + | (More (v1, d1, r1, e1), More (v2, d2, r2, e2)) -> + compare_key v1 v2 = 0 + && data_equal d1 d2 + && (if phys_equal r1 r2 + then loop e1 e2 + else loop (cons r1 e1) (cons r2 e2)) + in + loop t1 t2 + ;; + + let rec fold ~init ~f = function + | End -> init + | More (key, data, tree, enum) -> + let next = f ~key ~data init in + fold (cons tree enum) ~init:next ~f + ;; + + let fold2 compare_key t1 t2 ~init ~f = + let rec loop t1 t2 curr = + match t1, t2 with + | End, End -> curr + | End, _ -> + fold t2 ~init:curr ~f:(fun ~key ~data acc -> f ~key ~data:(`Right data) acc) + | _ , End -> + fold t1 ~init:curr ~f:(fun ~key ~data acc -> f ~key ~data:(`Left data) acc) + | More (k1, v1, tree1, enum1), More (k2, v2, tree2, enum2) -> + let compare_result = compare_key k1 k2 in + if compare_result = 0 then begin + let next = f ~key:k1 ~data:(`Both (v1, v2)) curr in + loop (cons tree1 enum1) (cons tree2 enum2) next + end else if compare_result < 0 then begin + let next = f ~key:k1 ~data:(`Left v1) curr in + loop (cons tree1 enum1) t2 next + end else begin + let next = f ~key:k2 ~data:(`Right v2) curr in + loop t1 (cons tree2 enum2) next + end + in + loop t1 t2 init + ;; + + let symmetric_diff t1 t2 ~compare_key ~data_equal = + let step state = + match state with + | End, End -> + Sequence.Step.Done + | End, More (key, data, tree, enum) -> + Sequence.Step.Yield ((key, `Right data), (End, cons tree enum)) + | More (key, data, tree, enum), End -> + Sequence.Step.Yield ((key, `Left data), (cons tree enum, End)) + | (More (k1, v1, tree1, enum1) as left), (More (k2, v2, tree2, enum2) as right) -> + let compare_result = compare_key k1 k2 in + if compare_result = 0 then begin + let next_state = + if phys_equal tree1 tree2 + then (enum1, enum2) + else (cons tree1 enum1, cons tree2 enum2) + in + if data_equal v1 v2 + then Sequence.Step.Skip next_state + else Sequence.Step.Yield ((k1, `Unequal (v1, v2)), next_state) + end else if compare_result < 0 then begin + Sequence.Step.Yield ((k1, `Left v1), (cons tree1 enum1, right)) + end else begin + Sequence.Step.Yield ((k2, `Right v2), (left, (cons tree2 enum2))) + end + in + Sequence.unfold_step ~init:(of_tree t1, of_tree t2) ~f:step + ;; + + + end + + let to_sequence_increasing comparator ~from_key t = + let next enum = + match enum with + | Enum.End -> Sequence.Step.Done + | Enum.More(k,v,t,e) -> Sequence.Step.Yield((k,v), Enum.cons t e) + in + let init = + match from_key with + | None -> Enum.of_tree t + | Some key -> Enum.starting_at_increasing t key comparator.Comparator.compare + in + Sequence.unfold_step ~init ~f:next + ;; + + let to_sequence_decreasing comparator ~from_key t = + let next enum = + match enum with + | Enum.End -> Sequence.Step.Done + | Enum.More(k,v,t,e) -> Sequence.Step.Yield((k,v), Enum.cons_right t e) + in + let init = + match from_key with + | None -> Enum.of_tree_right t + | Some key -> Enum.starting_at_decreasing t key comparator.Comparator.compare + in + Sequence.unfold_step ~init ~f:next + ;; + + let to_sequence comparator ?(order=`Increasing_key) ?keys_greater_or_equal_to + ?keys_less_or_equal_to t = + let inclusive_bound side t bound = + let compare_key = comparator.Comparator.compare in + let l, maybe, r = split t bound ~compare_key in + let t = side (l, r) in + match maybe with + | None -> t + | Some (key, data) -> set' t key data ~compare_key + in + match order with + | `Increasing_key -> + let t = Option.fold keys_less_or_equal_to ~init:t ~f:(inclusive_bound fst) in + to_sequence_increasing comparator ~from_key:keys_greater_or_equal_to t + | `Decreasing_key -> + let t = Option.fold keys_greater_or_equal_to ~init:t ~f:(inclusive_bound snd) in + to_sequence_decreasing comparator ~from_key:keys_less_or_equal_to t + ;; + + let compare compare_key compare_data t1 t2 = + Enum.compare compare_key compare_data (Enum.of_tree t1) (Enum.of_tree t2) + ;; + + let equal compare_key compare_data t1 t2 = + Enum.equal compare_key compare_data (Enum.of_tree t1) (Enum.of_tree t2) + ;; + + let iter2 t1 t2 ~f ~compare_key = + Enum.fold2 compare_key (Enum.of_tree t1) (Enum.of_tree t2) + ~init:() + ~f:(fun ~key ~data () -> f ~key ~data) + ;; + + let fold2 t1 t2 ~init ~f ~compare_key = + Enum.fold2 compare_key (Enum.of_tree t1) (Enum.of_tree t2) ~f ~init + ;; + + let symmetric_diff = Enum.symmetric_diff + + let rec length = function + | Empty -> 0 + | Leaf _ -> 1 + | Node (l, _, _, r, _) -> length l + length r + 1 + ;; + + let hash_fold_t_ignoring_structure hash_fold_key hash_fold_data state t = + fold t ~init:(hash_fold_int state (length t)) + ~f:(fun ~key ~data state -> hash_fold_data (hash_fold_key state key) data) + ;; + + let of_alist_fold alist ~init ~f ~compare_key = + List.fold alist ~init:(empty, 0) + ~f:(fun (accum, length) (key, data) -> + let prev_data = + match find accum key ~compare_key with + | None -> init + | Some prev -> prev + in + let data = f prev_data data in + set accum ~length ~key ~data ~compare_key) + ;; + + let of_alist_reduce alist ~f ~compare_key = + List.fold alist ~init:(empty, 0) + ~f:(fun (accum, length) (key, data) -> + let new_data = + match find accum key ~compare_key with + | None -> data + | Some prev -> f prev data + in + set accum ~length ~key ~data:new_data ~compare_key) + ;; + + let keys t = fold_right ~f:(fun ~key ~data:_ list -> key::list) t ~init:[] + let data t = fold_right ~f:(fun ~key:_ ~data list -> data::list) t ~init:[] + + let of_alist alist ~compare_key = + with_return (fun r -> + let map = + List.fold alist ~init:(empty, 0) ~f:(fun (t, length) (key,data) -> + let ((_, length') as acc) = set ~length ~key ~data t ~compare_key in + if length = length' then r.return (`Duplicate_key key) + else acc) + in + `Ok map) + ;; + + let for_all t ~f = + with_return (fun r -> + iter t ~f:(fun data -> if not (f data) then r.return false); + true) + + let for_alli t ~f = + with_return (fun r -> + iteri t ~f:(fun ~key ~data -> if not (f ~key ~data) then r.return false); + true) + + let exists t ~f = + with_return (fun r -> + iter t ~f:(fun data -> if f data then r.return true); + false) + + let existsi t ~f = + with_return (fun r -> + iteri t ~f:(fun ~key ~data -> if f ~key ~data then r.return true); + false) + + let count t ~f = + fold t ~init:0 ~f:(fun ~key:_ ~data acc -> if f data then acc + 1 else acc) + + let counti t ~f = + fold t ~init:0 ~f:(fun ~key ~data acc -> if f ~key ~data then acc + 1 else acc) + + let of_alist_or_error alist ~comparator = + match of_alist alist ~compare_key:comparator.Comparator.compare with + | `Ok x -> Result.Ok x + | `Duplicate_key key -> + Or_error.error "Map.of_alist_or_error: duplicate key" key comparator.sexp_of_t + ;; + + let of_alist_exn alist ~comparator = + match of_alist alist ~compare_key:comparator.Comparator.compare with + | `Ok x -> x + | `Duplicate_key key -> + Error.create "Map.of_alist_exn: duplicate key" key comparator.sexp_of_t + |> Error.raise + ;; + + let of_alist_multi alist ~compare_key = + let alist = List.rev alist in + of_alist_fold alist ~init:[] ~f:(fun l x -> x :: l) ~compare_key + ;; + + let to_alist ?(key_order = `Increasing) t = + match key_order with + | `Increasing -> fold_right t ~init:[] ~f:(fun ~key ~data x -> (key, data) :: x) + | `Decreasing -> fold t ~init:[] ~f:(fun ~key ~data x -> (key, data) :: x) + ;; + + let merge t1 t2 ~f ~compare_key = + let elts = Uniform_array.unsafe_create_uninitialized ~len:(length t1 + length t2) in + let i = ref 0 in + iter2 t1 t2 ~compare_key ~f:(fun ~key ~data:values -> + match f ~key values with + | Some value -> Uniform_array.set elts !i (key, value); incr i + | None -> ()); + let len = !i in + let get i = Uniform_array.get elts i in + let tree = of_increasing_iterator_unchecked ~len ~f:get in + tree, len + ;; + + module Closest_key_impl = struct + (* [marker] and [repackage] allow us to create "logical" options without actually + allocating any options. Passing [Found key value] to a function is equivalent to + passing [Some (key, value)]; passing [Missing () ()] is equivalent to passing + [None]. *) + type ('k, 'v, 'k_opt, 'v_opt) marker = + | Missing : ('k, 'v, unit, unit) marker + | Found : ('k, 'v, 'k, 'v) marker + + let repackage (type k) (type v) (type k_opt) (type v_opt) + (marker : (k, v, k_opt, v_opt) marker) (k : k_opt) (v : v_opt) + : (k * v) option = + match marker with + | Missing -> None + | Found -> Some (k, v) + ;; + + (* The type signature is explicit here to allow polymorphic recursion. *) + let rec loop + : 'k 'v 'k_opt 'v_opt. + ('k, 'v) tree + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'k -> compare_key:('k -> 'k -> int) + -> ('k, 'v, 'k_opt, 'v_opt) marker -> 'k_opt -> 'v_opt + -> ('k * 'v) option + = fun t dir k ~compare_key found_marker found_key found_value -> + match t with + | Empty -> + repackage found_marker found_key found_value + | Leaf (k', v') -> + let c = compare_key k' k in + if + match dir with + | `Greater_or_equal_to -> c >= 0 + | `Greater_than -> c > 0 + | `Less_or_equal_to -> c <= 0 + | `Less_than -> c < 0 + then + Some (k', v') + else + repackage found_marker found_key found_value + | Node (l, k', v', r, _) -> + let c = compare_key k' k in + if c = 0 then begin + (* This is a base case (no recursive call). *) + match dir with + | `Greater_or_equal_to | `Less_or_equal_to -> + Some (k', v') + | `Greater_than -> + if is_empty r then + repackage found_marker found_key found_value + else + min_elt r + | `Less_than -> + if is_empty l then + repackage found_marker found_key found_value + else + max_elt l + end else begin + (* We are guaranteed here that k' <> k. *) + (* This is the only recursive case. *) + match dir with + | `Greater_or_equal_to | `Greater_than -> + if c > 0 + then loop l dir k ~compare_key Found k' v' + else loop r dir k ~compare_key found_marker found_key found_value + | `Less_or_equal_to | `Less_than -> + if c < 0 + then loop r dir k ~compare_key Found k' v' + else loop l dir k ~compare_key found_marker found_key found_value + end + ;; + + let closest_key t dir k ~compare_key = + loop t dir k ~compare_key Missing () () + ;; + end + let closest_key = Closest_key_impl.closest_key + + let rec rank t k ~compare_key = + match t with + | Empty -> None + | Leaf (k', _) -> if compare_key k' k = 0 then Some 0 else None + | Node (l, k', _, r, _) -> + let c = compare_key k' k in + if c = 0 + then Some (length l) + else if c > 0 + then rank l k ~compare_key + else Option.map (rank r k ~compare_key) ~f:(fun rank -> rank + 1 + (length l)) + ;; + + (* this could be implemented using [Sequence] interface but the following implementation + allocates only 2 words and doesn't require write-barrier *) + let rec nth' num_to_search = function + | Empty -> None + | Leaf (k, v) -> + if !num_to_search = 0 + then Some (k, v) + else begin + decr num_to_search; + None + end + | Node (l, k, v, r, _) -> + match nth' num_to_search l with + | (Some _) as some -> some + | None -> + if !num_to_search = 0 + then Some (k, v) + else begin + decr num_to_search; + nth' num_to_search r + end + + let nth t n = nth' (ref n) t + ;; + + type ('k, 'v) acc = + { mutable bad_key : 'k option + ; mutable map_length : (('k, 'v) t * int) + } + + let of_iteri ~iteri ~compare_key = + let acc = { bad_key = None; map_length = (empty, 0) } in + iteri ~f:(fun ~key ~data -> + let map, length = acc.map_length in + let ((_, length') as pair) = set ~length ~key ~data map ~compare_key in + if length = length' && Option.is_none acc.bad_key + then acc.bad_key <- Some key + else acc.map_length <- pair); + match acc.bad_key with + | None -> `Ok acc.map_length + | Some key -> `Duplicate_key key + ;; + + let t_of_sexp_direct key_of_sexp value_of_sexp sexp ~comparator = + let alist = list_of_sexp (pair_of_sexp key_of_sexp value_of_sexp) sexp in + of_alist_exn alist ~comparator + ;; + + let sexp_of_t sexp_of_key sexp_of_value t = + let f ~key ~data acc = Sexp.List [sexp_of_key key; sexp_of_value data] :: acc in + Sexp.List (fold_right ~f t ~init:[]) + ;; +end + +type ('k, 'v, 'comparator) t = + { (* [comparator] is the first field so that polymorphic equality fails on a map due + to the functional value in the comparator. + Note that this does not affect polymorphic [compare]: that still produces + nonsense. *) + comparator : ('k, 'comparator) Comparator.t; + tree : ('k, 'v) Tree0.t; + length : int; + } + +type ('k, 'v, 'comparator) tree = ('k, 'v) Tree0.t + +let compare_key t = t.comparator.Comparator.compare + +let like {tree = _; length = _; comparator} (tree, length) = {tree; length; comparator} +let like2 x (y,z) = like x y, like x z +let with_same_length { tree = _; comparator; length } tree = + { tree; comparator; length } + +let of_tree ~comparator tree = { tree; comparator; length = Tree0.length tree} + +(* Exposing this function would make it very easy for the invariants + of this module to be broken. *) +let of_tree_unsafe ~comparator ~length tree = {tree; comparator; length} + +module Accessors = struct + let comparator t = t.comparator + let to_tree t = t.tree + let invariants t = Tree0.invariants t.tree ~compare_key:(compare_key t) + let is_empty t = Tree0.is_empty t.tree + let length t = t.length + let set t ~key ~data = + like t (Tree0.set t.tree ~length:t.length ~key ~data ~compare_key:(compare_key t)) + ;; + let add_exn t ~key ~data = + like t (Tree0.add_exn t.tree ~length:t.length ~key ~data ~compare_key:(compare_key t) + ~sexp_of_key:t.comparator.sexp_of_t) + ;; + let add_exn_internal t ~key ~data = + like t (Tree0.add_exn_internal + t.tree ~length:t.length ~key ~data ~compare_key:(compare_key t) + ~sexp_of_key:t.comparator.sexp_of_t) + ;; + let add t ~key ~data = + match add_exn_internal t ~key ~data with + | result -> `Ok result + | exception Duplicate -> `Duplicate + ;; + let add_multi t ~key ~data = + like t + (Tree0.add_multi t.tree ~length:t.length ~key ~data ~compare_key:(compare_key t)) + ;; + let remove_multi t key = + like t (Tree0.remove_multi t.tree ~length:t.length key ~compare_key:(compare_key t)) + ;; + let find_multi t key = Tree0.find_multi t.tree key ~compare_key:(compare_key t) + let change t key ~f = + like t (Tree0.change t.tree key ~f ~length:t.length ~compare_key:(compare_key t)) + ;; + let update t key ~f = change t key ~f:(fun data -> Some (f data)) + let find_exn t key = Tree0.find_exn t.tree key ~compare_key:(compare_key t) + let find t key = Tree0.find t.tree key ~compare_key:(compare_key t) + let remove t key = + like t (Tree0.remove t.tree key ~length:t.length ~compare_key:(compare_key t)) + ;; + let mem t key = Tree0.mem t.tree key ~compare_key:(compare_key t) + ;; + let iter_keys t ~f = Tree0.iter_keys t.tree ~f + let iter t ~f = Tree0.iter t.tree ~f + let iteri t ~f = Tree0.iteri t.tree ~f + let iter2 t1 t2 ~f = Tree0.iter2 t1.tree t2.tree ~f ~compare_key:(compare_key t1) + ;; + let map t ~f = with_same_length t (Tree0.map t.tree ~f) + let mapi t ~f = with_same_length t (Tree0.mapi t.tree ~f) + let fold t ~init ~f = Tree0.fold t.tree ~f ~init + let fold_right t ~init ~f = Tree0.fold_right t.tree ~f ~init + let fold2 t1 t2 ~init ~f = + Tree0.fold2 t1.tree t2.tree ~init ~f ~compare_key:(compare_key t1) + ;; + let filter_keys t ~f = like t (Tree0.filter_keys t.tree ~f ~compare_key:(compare_key t)) + let filter t ~f = like t (Tree0.filter t.tree ~f ~compare_key:(compare_key t)) + let filteri t ~f = like t (Tree0.filteri t.tree ~f ~compare_key:(compare_key t)) + let filter_map t ~f = like t (Tree0.filter_map t.tree ~f ~compare_key:(compare_key t)) + let filter_mapi t ~f = like t (Tree0.filter_mapi t.tree ~f ~compare_key:(compare_key t)) + ;; + let partition_mapi t ~f = + like2 t (Tree0.partition_mapi t.tree ~f ~compare_key:(compare_key t)) + ;; + let partition_map t ~f = + like2 t (Tree0.partition_map t.tree ~f ~compare_key:(compare_key t)) + ;; + let partitioni_tf t ~f = + like2 t (Tree0.partitioni_tf t.tree ~f ~compare_key:(compare_key t)) + ;; + let partition_tf t ~f = + like2 t (Tree0.partition_tf t.tree ~f ~compare_key:(compare_key t)) + ;; + + let compare_direct compare_data t1 t2 = + Tree0.compare (compare_key t1) compare_data t1.tree t2.tree + ;; + let equal compare_data t1 t2 = + Tree0.equal (compare_key t1) compare_data t1.tree t2.tree + ;; + let keys t = Tree0.keys t.tree + let data t = Tree0.data t.tree + let to_alist ?key_order t = Tree0.to_alist ?key_order t.tree + let validate ~name f t = Validate.alist ~name f (to_alist t) + let symmetric_diff t1 t2 ~data_equal = + Tree0.symmetric_diff t1.tree t2.tree ~compare_key:(compare_key t1) ~data_equal + ;; + let merge t1 t2 ~f = + like t1 (Tree0.merge t1.tree t2.tree ~f ~compare_key:(compare_key t1)) + ;; + let min_elt t = Tree0.min_elt t.tree + let min_elt_exn t = Tree0.min_elt_exn t.tree + let max_elt t = Tree0.max_elt t.tree + let max_elt_exn t = Tree0.max_elt_exn t.tree + let for_all t ~f = Tree0.for_all t.tree ~f + let for_alli t ~f = Tree0.for_alli t.tree ~f + let exists t ~f = Tree0.exists t.tree ~f + let existsi t ~f = Tree0.existsi t.tree ~f + let count t ~f = Tree0.count t.tree ~f + let counti t ~f = Tree0.counti t.tree ~f + let split t k = + let l, maybe, r = Tree0.split t.tree k ~compare_key:(compare_key t) in + let comparator = comparator t in + (* Try to traverse the least amount possible to calculate the length, + using height as a heuristic. *) + let both_len = if Option.is_some maybe then t.length - 1 else t.length in + if Tree0.height l < Tree0.height r then + let l = of_tree l ~comparator in + l, maybe, of_tree_unsafe r ~comparator ~length:(both_len - length l) + else + let r = of_tree r ~comparator in + of_tree_unsafe l ~comparator ~length:(both_len - length r), maybe, r + ;; + let subrange t ~lower_bound ~upper_bound = + let left, mid, right = + Tree0.split_range t.tree ~lower_bound ~upper_bound ~compare_key:(compare_key t) + in + (* Try to traverse the least amount possible to calculate the length, + using height as a heuristic. *) + let outer_joined_height = + let h_l = Tree0.height left + and h_r = Tree0.height right in + if h_l = h_r then + h_l + 1 + else + max h_l h_r + in + if outer_joined_height < Tree0.height mid then + let mid_length = t.length - (Tree0.length left + Tree0.length right) in + of_tree_unsafe mid ~comparator:(comparator t) ~length:mid_length + else + of_tree mid ~comparator:(comparator t) + ;; + let append ~lower_part ~upper_part = + match Tree0.append ~compare_key:(compare_key lower_part) ~lower_part:lower_part.tree ~upper_part:upper_part.tree with + | `Ok tree -> `Ok (of_tree_unsafe tree ~comparator:(comparator lower_part) ~length:(lower_part.length + upper_part.length)) + | `Overlapping_key_ranges -> `Overlapping_key_ranges + ;; + let fold_range_inclusive t ~min ~max ~init ~f = + Tree0.fold_range_inclusive t.tree ~min ~max ~init ~f ~compare_key:(compare_key t) + ;; + let range_to_alist t ~min ~max = + Tree0.range_to_alist t.tree ~min ~max ~compare_key:(compare_key t) + ;; + let closest_key t dir key = Tree0.closest_key t.tree dir key ~compare_key:(compare_key t) + let nth t n = Tree0.nth t.tree n + let nth_exn t n = Option.value_exn (nth t n) + let rank t key = Tree0.rank t.tree key ~compare_key:(compare_key t) + let sexp_of_t sexp_of_k sexp_of_v _ t = Tree0.sexp_of_t sexp_of_k sexp_of_v t.tree + let to_sequence ?order ?keys_greater_or_equal_to ?keys_less_or_equal_to t = + Tree0.to_sequence t.comparator ?order ?keys_greater_or_equal_to + ?keys_less_or_equal_to t.tree + + let hash_fold_direct hash_fold_key hash_fold_data state t = + Tree0.hash_fold_t_ignoring_structure hash_fold_key hash_fold_data state t.tree +end + +(* [0] is used as the [length] argument everywhere in this module, since trees do not + have their lengths stored at the root, unlike maps. The values are discarded always. *) +module Tree = struct + type ('k, 'v, 'comparator) t = ('k, 'v, 'comparator) tree + + let empty_without_value_restriction = Tree0.empty + let empty ~comparator:_ = empty_without_value_restriction + let of_tree ~comparator:_ tree = tree + let singleton ~comparator:_ k v = Tree0.singleton k v + let of_sorted_array_unchecked ~comparator array = + fst (Tree0.of_sorted_array_unchecked array ~compare_key:comparator.Comparator.compare) + ;; + let of_sorted_array ~comparator array = + Tree0.of_sorted_array array ~compare_key:comparator.Comparator.compare + |> Or_error.map ~f:fst + ;; + let of_alist ~comparator alist = + match Tree0.of_alist alist ~compare_key:comparator.Comparator.compare with + | `Duplicate_key _ as d -> d + | `Ok (tree, _size) -> `Ok tree + ;; + let of_alist_or_error ~comparator alist = + Tree0.of_alist_or_error alist ~comparator + |> Or_error.map ~f:fst + let of_alist_exn ~comparator alist = fst (Tree0.of_alist_exn alist ~comparator) + let of_alist_multi ~comparator alist = + fst (Tree0.of_alist_multi alist ~compare_key:comparator.Comparator.compare) + ;; + let of_alist_fold ~comparator alist ~init ~f = + fst (Tree0.of_alist_fold alist ~init ~f ~compare_key:comparator.Comparator.compare) + ;; + let of_alist_reduce ~comparator alist ~f = + fst (Tree0.of_alist_reduce alist ~f ~compare_key:comparator.Comparator.compare) + ;; + let of_iteri ~comparator ~iteri = + match Tree0.of_iteri ~iteri ~compare_key:comparator.Comparator.compare with + | `Ok (tree, _size) -> `Ok tree + | `Duplicate_key _ as d -> d + ;; + let of_increasing_iterator_unchecked ~comparator:_required_by_intf ~len ~f = + Tree0.of_increasing_iterator_unchecked ~len ~f + ;; + let of_increasing_sequence ~comparator seq = + Or_error.map ~f:fst (Tree0.of_increasing_sequence seq ~compare_key:comparator.Comparator.compare) + ;; + + let to_tree t = t + let invariants ~comparator t = + Tree0.invariants t ~compare_key:comparator.Comparator.compare + ;; + let is_empty t = Tree0.is_empty t + let length t = Tree0.length t + let set ~comparator t ~key ~data = + fst (Tree0.set t ~key ~data ~length:0 ~compare_key:comparator.Comparator.compare) + ;; + let add_exn ~comparator t ~key ~data = + fst (Tree0.add_exn t ~key ~data ~length:0 ~compare_key:comparator.Comparator.compare + ~sexp_of_key:comparator.sexp_of_t) + ;; + let add ~comparator t ~key ~data = + try + `Ok (add_exn t ~comparator ~key ~data) + with _ -> + `Duplicate + ;; + let add_multi ~comparator t ~key ~data = + Tree0.add_multi t ~key ~data ~length:0 ~compare_key:comparator.Comparator.compare + |> fst + ;; + let remove_multi ~comparator t key = + Tree0.remove_multi t key ~length:0 ~compare_key:comparator.Comparator.compare + |> fst + ;; + let find_multi ~comparator t key = + Tree0.find_multi t key ~compare_key:comparator.Comparator.compare + ;; + let change ~comparator t key ~f = + fst (Tree0.change t key ~f ~length:0 ~compare_key:comparator.Comparator.compare) + ;; + let update ~comparator t key ~f = change ~comparator t key ~f:(fun data -> Some (f data)) + let find_exn ~comparator t key = + Tree0.find_exn t key ~compare_key:comparator.Comparator.compare + ;; + let find ~comparator t key = + Tree0.find t key ~compare_key:comparator.Comparator.compare + ;; + let remove ~comparator t key = + fst (Tree0.remove t key ~length:0 ~compare_key:comparator.Comparator.compare) + ;; + let mem ~comparator t key = Tree0.mem t key ~compare_key:comparator.Comparator.compare + ;; + let iter_keys t ~f = Tree0.iter_keys t ~f + let iter t ~f = Tree0.iter t ~f + let iteri t ~f = Tree0.iteri t ~f + let iter2 ~comparator t1 t2 ~f = + Tree0.iter2 t1 t2 ~f ~compare_key:comparator.Comparator.compare + ;; + let map t ~f = Tree0.map t ~f + let mapi t ~f = Tree0.mapi t ~f + let fold t ~init ~f = Tree0.fold t ~f ~init + let fold_right t ~init ~f = Tree0.fold_right t ~f ~init + let fold2 ~comparator t1 t2 ~init ~f = + Tree0.fold2 t1 t2 ~init ~f ~compare_key:comparator.Comparator.compare + ;; + let filter_keys ~comparator t ~f = + fst (Tree0.filter_keys t ~f ~compare_key:comparator.Comparator.compare) + let filter ~comparator t ~f = + fst (Tree0.filter t ~f ~compare_key:comparator.Comparator.compare) + let filteri ~comparator t ~f = + fst (Tree0.filteri t ~f ~compare_key:comparator.Comparator.compare) + let filter_map ~comparator t ~f = + fst (Tree0.filter_map t ~f ~compare_key:comparator.Comparator.compare) + let filter_mapi ~comparator t ~f = + fst (Tree0.filter_mapi t ~f ~compare_key:comparator.Comparator.compare) + ;; + let partition_mapi ~comparator t ~f = + let (a, _), (b, _) = + Tree0.partition_mapi t ~f ~compare_key:comparator.Comparator.compare + in + (a, b) + ;; + let partition_map ~comparator t ~f = + let (a, _), (b, _) = + Tree0.partition_map t ~f ~compare_key:comparator.Comparator.compare + in + (a, b) + ;; + let partitioni_tf ~comparator t ~f = + let (a, _), (b, _) = + Tree0.partitioni_tf t ~f ~compare_key:comparator.Comparator.compare + in + (a, b) + ;; + let partition_tf ~comparator t ~f = + let (a, _), (b, _) = + Tree0.partition_tf t ~f ~compare_key:comparator.Comparator.compare + in + (a, b) + ;; + let compare_direct ~comparator compare_data t1 t2 = + Tree0.compare comparator.Comparator.compare compare_data t1 t2 + ;; + let equal ~comparator compare_data t1 t2 = + Tree0.equal comparator.Comparator.compare compare_data t1 t2 + ;; + let keys t = Tree0.keys t + let data t = Tree0.data t + let to_alist ?key_order t = Tree0.to_alist ?key_order t + let validate ~name f t = Validate.alist ~name f (to_alist t) + let symmetric_diff ~comparator t1 t2 ~data_equal = + Tree0.symmetric_diff t1 t2 ~compare_key:comparator.Comparator.compare ~data_equal + ;; + let merge ~comparator t1 t2 ~f = + fst (Tree0.merge t1 t2 ~f ~compare_key:comparator.Comparator.compare) + ;; + let min_elt t = Tree0.min_elt t + let min_elt_exn t = Tree0.min_elt_exn t + let max_elt t = Tree0.max_elt t + let max_elt_exn t = Tree0.max_elt_exn t + let for_all t ~f = Tree0.for_all t ~f + let for_alli t ~f = Tree0.for_alli t ~f + let exists t ~f = Tree0.exists t ~f + let existsi t ~f = Tree0.existsi t ~f + let count t ~f = Tree0.count t ~f + let counti t ~f = Tree0.counti t ~f + let split ~comparator t k = Tree0.split t k ~compare_key:comparator.Comparator.compare + let append ~comparator ~lower_part ~upper_part = + Tree0.append ~lower_part ~upper_part ~compare_key:comparator.Comparator.compare + let subrange ~comparator t ~lower_bound ~upper_bound = + let (_, ret, _) = + Tree0.split_range t ~lower_bound ~upper_bound + ~compare_key:comparator.Comparator.compare in + ret + ;; + let fold_range_inclusive ~comparator t ~min ~max ~init ~f = + Tree0.fold_range_inclusive t ~min ~max ~init ~f + ~compare_key:comparator.Comparator.compare + ;; + let range_to_alist ~comparator t ~min ~max = + Tree0.range_to_alist t ~min ~max ~compare_key:comparator.Comparator.compare + ;; + let closest_key ~comparator t dir key = + Tree0.closest_key t dir key ~compare_key:comparator.Comparator.compare + ;; + let nth ~comparator:_ t n = Tree0.nth t n + let nth_exn ~comparator t n = Option.value_exn (nth ~comparator t n) + let rank ~comparator t key = Tree0.rank t key ~compare_key:comparator.Comparator.compare + let sexp_of_t sexp_of_k sexp_of_v _ t = Tree0.sexp_of_t sexp_of_k sexp_of_v t + let t_of_sexp_direct ~comparator k_of_sexp v_of_sexp sexp = + fst (Tree0.t_of_sexp_direct k_of_sexp v_of_sexp sexp ~comparator) + ;; + + let to_sequence ~comparator ?order ?keys_greater_or_equal_to ?keys_less_or_equal_to t + = + Tree0.to_sequence comparator ?order ?keys_greater_or_equal_to ?keys_less_or_equal_to + t +end + +module Using_comparator = struct + type nonrec ('k, 'v, 'cmp) t = ('k, 'v, 'cmp) t + + include Accessors + + let empty ~comparator = { tree = Tree0.empty; comparator; length = 0 } + + let singleton ~comparator k v = { comparator; tree = Tree0.singleton k v; length = 1 } + + let of_tree0 ~comparator (tree, length) = + { comparator; tree; length } + + let of_tree ~comparator tree = of_tree0 ~comparator (tree, Tree0.length tree) + let to_tree = to_tree + + let of_sorted_array_unchecked ~comparator array = + of_tree0 ~comparator + (Tree0.of_sorted_array_unchecked array ~compare_key:comparator.Comparator.compare) + ;; + + let of_sorted_array ~comparator array = + Or_error.map (Tree0.of_sorted_array array ~compare_key:comparator.Comparator.compare) + ~f:(fun tree -> of_tree0 ~comparator tree) + ;; + + let of_alist ~comparator alist = + match Tree0.of_alist alist ~compare_key:comparator.Comparator.compare with + | `Ok (tree, length) -> `Ok { comparator; tree; length } + | `Duplicate_key _ as z -> z + ;; + + let of_alist_or_error ~comparator alist = + Result.map (Tree0.of_alist_or_error alist ~comparator) + ~f:(fun tree -> of_tree0 ~comparator tree) + ;; + + let of_alist_exn ~comparator alist = + of_tree0 ~comparator (Tree0.of_alist_exn alist ~comparator) + ;; + + let of_alist_multi ~comparator alist = + of_tree0 ~comparator + (Tree0.of_alist_multi alist ~compare_key:comparator.Comparator.compare) + ;; + + let of_alist_fold ~comparator alist ~init ~f = + of_tree0 ~comparator + (Tree0.of_alist_fold alist ~init ~f ~compare_key:comparator.Comparator.compare) + ;; + + let of_alist_reduce ~comparator alist ~f = + of_tree0 ~comparator + (Tree0.of_alist_reduce alist ~f ~compare_key:comparator.Comparator.compare) + ;; + + let of_iteri ~comparator ~iteri = + match Tree0.of_iteri ~compare_key:comparator.Comparator.compare ~iteri with + | `Ok tree_length -> `Ok (of_tree0 ~comparator tree_length) + | `Duplicate_key _ as z -> z + ;; + + let of_increasing_iterator_unchecked ~comparator ~len ~f = + of_tree0 ~comparator (Tree0.of_increasing_iterator_unchecked ~len ~f, len) + + let of_increasing_sequence ~comparator seq = + Or_error.map ~f:(of_tree0 ~comparator) + (Tree0.of_increasing_sequence seq ~compare_key:comparator.Comparator.compare) + ;; + + let t_of_sexp_direct ~comparator k_of_sexp v_of_sexp sexp = + of_tree0 ~comparator (Tree0.t_of_sexp_direct k_of_sexp v_of_sexp sexp ~comparator) + ;; + + module Empty_without_value_restriction(K : Comparator.S1) = struct + let empty = { tree = Tree0.empty; comparator = K.comparator; length = 0 } + end + + module Tree = Tree +end + +include Accessors + +type ('k, 'cmp) comparator = + (module Comparator.S with type t = 'k and type comparator_witness = 'cmp) + +let comparator_s (type k cmp) t : (k, cmp) comparator = + (module struct + type t = k + type comparator_witness = cmp + let comparator = t.comparator + end) + +let to_comparator (type k cmp) ((module M) : (k, cmp) comparator) = M.comparator + +let empty m = Using_comparator.empty ~comparator:(to_comparator m) +let singleton m a = Using_comparator.singleton ~comparator:(to_comparator m) a +let of_alist m a = Using_comparator.of_alist ~comparator:(to_comparator m) a +let of_alist_or_error m a = Using_comparator.of_alist_or_error ~comparator:(to_comparator m) a +let of_alist_exn m a = Using_comparator.of_alist_exn ~comparator:(to_comparator m) a +let of_alist_multi m a = Using_comparator.of_alist_multi ~comparator:(to_comparator m) a +let of_alist_fold m a ~init ~f = Using_comparator.of_alist_fold ~comparator:(to_comparator m) a ~init ~f +let of_alist_reduce m a ~f = Using_comparator.of_alist_reduce ~comparator:(to_comparator m) a ~f +let of_sorted_array_unchecked m a = Using_comparator.of_sorted_array_unchecked ~comparator:(to_comparator m) a +let of_sorted_array m a = Using_comparator.of_sorted_array ~comparator:(to_comparator m) a +let of_iteri m ~iteri = Using_comparator.of_iteri ~iteri ~comparator:(to_comparator m) +let of_increasing_iterator_unchecked m ~len ~f = + Using_comparator.of_increasing_iterator_unchecked ~len ~f ~comparator:(to_comparator m) +let of_increasing_sequence m seq = Using_comparator.of_increasing_sequence ~comparator:(to_comparator m) seq + +module M(K : sig type t type comparator_witness end) = struct + type nonrec 'v t = (K.t, 'v, K.comparator_witness) t +end +module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end +module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Comparator.S with type t := t +end +module type Compare_m = sig end +module type Hash_fold_m = Hasher.S + +let sexp_of_m__t (type k) (module K : Sexp_of_m with type t = k) sexp_of_v t = + sexp_of_t K.sexp_of_t sexp_of_v (fun _ -> Sexp.Atom "_") t + +let m__t_of_sexp (type k cmp) + (module K : M_of_sexp with type t = k and type comparator_witness = cmp) + v_of_sexp sexp = + Using_comparator.t_of_sexp_direct ~comparator:K.comparator K.t_of_sexp v_of_sexp sexp + +let compare_m__t (module K : Compare_m) compare_v t1 t2 = + compare_direct compare_v t1 t2 + +let hash_fold_m__t (type k) (module K : Hash_fold_m with type t = k) hash_fold_v state = + hash_fold_direct K.hash_fold_t hash_fold_v state + +let merge_skewed t1 t2 ~combine = + let t1, t2, combine = + if length t2 <= length t1 + then t1, t2, combine + else t2, t1, (fun ~key v1 v2 -> combine ~key v2 v1) + in + fold t2 ~init:t1 ~f:(fun ~key ~data:v2 t1 -> + change t1 key ~f:(function + | None -> Some v2 + | Some v1 -> Some (combine ~key v1 v2))) + +module Poly = struct + type nonrec ('k, 'v) t = ('k, 'v, Comparator.Poly.comparator_witness) t + type nonrec ('k, 'v) tree = ('k, 'v) Tree0.t + + include Accessors + + let comparator = Comparator.Poly.comparator + + let of_tree tree = { tree; comparator; length = Tree0.length tree} + + include Using_comparator.Empty_without_value_restriction(Comparator.Poly) + + let singleton a = Using_comparator.singleton ~comparator a + let of_alist a = Using_comparator.of_alist ~comparator a + let of_alist_or_error a = Using_comparator.of_alist_or_error ~comparator a + let of_alist_exn a = Using_comparator.of_alist_exn ~comparator a + let of_alist_multi a = Using_comparator.of_alist_multi ~comparator a + let of_alist_fold a ~init ~f = Using_comparator.of_alist_fold ~comparator a ~init ~f + let of_alist_reduce a ~f = Using_comparator.of_alist_reduce ~comparator a ~f + let of_sorted_array_unchecked a = Using_comparator.of_sorted_array_unchecked ~comparator a + let of_sorted_array a = Using_comparator.of_sorted_array ~comparator a + let of_iteri ~iteri = Using_comparator.of_iteri ~iteri ~comparator + let of_increasing_iterator_unchecked ~len ~f = + Using_comparator.of_increasing_iterator_unchecked ~len ~f ~comparator + let of_increasing_sequence seq = Using_comparator.of_increasing_sequence ~comparator seq +end diff --git a/src/map.mli b/src/map.mli new file mode 100644 index 0000000..2bee42d --- /dev/null +++ b/src/map.mli @@ -0,0 +1 @@ +include Map_intf.Map (** @inline *) diff --git a/src/map_intf.ml b/src/map_intf.ml new file mode 100644 index 0000000..58ebd99 --- /dev/null +++ b/src/map_intf.ml @@ -0,0 +1,1936 @@ +open! Import +open! T + +module Or_duplicate = struct + type 'a t = [ `Ok of 'a | `Duplicate ] + [@@deriving_inline sexp_of] + let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + function + | `Ok v0 -> + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Ok"; _of_a v0] + | `Duplicate -> Ppx_sexp_conv_lib.Sexp.Atom "Duplicate" + [@@@end] +end + +module Without_comparator = struct + type ('key, 'cmp, 'z) t = 'z +end + +module With_comparator = struct + type ('key, 'cmp, 'z) t = comparator:('key, 'cmp) Comparator.t -> 'z +end + +module With_first_class_module = struct + type ('key, 'cmp, 'z) t = + (module Comparator.S with type t = 'key and type comparator_witness = 'cmp) -> 'z +end + +module Symmetric_diff_element = struct + type ('k, 'v) t = 'k * [ `Left of 'v | `Right of 'v | `Unequal of 'v * 'v ] + [@@deriving_inline compare, sexp] + let compare : + 'k 'v . + ('k -> 'k -> int) -> ('v -> 'v -> int) -> ('k, 'v) t -> ('k, 'v) t -> int + = + fun _cmp__k -> + fun _cmp__v -> + fun a__001_ -> + fun b__002_ -> + let (t__003_, t__004_) = a__001_ in + let (t__005_, t__006_) = b__002_ in + match _cmp__k t__003_ t__005_ with + | 0 -> + if Ppx_compare_lib.phys_equal t__004_ t__006_ + then 0 + else + (match (t__004_, t__006_) with + | (`Left _left__007_, `Left _right__008_) -> + _cmp__v _left__007_ _right__008_ + | (`Right _left__009_, `Right _right__010_) -> + _cmp__v _left__009_ _right__010_ + | (`Unequal _left__011_, `Unequal _right__012_) -> + let (t__013_, t__014_) = _left__011_ in + let (t__015_, t__016_) = _right__012_ in + (match _cmp__v t__013_ t__015_ with + | 0 -> _cmp__v t__014_ t__016_ + | n -> n) + | (x, y) -> Ppx_compare_lib.polymorphic_compare x y) + | n -> n + let t_of_sexp : + 'k 'v . + (Ppx_sexp_conv_lib.Sexp.t -> 'k) -> + (Ppx_sexp_conv_lib.Sexp.t -> 'v) -> + Ppx_sexp_conv_lib.Sexp.t -> ('k, 'v) t + = + let _tp_loc = "src/map_intf.ml.Symmetric_diff_element.t" in + fun _of_k -> + fun _of_v -> + function + | Ppx_sexp_conv_lib.Sexp.List (v0::v1::[]) -> + let v0 = _of_k v0 + and v1 = + (fun sexp -> + try + match sexp with + | Ppx_sexp_conv_lib.Sexp.Atom atom as _sexp -> + (match atom with + | "Left" -> + Ppx_sexp_conv_lib.Conv_error.ptag_takes_args + _tp_loc _sexp + | "Right" -> + Ppx_sexp_conv_lib.Conv_error.ptag_takes_args + _tp_loc _sexp + | "Unequal" -> + Ppx_sexp_conv_lib.Conv_error.ptag_takes_args + _tp_loc _sexp + | _ -> Ppx_sexp_conv_lib.Conv_error.no_variant_match ()) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + atom)::sexp_args) as _sexp -> + (match atom with + | "Left" as _tag -> + (match sexp_args with + | v0::[] -> let v0 = _of_v v0 in `Left v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.ptag_incorrect_n_args + _tp_loc _tag _sexp) + | "Right" as _tag -> + (match sexp_args with + | v0::[] -> let v0 = _of_v v0 in `Right v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.ptag_incorrect_n_args + _tp_loc _tag _sexp) + | "Unequal" as _tag -> + (match sexp_args with + | v0::[] -> + let v0 = + match v0 with + | Ppx_sexp_conv_lib.Sexp.List (v0::v1::[]) + -> + let v0 = _of_v v0 + and v1 = _of_v v1 in (v0, v1) + | sexp -> + Ppx_sexp_conv_lib.Conv_error.tuple_of_size_n_expected + _tp_loc 2 sexp in + `Unequal v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.ptag_incorrect_n_args + _tp_loc _tag _sexp) + | _ -> Ppx_sexp_conv_lib.Conv_error.no_variant_match ()) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List + _)::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_poly_var + _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_poly_var + _tp_loc sexp + with + | Ppx_sexp_conv_lib.Conv_error.No_variant_match -> + Ppx_sexp_conv_lib.Conv_error.no_matching_variant_found + _tp_loc sexp) v1 in + (v0, v1) + | sexp -> + Ppx_sexp_conv_lib.Conv_error.tuple_of_size_n_expected _tp_loc 2 + sexp + let sexp_of_t : + 'k 'v . + ('k -> Ppx_sexp_conv_lib.Sexp.t) -> + ('v -> Ppx_sexp_conv_lib.Sexp.t) -> + ('k, 'v) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_k -> + fun _of_v -> + function + | (v0, v1) -> + let v0 = _of_k v0 + and v1 = + match v1 with + | `Left v0 -> + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Left"; _of_v v0] + | `Right v0 -> + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Right"; _of_v v0] + | `Unequal v0 -> + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Unequal"; + (let (v0, v1) = v0 in + let v0 = _of_v v0 + and v1 = _of_v v1 in Ppx_sexp_conv_lib.Sexp.List [v0; v1])] in + Ppx_sexp_conv_lib.Sexp.List [v0; v1] + [@@@end] +end + +module type Accessors_generic = sig + type ('a, 'b, 'cmp) t + type ('a, 'b, 'cmp) tree + type 'a key + type ('a, 'cmp, 'z) options + + val invariants + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> bool + ) options + + val is_empty : (_, _, _) t -> bool + + val length : (_, _, _) t -> int + + val add + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> key:'k key -> data:'v -> ('k, 'v, 'cmp) t Or_duplicate.t + ) options + + val add_exn + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> key:'k key -> data:'v -> ('k, 'v, 'cmp) t + ) options + + val set + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> key:'k key -> data:'v -> ('k, 'v, 'cmp) t + ) options + + val add_multi + : ('k, 'cmp, + ('k, 'v list, 'cmp) t + -> key:'k key + -> data:'v + -> ('k, 'v list, 'cmp) t + ) options + + val remove_multi + : ('k, 'cmp, + ('k, 'v list, 'cmp) t + -> 'k key + -> ('k, 'v list, 'cmp) t + ) options + + val find_multi + : ('k, 'cmp, + ('k, 'v list, 'cmp) t + -> 'k key -> 'v list + ) options + + val change + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> 'k key + -> f:('v option -> 'v option) + -> ('k, 'v, 'cmp) t + ) options + + val update + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> 'k key + -> f:('v option -> 'v) + -> ('k, 'v, 'cmp) t + ) options + + val find : ('k, 'cmp, ('k, 'v, 'cmp) t -> 'k key -> 'v option) options + val find_exn : ('k, 'cmp, ('k, 'v, 'cmp) t -> 'k key -> 'v ) options + + val remove + : ('k, 'cmp, ('k, 'v, 'cmp) t -> 'k key -> ('k, 'v, 'cmp) t + ) options + + val mem : ('k, 'cmp, ('k, _, 'cmp) t -> 'k key -> bool) options + + val iter_keys : ('k, _, _) t -> f:('k key -> unit) -> unit + val iter : ( _, 'v, _) t -> f:('v -> unit) -> unit + val iteri : ('k, 'v, _) t -> f:(key:'k key -> data:'v -> unit) -> unit + + val iter2 + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> f:(key:'k key + -> data:[ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] + -> unit) + -> unit + ) options + + val map : ('k, 'v1, 'cmp) t -> f:('v1 -> 'v2) -> ('k, 'v2, 'cmp) t + + val mapi + : ('k, 'v1, 'cmp) t + -> f:(key:'k key -> data:'v1 -> 'v2) + -> ('k, 'v2, 'cmp) t + + val fold : ('k, 'v, _) t -> init:'a -> f:(key:'k key -> data:'v -> 'a -> 'a) -> 'a + val fold_right : ('k, 'v, _) t -> init:'a -> f:(key:'k key -> data:'v -> 'a -> 'a) -> 'a + + val fold2 + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> init:'a + -> f:(key:'k key + -> data:[ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] + -> 'a + -> 'a) + -> 'a + ) options + + val filter_keys + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> f:('k key -> bool) + -> ('k, 'v, 'cmp) t + ) options + + val filter + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> f:('v -> bool) + -> ('k, 'v, 'cmp) t + ) options + + val filteri + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> f:(key:'k key -> data:'v -> bool) + -> ('k, 'v, 'cmp) t + ) options + + val filter_map + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> f:('v1 -> 'v2 option) + -> ('k, 'v2, 'cmp) t + ) options + + val filter_mapi + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> f:(key:'k key -> data:'v1 -> 'v2 option) + -> ('k, 'v2, 'cmp) t + ) options + + val partition_mapi + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> f:(key:'k key -> data:'v1 -> [`Fst of 'v2 | `Snd of 'v3]) + -> ('k, 'v2, 'cmp) t * ('k, 'v3, 'cmp) t + ) options + + val partition_map + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> f:('v1 -> [`Fst of 'v2 | `Snd of 'v3]) + -> ('k, 'v2, 'cmp) t * ('k, 'v3, 'cmp) t + ) options + + val partitioni_tf + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> f:(key:'k key -> data:'v -> bool) + -> ('k, 'v, 'cmp) t * ('k, 'v, 'cmp) t + ) options + + val partition_tf + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> f:('v -> bool) + -> ('k, 'v, 'cmp) t * ('k, 'v, 'cmp) t + ) options + + val compare_direct + : ('k, 'cmp, + ('v -> 'v -> int) + -> ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> int + ) options + + val equal + : ('k, 'cmp, + ('v -> 'v -> bool) + -> ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> bool + ) options + + val keys : ('k, _, _) t -> 'k key list + + val data : (_, 'v, _) t -> 'v list + + val to_alist + : ?key_order:[`Increasing|`Decreasing] + -> ('k, 'v, _) t + -> ('k key * 'v) list + + val validate + : name:('k key -> string) + -> 'v Validate.check + -> ('k, 'v, _) t Validate.check + + val merge + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> f:(key:'k key + -> [ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] + -> 'v3 option) + -> ('k, 'v3, 'cmp) t + ) options + + val symmetric_diff + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> data_equal:('v -> 'v -> bool) + -> ('k key, 'v) Symmetric_diff_element.t Sequence.t + ) options + + val min_elt : ('k, 'v, _) t -> ('k key * 'v) option + val min_elt_exn : ('k, 'v, _) t -> 'k key * 'v + + val max_elt : ('k, 'v, _) t -> ('k key * 'v) option + val max_elt_exn : ('k, 'v, _) t -> 'k key * 'v + + val for_all : ('k, 'v, _) t -> f:( 'v -> bool) -> bool + val for_alli : ('k, 'v, _) t -> f:(key:'k key -> data:'v -> bool) -> bool + val exists : ('k, 'v, _) t -> f:( 'v -> bool) -> bool + val existsi : ('k, 'v, _) t -> f:(key:'k key -> data:'v -> bool) -> bool + val count : ('k, 'v, _) t -> f:( 'v -> bool) -> int + val counti : ('k, 'v, _) t -> f:(key:'k key -> data:'v -> bool) -> int + + val split + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> 'k key + -> ('k, 'v, 'cmp) t * ('k key * 'v) option * ('k, 'v, 'cmp) t + ) options + + val append + : ('k, 'cmp, + lower_part:('k, 'v, 'cmp) t + -> upper_part:('k, 'v, 'cmp) t + -> [ `Ok of ('k, 'v, 'cmp) t | `Overlapping_key_ranges ] + ) options + + val subrange + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> lower_bound:'k key Maybe_bound.t + -> upper_bound:'k key Maybe_bound.t + -> ('k, 'v, 'cmp) t + ) options + + val fold_range_inclusive + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> min:'k key + -> max:'k key + -> init:'a + -> f:(key:'k key -> data:'v -> 'a -> 'a) + -> 'a + ) options + + val range_to_alist + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> min:'k key -> max:'k key -> ('k key * 'v) list + ) options + + val closest_key + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'k key -> ('k key * 'v) option + ) options + + val nth + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> int -> ('k key * 'v) option + ) options + + val nth_exn + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> int -> ('k key * 'v) + ) options + + val rank + : ('k, 'cmp, + ('k, _, 'cmp) t -> 'k key -> int option + ) options + + val to_tree : ('k, 'v, 'cmp) t -> ('k key, 'v, 'cmp) tree + + val to_sequence + : ('k, 'cmp, + ?order:[ `Increasing_key | `Decreasing_key ] + -> ?keys_greater_or_equal_to:'k key + -> ?keys_less_or_equal_to:'k key + -> ('k, 'v, 'cmp) t + -> ('k key * 'v) Sequence.t + ) options + +end + +module type Accessors1 = sig + type 'a t + type 'a tree + type key + val invariants : _ t -> bool + val is_empty : _ t -> bool + val length : _ t -> int + val add : 'a t -> key:key -> data:'a -> 'a t Or_duplicate.t + val add_exn : 'a t -> key:key -> data:'a -> 'a t + val set : 'a t -> key:key -> data:'a -> 'a t + val add_multi : 'a list t -> key:key -> data:'a -> 'a list t + val remove_multi : 'a list t -> key -> 'a list t + val find_multi : 'a list t -> key -> 'a list + val change : 'a t -> key -> f:('a option -> 'a option) -> 'a t + val update : 'a t -> key -> f:('a option -> 'a) -> 'a t + val find : 'a t -> key -> 'a option + val find_exn : 'a t -> key -> 'a + val remove : 'a t -> key -> 'a t + val mem : _ t -> key -> bool + + val iter_keys : _ t -> f:(key -> unit) -> unit + val iter : 'a t -> f:('a -> unit) -> unit + val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit + val iter2 + : 'a t + -> 'b t + -> f:(key:key -> data:[ `Left of 'a | `Right of 'b | `Both of 'a * 'b ] -> unit) + -> unit + val map : 'a t -> f:('a -> 'b) -> 'b t + val mapi : 'a t -> f:(key:key -> data:'a -> 'b) -> 'b t + val fold : 'a t -> init:'b -> f:(key:key -> data:'a -> 'b -> 'b) -> 'b + val fold_right : 'a t -> init:'b -> f:(key:key -> data:'a -> 'b -> 'b) -> 'b + val fold2 + : 'a t + -> 'b t + -> init:'c + -> f:(key:key -> data:[ `Left of 'a | `Right of 'b | `Both of 'a * 'b ] -> 'c -> 'c) + -> 'c + + val filter_keys : 'a t -> f:(key -> bool) -> 'a t + val filter : 'a t -> f:('a -> bool) -> 'a t + val filteri : 'a t -> f:(key:key -> data:'a -> bool) -> 'a t + val filter_map : 'a t -> f:('a -> 'b option) -> 'b t + val filter_mapi : 'a t -> f:(key:key -> data:'a -> 'b option) -> 'b t + val partition_mapi + : 'a t + -> f:(key:key -> data:'a -> [`Fst of 'b | `Snd of 'c]) + -> 'b t * 'c t + val partition_map + : 'a t + -> f:('a -> [`Fst of 'b | `Snd of 'c]) + -> 'b t * 'c t + val partitioni_tf + : 'a t + -> f:(key:key -> data:'a -> bool) + -> 'a t * 'a t + val partition_tf + : 'a t + -> f:('a -> bool) + -> 'a t * 'a t + val compare_direct : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal : ('a -> 'a -> bool)-> 'a t -> 'a t -> bool + val keys : _ t -> key list + val data : 'a t -> 'a list + val to_alist : ?key_order:[`Increasing|`Decreasing] -> 'a t -> (key * 'a) list + val validate : name:(key -> string) -> 'a Validate.check -> 'a t Validate.check + val merge + : 'a t + -> 'b t + -> f:(key:key -> [ `Left of 'a | `Right of 'b | `Both of 'a * 'b ] -> 'c option) + -> 'c t + val symmetric_diff + : 'a t + -> 'a t + -> data_equal:('a -> 'a -> bool) + -> (key, 'a) Symmetric_diff_element.t Sequence.t + val min_elt : 'a t -> (key * 'a) option + val min_elt_exn : 'a t -> key * 'a + val max_elt : 'a t -> (key * 'a) option + val max_elt_exn : 'a t -> key * 'a + val for_all : 'a t -> f:( 'a -> bool) -> bool + val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool + val exists : 'a t -> f:( 'a -> bool) -> bool + val existsi : 'a t -> f:(key:key -> data:'a -> bool) -> bool + val count : 'a t -> f:( 'a -> bool) -> int + val counti : 'a t -> f:(key:key -> data:'a -> bool) -> int + val split : 'a t -> key -> 'a t * (key * 'a) option * 'a t + val append + : lower_part:'a t + -> upper_part:'a t + -> [ `Ok of 'a t | `Overlapping_key_ranges ] + val subrange + : 'a t + -> lower_bound:key Maybe_bound.t + -> upper_bound:key Maybe_bound.t + -> 'a t + val fold_range_inclusive + : 'a t + -> min:key + -> max:key + -> init:'b + -> f:(key:key -> data:'a -> 'b -> 'b) + -> 'b + val range_to_alist : 'a t -> min:key -> max:key -> (key * 'a) list + val closest_key : 'a t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> key -> (key * 'a) option + val nth : 'a t -> int -> (key * 'a) option + val nth_exn : 'a t -> int -> (key * 'a) + val rank : _ t -> key -> int option + val to_tree : 'a t -> 'a tree + val to_sequence + : ?order:[ `Increasing_key | `Decreasing_key ] + -> ?keys_greater_or_equal_to:key + -> ?keys_less_or_equal_to:key + -> 'a t + -> (key * 'a) Sequence.t +end + +module type Accessors2 = sig + type ('a, 'b) t + type ('a, 'b) tree + val invariants : (_, _) t -> bool + val is_empty : (_, _) t -> bool + val length : (_, _) t -> int + val add : ('a, 'b) t -> key:'a -> data:'b -> ('a, 'b) t Or_duplicate.t + val add_exn : ('a, 'b) t -> key:'a -> data:'b -> ('a, 'b) t + val set : ('a, 'b) t -> key:'a -> data:'b -> ('a, 'b) t + val add_multi : ('a, 'b list) t -> key:'a -> data:'b -> ('a, 'b list) t + val remove_multi : ('a, 'b list) t -> 'a -> ('a, 'b list) t + val find_multi : ('a, 'b list) t -> 'a -> 'b list + val change : ('a, 'b) t -> 'a -> f:('b option -> 'b option) -> ('a, 'b) t + val update : ('a, 'b) t -> 'a -> f:('b option -> 'b) -> ('a, 'b) t + val find : ('a, 'b) t -> 'a -> 'b option + val find_exn : ('a, 'b) t -> 'a -> 'b + val remove : ('a, 'b) t -> 'a -> ('a, 'b) t + val mem : ('a, 'b) t -> 'a -> bool + + val iter_keys : ('a, _) t -> f:('a -> unit) -> unit + val iter : ( _, 'b) t -> f:('b -> unit) -> unit + val iteri : ('a, 'b) t -> f:(key:'a -> data:'b -> unit) -> unit + val iter2 + : ('a, 'b) t + -> ('a, 'c) t + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> unit) + -> unit + val map : ('a, 'b) t -> f:('b -> 'c) -> ('a, 'c) t + val mapi : ('a, 'b) t -> f:(key:'a -> data:'b -> 'c) -> ('a, 'c) t + val fold : ('a, 'b) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold_right : ('a, 'b) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold2 + : ('a, 'b) t + -> ('a, 'c) t + -> init:'d + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> 'd -> 'd) + -> 'd + + val filter_keys : ('a, 'b) t -> f:('a -> bool) -> ('a, 'b) t + val filter : ('a, 'b) t -> f:('b -> bool) -> ('a, 'b) t + val filteri : ('a, 'b) t -> f:(key:'a -> data:'b -> bool) -> ('a, 'b) t + val filter_map : ('a, 'b) t -> f:('b -> 'c option) -> ('a, 'c) t + val filter_mapi : ('a, 'b) t -> f:(key:'a -> data:'b -> 'c option) -> ('a, 'c) t + val partition_mapi + : ('a, 'b) t + -> f:(key:'a -> data:'b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c) t * ('a, 'd) t + val partition_map + : ('a, 'b) t + -> f:('b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c) t * ('a, 'd) t + val partitioni_tf + : ('a, 'b) t + -> f:(key:'a -> data:'b -> bool) + -> ('a, 'b) t * ('a, 'b) t + val partition_tf + : ('a, 'b) t + -> f:('b -> bool) + -> ('a, 'b) t * ('a, 'b) t + val compare_direct : ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + val equal : ('b -> 'b -> bool)-> ('a, 'b) t -> ('a, 'b) t -> bool + val keys : ('a, _) t -> 'a list + val data : (_, 'b) t -> 'b list + val to_alist : ?key_order:[`Increasing|`Decreasing] -> ('a, 'b) t -> ('a * 'b) list + val validate + : name:('a -> string) -> 'b Validate.check -> ('a, 'b) t Validate.check + val merge + : ('a, 'b) t + -> ('a, 'c) t + -> f:(key:'a -> [ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> 'd option) + -> ('a, 'd) t + val symmetric_diff + : ('a, 'b) t + -> ('a, 'b) t + -> data_equal:('b -> 'b -> bool) + -> ('a, 'b) Symmetric_diff_element.t Sequence.t + val min_elt : ('a, 'b) t -> ('a * 'b) option + val min_elt_exn : ('a, 'b) t -> 'a * 'b + val max_elt : ('a, 'b) t -> ('a * 'b) option + val max_elt_exn : ('a, 'b) t -> 'a * 'b + val for_all : ( _, 'b) t -> f:( 'b -> bool) -> bool + val for_alli : ('a, 'b) t -> f:(key:'a -> data:'b -> bool) -> bool + val exists : ( _, 'b) t -> f:( 'b -> bool) -> bool + val existsi : ('a, 'b) t -> f:(key:'a -> data:'b -> bool) -> bool + val count : ( _, 'b) t -> f:( 'b -> bool) -> int + val counti : ('a, 'b) t -> f:(key:'a -> data:'b -> bool) -> int + val split : ('a, 'b) t -> 'a -> ('a, 'b) t * ('a * 'b) option * ('a, 'b) t + val append + : lower_part:('a, 'b) t + -> upper_part:('a, 'b) t + -> [ `Ok of ('a, 'b) t | `Overlapping_key_ranges ] + val subrange + : ('a, 'b) t + -> lower_bound:'a Maybe_bound.t + -> upper_bound:'a Maybe_bound.t + -> ('a, 'b) t + val fold_range_inclusive + : ('a, 'b) t -> min:'a -> max:'a -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val range_to_alist : ('a, 'b) t -> min:'a -> max:'a -> ('a * 'b) list + val closest_key : ('a, 'b) t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'a -> ('a * 'b) option + val nth : ('a, 'b) t -> int -> ('a * 'b) option + val nth_exn : ('a, 'b) t -> int -> ('a * 'b) + val rank : ('a, _) t -> 'a -> int option + val to_tree : ('a, 'b) t -> ('a, 'b) tree + val to_sequence + : ?order:[ `Increasing_key | `Decreasing_key ] + -> ?keys_greater_or_equal_to:'a + -> ?keys_less_or_equal_to:'a + -> ('a, 'b) t + -> ('a * 'b) Sequence.t +end + +module type Accessors3 = sig + type ('a, 'b, 'cmp) t + type ('a, 'b, 'cmp) tree + val invariants : (_, _, _) t -> bool + val is_empty : (_, _, _) t -> bool + val length : (_, _, _) t -> int + val add : ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b , 'cmp) t Or_duplicate.t + val add_exn : ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b , 'cmp) t + val set : ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b , 'cmp) t + val add_multi : ('a, 'b list, 'cmp) t -> key:'a -> data:'b -> ('a, 'b list, 'cmp) t + val remove_multi : ('a, 'b list, 'cmp) t -> 'a -> ('a, 'b list, 'cmp) t + val find_multi : ('a, 'b list, 'cmp) t -> 'a -> 'b list + val change : ('a, 'b, 'cmp) t -> 'a -> f:('b option -> 'b option) -> ('a, 'b, 'cmp) t + val update : ('a, 'b, 'cmp) t -> 'a -> f:('b option -> 'b) -> ('a, 'b, 'cmp) t + val find : ('a, 'b, 'cmp) t -> 'a -> 'b option + val find_exn : ('a, 'b, 'cmp) t -> 'a -> 'b + val remove : ('a, 'b, 'cmp) t -> 'a -> ('a, 'b, 'cmp) t + val mem : ('a, 'b, 'cmp) t -> 'a -> bool + + val iter_keys : ('a, _, 'cmp) t -> f:('a -> unit) -> unit + val iter : ( _, 'b, 'cmp) t -> f:('b -> unit) -> unit + val iteri : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> unit) -> unit + val iter2 + : ('a, 'b, 'cmp) t + -> ('a, 'c, 'cmp) t + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> unit) + -> unit + val map : ('a, 'b, 'cmp) t -> f:('b -> 'c) -> ('a, 'c, 'cmp) t + val mapi : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> 'c) -> ('a, 'c, 'cmp) t + val fold : ('a, 'b, _) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold_right : ('a, 'b, _) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold2 + : ('a, 'b, 'cmp) t + -> ('a, 'c, 'cmp) t + -> init:'d + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> 'd -> 'd) + -> 'd + + val filter_keys : ('a, 'b, 'cmp) t -> f:('a -> bool) -> ('a, 'b, 'cmp) t + val filter : ('a, 'b, 'cmp) t -> f:('b -> bool) -> ('a, 'b, 'cmp) t + val filteri : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> bool) -> ('a, 'b, 'cmp) t + val filter_map : ('a, 'b, 'cmp) t -> f:('b -> 'c option) -> ('a, 'c, 'cmp) t + val filter_mapi + : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> 'c option) -> ('a, 'c, 'cmp) t + val partition_mapi + : ('a, 'b, 'cmp) t + -> f:(key:'a -> data:'b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c, 'cmp) t * ('a, 'd, 'cmp) t + val partition_map + : ('a, 'b, 'cmp) t + -> f:('b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c, 'cmp) t * ('a, 'd, 'cmp) t + val partitioni_tf + : ('a, 'b, 'cmp) t + -> f:(key:'a -> data:'b -> bool) + -> ('a, 'b, 'cmp) t * ('a, 'b, 'cmp) t + val partition_tf + : ('a, 'b, 'cmp) t + -> f:('b -> bool) + -> ('a, 'b, 'cmp) t * ('a, 'b, 'cmp) t + val compare_direct : ('b -> 'b -> int) -> ('a, 'b, 'cmp) t -> ('a, 'b, 'cmp) t -> int + val equal : ('b -> 'b -> bool)-> ('a, 'b, 'cmp) t -> ('a, 'b, 'cmp) t -> bool + val keys : ('a, _, _) t -> 'a list + val data : (_, 'b, _) t -> 'b list + val to_alist : ?key_order:[`Increasing|`Decreasing] -> ('a, 'b, _) t -> ('a * 'b) list + val validate + : name:('a -> string) -> 'b Validate.check -> ('a, 'b, _) t Validate.check + val merge + : ('a, 'b, 'cmp) t -> ('a, 'c, 'cmp) t + -> f:(key:'a -> [ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> 'd option) + -> ('a, 'd, 'cmp) t + val symmetric_diff + : ('a, 'b, 'cmp) t + -> ('a, 'b, 'cmp) t + -> data_equal:('b -> 'b -> bool) + -> ('a, 'b) Symmetric_diff_element.t Sequence.t + val min_elt : ('a, 'b, 'cmp) t -> ('a * 'b) option + val min_elt_exn : ('a, 'b, 'cmp) t -> 'a * 'b + val max_elt : ('a, 'b, 'cmp) t -> ('a * 'b) option + val max_elt_exn : ('a, 'b, 'cmp) t -> 'a * 'b + val for_all : ( _, 'b, _) t -> f:( 'b -> bool) -> bool + val for_alli : ('a, 'b, _) t -> f:(key:'a -> data:'b -> bool) -> bool + val exists : ( _, 'b, _) t -> f:( 'b -> bool) -> bool + val existsi : ('a, 'b, _) t -> f:(key:'a -> data:'b -> bool) -> bool + val count : ( _, 'b, _) t -> f:( 'b -> bool) -> int + val counti : ('a, 'b, _) t -> f:(key:'a -> data:'b -> bool) -> int + val split + : ('k, 'v, 'cmp) t + -> 'k + -> ('k, 'v, 'cmp) t * ('k * 'v) option * ('k, 'v, 'cmp) t + val append + : lower_part:('k, 'v, 'cmp) t + -> upper_part:('k, 'v, 'cmp) t + -> [ `Ok of ('k, 'v, 'cmp) t | `Overlapping_key_ranges ] + val subrange + : ('k, 'v, 'cmp) t + -> lower_bound:'k Maybe_bound.t + -> upper_bound:'k Maybe_bound.t + -> ('k, 'v, 'cmp) t + val fold_range_inclusive + : ('a, 'b, _) t + -> min:'a + -> max:'a + -> init:'c + -> f:(key:'a -> data:'b -> 'c -> 'c) + -> 'c + val range_to_alist : ('a, 'b, _) t -> min:'a -> max:'a -> ('a * 'b) list + val closest_key : ('a, 'b, _) t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'a -> ('a * 'b) option + val nth : ('a, 'b, _) t -> int -> ('a * 'b) option + val nth_exn : ('a, 'b, _) t -> int -> ('a * 'b) + val rank : ('a, _, _) t -> 'a -> int option + val to_tree : ('a, 'b, 'cmp) t -> ('a, 'b, 'cmp) tree + val to_sequence + : ?order:[ `Increasing_key | `Decreasing_key ] + -> ?keys_greater_or_equal_to:'a + -> ?keys_less_or_equal_to:'a + -> ('a, 'b, _) t + -> ('a * 'b) Sequence.t +end + +module type Accessors3_with_comparator = sig + type ('a, 'b, 'cmp) t + type ('a, 'b, 'cmp) tree + val invariants : comparator:('a, 'cmp) Comparator.t -> ('a, 'b, 'cmp) t -> bool + val is_empty : ('a, 'b, 'cmp) t -> bool + val length : ('a, 'b, 'cmp) t -> int + val add + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b, 'cmp) t Or_duplicate.t + val add_exn + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b, 'cmp) t + val set + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b, 'cmp) t + val add_multi + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b list, 'cmp) t -> key:'a -> data:'b -> ('a, 'b list, 'cmp) t + val remove_multi + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b list, 'cmp) t -> 'a -> ('a, 'b list, 'cmp) t + val find_multi + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b list, 'cmp) t -> 'a -> 'b list + val change + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> f:('b option -> 'b option) -> ('a, 'b, 'cmp) t + val update + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> f:('b option -> 'b) -> ('a, 'b, 'cmp) t + val find + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> 'b option + val find_exn + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> 'b + val remove + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> ('a, 'b, 'cmp) t + val mem + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> bool + + val iter_keys : ('a, _, 'cmp) t -> f:('a -> unit) -> unit + val iter : ( _, 'b, 'cmp) t -> f:('b -> unit) -> unit + val iteri : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> unit) -> unit + val iter2 + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> ('a, 'c, 'cmp) t + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c ]-> unit) + -> unit + val map : ('a, 'b, 'cmp) t -> f:('b -> 'c) -> ('a, 'c, 'cmp) t + val mapi : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> 'c) -> ('a, 'c, 'cmp) t + val fold : ('a, 'b, _) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold_right : ('a, 'b, _) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold2 + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> ('a, 'c, 'cmp) t + -> init:'d + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c] -> 'd -> 'd) + -> 'd + + val filter_keys + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> f:('a -> bool) -> ('a, 'b, 'cmp) t + val filter + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> f:('b -> bool) -> ('a, 'b, 'cmp) t + val filteri + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> bool) -> ('a, 'b, 'cmp) t + val filter_map + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> f:('b -> 'c option) -> ('a, 'c, 'cmp) t + val filter_mapi + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> 'c option) -> ('a, 'c, 'cmp) t + val partition_mapi + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> f:(key:'a -> data:'b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c, 'cmp) t * ('a, 'd, 'cmp) t + val partition_map + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> f:('b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c, 'cmp) t * ('a, 'd, 'cmp) t + val partitioni_tf + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> f:(key:'a -> data:'b -> bool) + -> ('a, 'b, 'cmp) t * ('a, 'b, 'cmp) t + val partition_tf + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> f:('b -> bool) + -> ('a, 'b, 'cmp) t * ('a, 'b, 'cmp) t + val compare_direct + : comparator:('a, 'cmp) Comparator.t + -> ('b -> 'b -> int) + -> ('a, 'b, 'cmp) t + -> ('a, 'b, 'cmp) t + -> int + val equal + : comparator:('a, 'cmp) Comparator.t + -> ('b -> 'b -> bool) -> ('a, 'b, 'cmp) t -> ('a, 'b, 'cmp) t -> bool + val keys : ('a, _, _) t -> 'a list + val data : (_ , 'b, _) t -> 'b list + val to_alist + : ?key_order:[`Increasing|`Decreasing] -> ('a, 'b, _) t -> ('a * 'b) list + val validate + : name:('a -> string) -> 'b Validate.check -> ('a, 'b, _) t Validate.check + val merge + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> ('a, 'c, 'cmp) t + -> f:(key:'a -> [ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> 'd option) + -> ('a, 'd, 'cmp) t + val symmetric_diff + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> ('a, 'b, 'cmp) t + -> data_equal:('b -> 'b -> bool) + -> ('a, 'b) Symmetric_diff_element.t Sequence.t + val min_elt : ('a, 'b, 'cmp) t -> ('a * 'b) option + val min_elt_exn : ('a, 'b, 'cmp) t -> 'a * 'b + val max_elt : ('a, 'b, 'cmp) t -> ('a * 'b) option + val max_elt_exn : ('a, 'b, 'cmp) t -> 'a * 'b + val for_all : ('a, 'b, 'cmp) t -> f:( 'b -> bool) -> bool + val for_alli : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> bool) -> bool + val exists : ('a, 'b, 'cmp) t -> f:( 'b -> bool) -> bool + val existsi : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> bool) -> bool + val count : ('a, 'b, 'cmp) t -> f:( 'b -> bool) -> int + val counti : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> bool) -> int + val split + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> 'a + -> ('a, 'b, 'cmp) t * ('a * 'b) option * ('a, 'b, 'cmp) t + val append + : comparator:('a, 'cmp) Comparator.t + -> lower_part:('a, 'b, 'cmp) t + -> upper_part:('a, 'b, 'cmp) t + -> [ `Ok of ('a, 'b, 'cmp) t | `Overlapping_key_ranges ] + val subrange + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> lower_bound:'a Maybe_bound.t + -> upper_bound:'a Maybe_bound.t + -> ('a, 'b, 'cmp) t + val fold_range_inclusive + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> min:'a -> max:'a -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val range_to_alist + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> min:'a -> max:'a -> ('a * 'b) list + val closest_key + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'a -> ('a * 'b) option + val nth + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> int -> ('a * 'b) option + val nth_exn + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> int -> ('a * 'b) + val rank + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> int option + val to_tree : ('a, 'b, 'cmp) t -> ('a, 'b, 'cmp) tree + val to_sequence + : comparator:('a, 'cmp) Comparator.t + -> ?order:[ `Increasing_key | `Decreasing_key ] + -> ?keys_greater_or_equal_to:'a + -> ?keys_less_or_equal_to:'a + -> ('a, 'b, 'cmp) t + -> ('a * 'b) Sequence.t +end + +(** Consistency checks (same as in [Container]). *) +module Check_accessors (T : T3) (Tree : T3) (Key : T1) (Options : T3) + (M : Accessors_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) Options.t + with type ('a, 'b, 'c) t := ('a, 'b, 'c) T.t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Tree.t + with type 'a key := 'a Key.t) += struct end + +module Check_accessors1 (M : Accessors1) = + Check_accessors + (struct type ('a, 'b, 'c) t = 'b M.t end) + (struct type ('a, 'b, 'c) t = 'b M.tree end) + (struct type 'a t = M.key end) + (Without_comparator) + (M) + +module Check_accessors2 (M : Accessors2) = + Check_accessors + (struct type ('a, 'b, 'c) t = ('a, 'b) M.t end) + (struct type ('a, 'b, 'c) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (Without_comparator) + (M) + +module Check_accessors3 (M : Accessors3) = + Check_accessors + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.t end) + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.tree end) + (struct type 'a t = 'a end) + (Without_comparator) + (M) + +module Check_accessors3_with_comparator (M : Accessors3_with_comparator) = + Check_accessors + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.t end) + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.tree end) + (struct type 'a t = 'a end) + (With_comparator) + (M) + +module type Creators_generic = sig + type ('k, 'v, 'cmp) t + type ('k, 'v, 'cmp) tree + type 'k key + type ('a, 'cmp, 'z) options + + val empty : ('k, 'cmp, ('k, _, 'cmp) t) options + + val singleton : ('k, 'cmp, 'k key -> 'v -> ('k, 'v, 'cmp) t) options + + val of_sorted_array + : ('k, 'cmp, ('k key * 'v) array -> ('k, 'v, 'cmp) t Or_error.t) options + + val of_sorted_array_unchecked + : ('k, 'cmp, ('k key * 'v) array -> ('k, 'v, 'cmp) t) options + + val of_increasing_iterator_unchecked + : ('k, 'cmp, len:int -> f:(int -> 'k key * 'v) -> ('k, 'v, 'cmp) t) options + + val of_increasing_sequence + : ('k, 'cmp, ('k key * 'v) Sequence.t -> ('k, 'v, 'cmp) t Or_error.t) options + + val of_alist + : ('k, + 'cmp, + ('k key * 'v) list -> [ `Ok of ('k, 'v, 'cmp) t | `Duplicate_key of 'k key ] + ) options + + val of_alist_or_error + : ('k, 'cmp, ('k key * 'v) list -> ('k, 'v, 'cmp) t Or_error.t) options + + val of_alist_exn : ('k, 'cmp, ('k key * 'v) list -> ('k, 'v, 'cmp) t) options + + val of_alist_multi : ('k, 'cmp, ('k key * 'v) list -> ('k, 'v list, 'cmp) t) options + + val of_alist_fold + : ('k, 'cmp, + ('k key * 'v1) list + -> init:'v2 + -> f:('v2 -> 'v1 -> 'v2) + -> ('k, 'v2, 'cmp) t + ) options + + val of_alist_reduce + : ('k, 'cmp, + ('k key * 'v) list + -> f:('v -> 'v -> 'v) + -> ('k, 'v, 'cmp) t + ) options + + val of_iteri + : ('k, 'cmp, + iteri:(f:(key:'k key + -> data:'v + -> unit) + -> unit) + -> [ `Ok of ('k, 'v, 'cmp) t + | `Duplicate_key of 'k key ] + ) options + + val of_tree + : ('k, 'cmp, + ('k key, 'v, 'cmp) tree -> ('k, 'v, 'cmp) t + ) options +end + +module type Creators1 = sig + type 'a t + type 'a tree + type key + val empty : _ t + val singleton : key -> 'a -> 'a t + val of_alist : (key * 'a) list -> [ `Ok of 'a t | `Duplicate_key of key ] + val of_alist_or_error : (key * 'a) list -> 'a t Or_error.t + val of_alist_exn : (key * 'a) list -> 'a t + val of_alist_multi : (key * 'a) list -> 'a list t + val of_alist_fold : (key * 'a) list -> init:'b -> f:('b -> 'a -> 'b) -> 'b t + val of_alist_reduce : (key * 'a) list -> f:('a -> 'a -> 'a) -> 'a t + val of_sorted_array : (key * 'a) array -> 'a t Or_error.t + val of_sorted_array_unchecked : (key * 'a) array -> 'a t + val of_increasing_iterator_unchecked : len:int -> f:(int -> key * 'a) -> 'a t + val of_increasing_sequence : (key * 'a) Sequence.t -> 'a t Or_error.t + val of_iteri : iteri:(f:(key:key -> data:'v -> unit) -> unit) + -> [ `Ok of 'v t | `Duplicate_key of key ] + val of_tree : 'a tree -> 'a t +end + +module type Creators2 = sig + type ('a, 'b) t + type ('a, 'b) tree + val empty : (_, _) t + val singleton : 'a -> 'b -> ('a, 'b) t + val of_alist : ('a * 'b) list -> [ `Ok of ('a, 'b) t | `Duplicate_key of 'a ] + val of_alist_or_error : ('a * 'b) list -> ('a, 'b) t Or_error.t + val of_alist_exn : ('a * 'b) list -> ('a, 'b) t + val of_alist_multi : ('a * 'b) list -> ('a, 'b list) t + val of_alist_fold : ('a * 'b) list -> init:'c -> f:('c -> 'b -> 'c) -> ('a, 'c) t + val of_alist_reduce : ('a * 'b) list -> f:('b -> 'b -> 'b) -> ('a, 'b) t + val of_sorted_array : ('a * 'b) array -> ('a, 'b) t Or_error.t + val of_sorted_array_unchecked : ('a * 'b) array -> ('a, 'b) t + val of_increasing_iterator_unchecked : len:int -> f:(int -> 'a * 'b) -> ('a, 'b) t + val of_increasing_sequence : ('a * 'b) Sequence.t -> ('a, 'b) t Or_error.t + val of_iteri : iteri:(f:(key:'a -> data:'b -> unit) -> unit) + -> [ `Ok of ('a, 'b) t | `Duplicate_key of 'a ] + val of_tree : ('a, 'b) tree -> ('a, 'b) t +end + +module type Creators3_with_comparator = sig + type ('a, 'b, 'cmp) t + type ('a, 'b, 'cmp) tree + val empty : comparator:('a, 'cmp) Comparator.t -> ('a, _, 'cmp) t + val singleton : comparator:('a, 'cmp) Comparator.t -> 'a -> 'b -> ('a, 'b, 'cmp) t + val of_alist + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> [ `Ok of ('a, 'b, 'cmp) t | `Duplicate_key of 'a ] + val of_alist_or_error + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> ('a, 'b, 'cmp) t Or_error.t + val of_alist_exn + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> ('a, 'b, 'cmp) t + val of_alist_multi + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> ('a, 'b list, 'cmp) t + val of_alist_fold + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> init:'c -> f:('c -> 'b -> 'c) -> ('a, 'c, 'cmp) t + val of_alist_reduce + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> f:('b -> 'b -> 'b) -> ('a, 'b, 'cmp) t + val of_sorted_array + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) array -> ('a, 'b, 'cmp) t Or_error.t + val of_sorted_array_unchecked + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) array -> ('a, 'b, 'cmp) t + val of_increasing_iterator_unchecked + : comparator:('a, 'cmp) Comparator.t + -> len:int -> f:(int -> 'a * 'b) -> ('a, 'b, 'cmp) t + val of_increasing_sequence + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) Sequence.t -> ('a, 'b, 'cmp) t Or_error.t + val of_iteri + : comparator:('a, 'cmp) Comparator.t + -> iteri:(f:(key:'a -> data:'b -> unit) -> unit) + -> [ `Ok of ('a, 'b, 'cmp) t | `Duplicate_key of 'a ] + val of_tree + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) tree -> ('a, 'b, 'cmp) t +end + +module Check_creators (T : T3) (Tree : T3) (Key : T1) (Options : T3) + (M : Creators_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) Options.t + with type ('a, 'b, 'c) t := ('a, 'b, 'c) T.t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Tree.t + with type 'a key := 'a Key.t) += struct end + +module Check_creators1 (M : Creators1) = + Check_creators + (struct type ('a, 'b, 'c) t = 'b M.t end) + (struct type ('a, 'b, 'c) t = 'b M.tree end) + (struct type 'a t = M.key end) + (Without_comparator) + (M) + +module Check_creators2 (M : Creators2) = + Check_creators + (struct type ('a, 'b, 'c) t = ('a, 'b) M.t end) + (struct type ('a, 'b, 'c) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (Without_comparator) + (M) + +module Check_creators3_with_comparator (M : Creators3_with_comparator) = + Check_creators + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.t end) + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.tree end) + (struct type 'a t = 'a end) + (With_comparator) + (M) + +module type Creators_and_accessors_generic = sig + include Creators_generic + include Accessors_generic + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) tree + with type 'a key := 'a key + with type ('a, 'b, 'c) options := ('a, 'b, 'c) options +end + +module type Creators_and_accessors1 = sig + include Creators1 + include Accessors1 + with type 'a t := 'a t + with type 'a tree := 'a tree + with type key := key +end + +module type Creators_and_accessors2 = sig + include Creators2 + include Accessors2 + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) tree +end + +module type Creators_and_accessors3_with_comparator = sig + include Creators3_with_comparator + include Accessors3_with_comparator + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) tree +end + +module type S_poly = Creators_and_accessors2 + +module type For_deriving = sig + type ('a, 'b, 'c) t + + module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end + module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Comparator.S with type t := t + end + module type Compare_m = sig end + module type Hash_fold_m = Hasher.S + + val sexp_of_m__t + : (module Sexp_of_m with type t = 'k) + -> ('v -> Sexp.t) + -> ('k, 'v, 'cmp) t + -> Sexp.t + + val m__t_of_sexp + : (module M_of_sexp with type t = 'k and type comparator_witness = 'cmp) + -> (Sexp.t -> 'v) + -> Sexp.t + -> ('k, 'v, 'cmp) t + + val compare_m__t + : (module Compare_m) + -> ('v -> 'v -> int) + -> ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> int + + val hash_fold_m__t + : (module Hash_fold_m with type t = 'k) + -> (Hash.state -> 'v -> Hash.state) + -> (Hash.state -> ('k, 'v, _) t -> Hash.state) +end + +module type Map = sig + (** [Map] is a functional data structure (balanced binary tree) implementing finite maps + over a totally-ordered domain, called a "key". *) + + type ('key, +'value, 'cmp) t + + module Or_duplicate = Or_duplicate + + type ('k, 'cmp) comparator = + (module Comparator.S with type t = 'k and type comparator_witness = 'cmp) + + (** Test if the invariants of the internal AVL search tree hold. *) + val invariants : (_, _, _) t -> bool + + (** Returns a first-class module that can be used to build other map/set/etc. + with the same notion of comparison. *) + val comparator_s : ('a, _, 'cmp) t -> ('a, 'cmp) comparator + + val comparator : ('a, _, 'cmp) t -> ('a, 'cmp) Comparator.t + + (** The empty map. *) + val empty : ('a, 'cmp) comparator -> ('a, 'b, 'cmp) t + + (** A map with one (key, data) pair. *) + val singleton : ('a, 'cmp) comparator -> 'a -> 'b -> ('a, 'b, 'cmp) t + + (** Creates a map from an association list with unique keys. *) + val of_alist + : ('a, 'cmp) comparator + -> ('a * 'b) list + -> [ `Ok of ('a, 'b, 'cmp) t | `Duplicate_key of 'a ] + + (** Creates a map from an association list with unique keys, returning an error if + duplicate ['a] keys are found. *) + val of_alist_or_error + : ('a, 'cmp) comparator + -> ('a * 'b) list -> ('a, 'b, 'cmp) t Or_error.t + + (** Creates a map from an association list with unique keys, raising an exception if + duplicate ['a] keys are found. *) + val of_alist_exn + : ('a, 'cmp) comparator + -> ('a * 'b) list -> ('a, 'b, 'cmp) t + + (** Creates a map from an association list with possibly repeated keys. The values in + the map for a given key appear in the same order as they did in the association + list. *) + val of_alist_multi + : ('a, 'cmp) comparator + -> ('a * 'b) list -> ('a, 'b list, 'cmp) t + + (** Combines an association list into a map, folding together bound values with common + keys. *) + val of_alist_fold + : ('a, 'cmp) comparator + -> ('a * 'b) list -> init:'c -> f:('c -> 'b -> 'c) -> ('a, 'c, 'cmp) t + + (** Combines an association list into a map, reducing together bound values with common + keys. *) + val of_alist_reduce + : ('a, 'cmp) comparator + -> ('a * 'b) list -> f:('b -> 'b -> 'b) -> ('a, 'b, 'cmp) t + + (** [of_iteri ~iteri] behaves like [of_alist], except that instead of taking a concrete + data structure, it takes an iteration function. For instance, to convert a string table + into a map: [of_iteri (module String) ~f:(Hashtbl.iteri table)]. It is faster than + adding the elements one by one. *) + val of_iteri + : ('a, 'cmp) comparator + -> iteri:(f:(key:'a -> data:'b -> unit) -> unit) + -> [ `Ok of ('a, 'b, 'cmp) t + | `Duplicate_key of 'a ] + + (** Creates a map from a sorted array of key-data pairs. The input array must be sorted + (either in ascending or descending order), as given by the relevant comparator, and + must not contain duplicate keys. If either of these conditions does not hold, + an error is returned. *) + val of_sorted_array + : ('a, 'cmp) comparator + -> ('a * 'b) array -> ('a, 'b, 'cmp) t Or_error.t + + (** Like [of_sorted_array] except that it returns a map with broken invariants when an + [Error] would have been returned. *) + val of_sorted_array_unchecked + : ('a, 'cmp) comparator + -> ('a * 'b) array -> ('a, 'b, 'cmp) t + + (** [of_increasing_iterator_unchecked c ~len ~f] behaves like [of_sorted_array_unchecked c + (Array.init len ~f)], with the additional restriction that a decreasing order is not + supported. The advantage is not requiring you to allocate an intermediate array. [f] + will be called with 0, 1, ... [len - 1], in order. *) + val of_increasing_iterator_unchecked + : ('a, 'cmp) comparator + -> len:int + -> f:(int -> 'a * 'b) + -> ('a, 'b, 'cmp) t + + (** [of_increasing_sequence c seq] behaves like [of_sorted_array c (Sequence.to_array + seq)], but does not allocate the intermediate array. + + The sequence will be folded over once, and the additional time complexity is {e O(n)}. + *) + val of_increasing_sequence + : ('k, 'cmp) comparator + -> ('k * 'v) Sequence.t + -> ('k, 'v, 'cmp) t Or_error.t + + (** Tests whether a map is empty. *) + val is_empty : (_, _, _) t -> bool + + (** [length map] returns the number of elements in [map]. O(1), but [Tree.length] is + O(n). *) + val length : (_, _, _) t -> int + + (** Returns a new map with the specified new binding; if the key was already bound, its + previous binding disappears. *) + val set : ('k, 'v, 'cmp) t -> key:'k -> data:'v -> ('k, 'v, 'cmp) t + + (** [add t ~key ~data] adds a new entry to [t] mapping [key] to [data] and returns [`Ok] + with the new map, or if [key] is already present in [t], returns [`Duplicate]. *) + val add : ('k, 'v, 'cmp) t -> key:'k -> data:'v -> ('k, 'v, 'cmp) t Or_duplicate.t + val add_exn : ('k, 'v, 'cmp) t -> key:'k -> data:'v -> ('k, 'v, 'cmp) t + + (** If [key] is not present then add a singleton list, otherwise, cons data onto the + head of the existing list. *) + val add_multi + : ('k, 'v list, 'cmp) t + -> key:'k + -> data:'v + -> ('k, 'v list, 'cmp) t + + (** If the key is present, then remove its head element; if the result is empty, remove + the key. *) + val remove_multi + : ('k, 'v list, 'cmp) t + -> 'k + -> ('k, 'v list, 'cmp) t + + (** Returns the value bound to the given key, or the empty list if there is none. *) + val find_multi + : ('k, 'v list, 'cmp) t + -> 'k + -> 'v list + + (** [change t key ~f] returns a new map [m] that is the same as [t] on all keys except + for [key], and whose value for [key] is defined by [f], i.e., [find m key = f (find + t key)]. *) + val change + : ('k, 'v, 'cmp) t + -> 'k + -> f:('v option -> 'v option) + -> ('k, 'v, 'cmp) t + + (** [update t key ~f] is [change t key ~f:(fun o -> Some (f o))]. *) + val update + : ('k, 'v, 'cmp) t + -> 'k + -> f:('v option -> 'v) + -> ('k, 'v, 'cmp) t + + (** Returns [Some value] bound to the given key, or [None] if none exists. *) + val find : ('k, 'v, 'cmp) t -> 'k -> 'v option + + (** Returns the value bound to the given key, raising [Caml.Not_found] of [Not_found_s] + if none exists. *) + val find_exn : ('k, 'v, 'cmp) t -> 'k -> 'v + + (** Returns a new map with any binding for the key in question removed. *) + val remove : ('k, 'v, 'cmp) t -> 'k -> ('k, 'v, 'cmp) t + + (** [mem map key] tests whether [map] contains a binding for [key]. *) + val mem : ('k, _, 'cmp) t -> 'k -> bool + + val iter_keys : ('k, _, _) t -> f:('k -> unit) -> unit + val iter : (_, 'v, _) t -> f:('v -> unit) -> unit + val iteri : ('k, 'v, _) t -> f:(key:'k -> data:'v -> unit) -> unit + + (** Iterates two maps side by side. The complexity of this function is O(M + N). If two + inputs are [[(0, a); (1, a)]] and [[(1, b); (2, b)]], [f] will be called with [[(0, + `Left a); (1, `Both (a, b)); (2, `Right b)]]. *) + val iter2 + : ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> f:(key:'k -> data:[ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] -> unit) + -> unit + + (** Returns a new map with bound values replaced by [f] applied to the bound values.*) + val map : ('k, 'v1, 'cmp) t -> f:('v1 -> 'v2) -> ('k, 'v2, 'cmp) t + + (** Like [map], but the passed function takes both [key] and [data] as arguments. *) + val mapi + : ('k, 'v1, 'cmp) t + -> f:(key:'k -> data:'v1 -> 'v2) + -> ('k, 'v2, 'cmp) t + + (** Folds over keys and data in the map in increasing order of [key]. *) + val fold : ('k, 'v, _) t -> init:'a -> f:(key:'k -> data:'v -> 'a -> 'a) -> 'a + + (** Folds over keys and data in the map in decreasing order of [key]. *) + val fold_right : ('k, 'v, _) t -> init:'a -> f:(key:'k -> data:'v -> 'a -> 'a) -> 'a + + (** Folds over two maps side by side, like [iter2]. *) + val fold2 + : ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> init:'a + -> f:(key:'k -> data:[ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] -> 'a -> 'a) + -> 'a + + (** [filter], [filteri], [filter_keys], [filter_map], and [filter_mapi] run in O(n * lg + n) time; they simply accumulate each key & data pair retained by [f] into a new map + using [add]. *) + val filter_keys : ('k, 'v, 'cmp) t -> f:('k -> bool) -> ('k, 'v, 'cmp) t + val filter : ('k, 'v, 'cmp) t -> f:('v -> bool) -> ('k, 'v, 'cmp) t + val filteri : ('k, 'v, 'cmp) t -> f:(key:'k -> data:'v -> bool) -> ('k, 'v, 'cmp) t + + (** Returns a new map with bound values filtered by [f] applied to the bound values. *) + val filter_map + : ('k, 'v1, 'cmp) t + -> f:('v1 -> 'v2 option) + -> ('k, 'v2, 'cmp) t + + (** Like [filter_map], but the passed function takes both [key] and [data] as + arguments. *) + val filter_mapi + : ('k, 'v1, 'cmp) t + -> f:(key:'k -> data:'v1 -> 'v2 option) + -> ('k, 'v2, 'cmp) t + + (** [partition_mapi t ~f] returns two new [t]s, with each key in [t] appearing in + exactly one of the resulting maps depending on its mapping in [f]. *) + val partition_mapi + : ('k, 'v1, 'cmp) t + -> f:(key:'k -> data:'v1 -> [`Fst of 'v2 | `Snd of 'v3]) + -> ('k, 'v2, 'cmp) t * ('k, 'v3, 'cmp) t + + (** [partition_map t ~f = partition_mapi t ~f:(fun ~key:_ ~data -> f data)] *) + val partition_map + : ('k, 'v1, 'cmp) t + -> f:('v1 -> [`Fst of 'v2 | `Snd of 'v3]) + -> ('k, 'v2, 'cmp) t * ('k, 'v3, 'cmp) t + + (** + {[ + partitioni_tf t ~f + = + partition_mapi t ~f:(fun ~key ~data -> + if f ~key ~data + then `Fst data + else `Snd data) + ]} *) + val partitioni_tf + : ('k, 'v, 'cmp) t + -> f:(key:'k -> data:'v -> bool) + -> ('k, 'v, 'cmp) t * ('k, 'v, 'cmp) t + + (** [partition_tf t ~f = partitioni_tf t ~f:(fun ~key:_ ~data -> f data)] *) + val partition_tf + : ('k, 'v, 'cmp) t + -> f:('v -> bool) + -> ('k, 'v, 'cmp) t * ('k, 'v, 'cmp) t + + (** Returns a total ordering between maps. The first argument is a total ordering used + to compare data associated with equal keys in the two maps. *) + val compare_direct + : ('v -> 'v -> int) + -> ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> int + + (** Hash function: a building block to use when hashing data structures containing maps in + them. [hash_fold_direct hash_fold_key] is compatible with [compare_direct] iff + [hash_fold_key] is compatible with [(comparator m).compare] of the map [m] being + hashed. *) + val hash_fold_direct + : 'k Hash.folder + -> 'v Hash.folder + -> ('k, 'v, 'cmp) t Hash.folder + + (** [equal cmp m1 m2] tests whether the maps [m1] and [m2] are equal, that is, contain + the same keys and associate each key with the same value. [cmp] is the equality + predicate used to compare the values associated with the keys. *) + val equal + : ('v -> 'v -> bool) + -> ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> bool + + (** Returns a list of the keys in the given map. *) + val keys : ('k, _, _) t -> 'k list + + (** Returns a list of the data in the given map. *) + val data : (_, 'v, _) t -> 'v list + + (** Creates an association list from the given map. *) + val to_alist + : ?key_order : [ `Increasing | `Decreasing ] (** default is [`Increasing] *) + -> ('k, 'v, _) t + -> ('k * 'v) list + + val validate : name:('k -> string) -> 'v Validate.check -> ('k, 'v, _) t Validate.check + + (** {2 Additional operations on maps} *) + + (** Merges two maps. The runtime is O(length(t1) + length(t2)). You shouldn't use this + function to merge a list of maps; consider using [merge_skewed] instead. *) + val merge + : ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> f:(key:'k + -> [ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] + -> 'v3 option) + -> ('k, 'v3, 'cmp) t + + (** A special case of [merge], [merge_skewed t1 t2] is a map containing all the + bindings of [t1] and [t2]. Bindings that appear in both [t1] and [t2] are + combined into a single value using the [combine] function. In a call + [combine ~key v1 v2], the value [v1] comes from [t1] and [v2] from [t2]. + + The runtime of [merge_skewed] is [O(l1 * log(l2))], where [l1] is the length + of the smaller map and [l2] the length of the larger map. This is likely to + be faster than [merge] when one of the maps is a lot smaller, or when you + merge a list of maps. *) + val merge_skewed + : ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> combine:(key:'k -> 'v -> 'v -> 'v) + -> ('k, 'v, 'cmp) t + + module Symmetric_diff_element : sig + type ('k, 'v) t = 'k * [ `Left of 'v | `Right of 'v | `Unequal of 'v * 'v ] + [@@deriving_inline compare, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('k -> 'k -> int) -> + ('v -> 'v -> int) -> ('k, 'v) t -> ('k, 'v) t -> int + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('k,'v) t := ('k, 'v) t + end[@@ocaml.doc "@inline"] + [@@@end] + end + + (** [symmetric_diff t1 t2 ~data_equal] returns a list of changes between [t1] and [t2]. + It is intended to be efficient in the case where [t1] and [t2] share a large amount of + structure. The keys in the output sequence will be in sorted order. *) + val symmetric_diff + : ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> data_equal:('v -> 'v -> bool) + -> ('k, 'v) Symmetric_diff_element.t Sequence.t + + (** [min_elt map] returns [Some (key, data)] pair corresponding to the minimum key in + [map], or [None] if empty. *) + val min_elt : ('k, 'v, _) t -> ('k * 'v) option + val min_elt_exn : ('k, 'v, _) t -> 'k * 'v + + (** [max_elt map] returns [Some (key, data)] pair corresponding to the maximum key in + [map], or [None] if [map] is empty. *) + val max_elt : ('k, 'v, _) t -> ('k * 'v) option + val max_elt_exn : ('k, 'v, _) t -> 'k * 'v + + (** These functions have the same semantics as similar functions in [List]. *) + + val for_all : ('k, 'v, _) t -> f:( 'v -> bool) -> bool + val for_alli : ('k, 'v, _) t -> f:(key:'k -> data:'v -> bool) -> bool + val exists : ('k, 'v, _) t -> f:( 'v -> bool) -> bool + val existsi : ('k, 'v, _) t -> f:(key:'k -> data:'v -> bool) -> bool + val count : ('k, 'v, _) t -> f:( 'v -> bool) -> int + val counti : ('k, 'v, _) t -> f:(key:'k -> data:'v -> bool) -> int + + (** [split t key] returns a map of keys strictly less than [key], the mapping of [key] if + any, and a map of keys strictly greater than [key]. + + Runtime is O(m + log n), where n is the size of the input map and m is the size of + the smaller of the two output maps. The O(m) term is due to the need to calculate + the length of the output maps. *) + val split + : ('k, 'v, 'cmp) t + -> 'k + -> ('k, 'v, 'cmp) t * ('k * 'v) option * ('k, 'v, 'cmp) t + + (** [append ~lower_part ~upper_part] returns [`Ok map] where [map] contains all the + [(key, value)] pairs from the two input maps if all the keys from [lower_part] are + less than all the keys from [upper_part]. Otherwise it returns + [`Overlapping_key_ranges]. + + Runtime is O(log n) where n is the size of the larger input map. This can be + significantly faster than [Map.merge] or repeated [Map.add]. + + {[ + assert (match Map.append ~lower_part ~upper_part with + | `Ok whole_map -> + Map.to_alist whole_map + = List.append (to_alist lower_part) (to_alist upper_part) + | `Overlapping_key_ranges -> true); + ]} *) + val append + : lower_part:('k, 'v, 'cmp) t + -> upper_part:('k, 'v, 'cmp) t + -> [ `Ok of ('k, 'v, 'cmp) t + | `Overlapping_key_ranges ] + + (** [subrange t ~lower_bound ~upper_bound] returns a map containing all the entries from + [t] whose keys lie inside the interval indicated by [~lower_bound] and + [~upper_bound]. If this interval is empty, an empty map is returned. + + Runtime is O(m + log n), where n is the size of the input map and m is the size of + the output map. The O(m) term is due to the need to calculate the length of the + output map. *) + val subrange + : ('k, 'v, 'cmp) t + -> lower_bound:'k Maybe_bound.t + -> upper_bound:'k Maybe_bound.t + -> ('k, 'v, 'cmp) t + + (** [fold_range_inclusive t ~min ~max ~init ~f] folds [f] (with initial value [~init]) + over all keys (and their associated values) that are in the range [[min, max]] + (inclusive). *) + val fold_range_inclusive + : ('k, 'v, 'cmp) t + -> min:'k + -> max:'k + -> init:'a + -> f:(key:'k -> data:'v -> 'a -> 'a) + -> 'a + + (** [range_to_alist t ~min ~max] returns an associative list of the elements whose keys + lie in [[min, max]] (inclusive), with the smallest key being at the head of the + list. *) + val range_to_alist : ('k, 'v, 'cmp) t -> min:'k -> max:'k -> ('k * 'v) list + + (** [closest_key t dir k] returns the [(key, value)] pair in [t] with [key] closest to + [k] that satisfies the given inequality bound. + + For example, [closest_key t `Less_than k] would be the pair with the closest key to + [k] where [key < k]. + + [to_sequence] can be used to get the same results as [closest_key]. It is less + efficient for individual lookups but more efficient for finding many elements starting + at some value. *) + val closest_key + : ('k, 'v, 'cmp) t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'k + -> ('k * 'v) option + + (** [nth t n] finds the (key, value) pair of rank n (i.e., such that there are exactly n + keys strictly less than the found key), if one exists. O(log(length t) + n) time. *) + val nth : ('k, 'v, _) t -> int -> ('k * 'v) option + val nth_exn : ('k, 'v, _) t -> int -> ('k * 'v) + + (** [rank t k] If [k] is in [t], returns the number of keys strictly less than [k] in + [t], and [None] otherwise. *) + val rank : ('k, 'v, 'cmp) t -> 'k -> int option + + (** [to_sequence ?order ?keys_greater_or_equal_to ?keys_less_or_equal_to t] gives a + sequence of key-value pairs between [keys_less_or_equal_to] and + [keys_greater_or_equal_to] inclusive, presented in [order]. If + [keys_greater_or_equal_to > keys_less_or_equal_to], the sequence is empty. Cost is + O(log n) up front and amortized O(1) to produce each element. *) + val to_sequence + : ?order : [ `Increasing_key (** default *) | `Decreasing_key ] + -> ?keys_greater_or_equal_to : 'k + -> ?keys_less_or_equal_to : 'k + -> ('k, 'v, 'cmp) t + -> ('k * 'v) Sequence.t + + (** [M] is meant to be used in combination with OCaml applicative functor types: + + {[ + type string_to_int_map = int Map.M(String).t + ]} + + which stands for: + + {[ + type string_to_int_map = (String.t, int, String.comparator_witness) Map.t + ]} + + The point is that [int Map.M(String).t] supports deriving, whereas the second syntax + doesn't (because there is no such thing as, say, [String.sexp_of_comparator_witness] + -- instead you would want to pass the comparator directly). + + In addition, the requirements to use [@@deriving_inline][@@@end] on the key module are only those + needed to satisfy what you are trying to derive on the map itself. Say you write: + + {[ + type t = int Map.M(X).t [@@deriving_inline hash][@@@end] + ]} + + then this will be well typed exactly if [X] contains at least: + - a type [t] with no parameters + - a comparator witness + - a [hash_fold_t] function with the right type *) + module M (K : sig type t type comparator_witness end) : sig + type nonrec 'v t = (K.t, 'v, K.comparator_witness) t + end + + include For_deriving + with type ('key, 'value, 'cmp) t := ('key, 'value, 'cmp) t + + (** A polymorphic Map. *) + module Poly : S_poly + with type ('key, +'value) t = ('key, 'value, Comparator.Poly.comparator_witness) t + + (** [Using_comparator] is a similar interface as the toplevel of [Map], except the + functions take a [~comparator:('k, 'cmp) Comparator.t], whereas the functions at the + toplevel of [Map] take a [('k, 'cmp) comparator]. *) + module Using_comparator : sig + type nonrec ('k, +'v, 'cmp) t = ('k, 'v, 'cmp) t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('k -> Ppx_sexp_conv_lib.Sexp.t) -> + ('v -> Ppx_sexp_conv_lib.Sexp.t) -> + ('cmp -> Ppx_sexp_conv_lib.Sexp.t) -> + ('k, 'v, 'cmp) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + val t_of_sexp_direct + : comparator:('k, 'cmp) Comparator.t + -> (Sexp.t -> 'k) + -> (Sexp.t -> 'v) + -> Sexp.t + -> ('k, 'v, 'cmp) t + + module Tree : sig + type ('k, +'v, 'cmp) t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('k -> Ppx_sexp_conv_lib.Sexp.t) -> + ('v -> Ppx_sexp_conv_lib.Sexp.t) -> + ('cmp -> Ppx_sexp_conv_lib.Sexp.t) -> + ('k, 'v, 'cmp) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + val t_of_sexp_direct + : comparator:('k, 'cmp) Comparator.t + -> (Sexp.t -> 'k) + -> (Sexp.t -> 'v) + -> Sexp.t + -> ('k, 'v, 'cmp) t + + include Creators_and_accessors3_with_comparator + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) t + + val empty_without_value_restriction : (_, _, _) t + end + + include Accessors3 + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Tree.t + include Creators3_with_comparator + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Tree.t + + val comparator : ('a, _, 'cmp) t -> ('a, 'cmp) Comparator.t + + val hash_fold_direct + : 'k Hash.folder + -> 'v Hash.folder + -> ('k, 'v, 'cmp) t Hash.folder + + (** To get around the value restriction, apply the functor and include it. You + can see an example of this in the [Poly] submodule below. *) + module Empty_without_value_restriction (K : Comparator.S1) : sig + val empty : ('a K.t, 'v, K.comparator_witness) t + end + end + + + (** {2 Modules and module types for extending [Map]} + + For use in extensions of Base, like [Core_kernel]. *) + + module With_comparator = With_comparator + module With_first_class_module = With_first_class_module + module Without_comparator = Without_comparator + + module type For_deriving = For_deriving + + module type S_poly = S_poly + module type Accessors1 = Accessors1 + module type Accessors2 = Accessors2 + module type Accessors3 = Accessors3 + module type Accessors3_with_comparator = Accessors3_with_comparator + module type Accessors_generic = Accessors_generic + module type Creators1 = Creators1 + module type Creators2 = Creators2 + module type Creators3_with_comparator = Creators3_with_comparator + module type Creators_and_accessors1 = Creators_and_accessors1 + module type Creators_and_accessors2 = Creators_and_accessors2 + module type Creators_and_accessors3_with_comparator = Creators_and_accessors3_with_comparator + module type Creators_and_accessors_generic = Creators_and_accessors_generic + module type Creators_generic = Creators_generic +end diff --git a/src/maybe_bound.ml b/src/maybe_bound.ml new file mode 100644 index 0000000..16637ba --- /dev/null +++ b/src/maybe_bound.ml @@ -0,0 +1,164 @@ +open! Import + +type 'a t = Incl of 'a | Excl of 'a | Unbounded [@@deriving_inline enumerate, sexp] +let all : 'a . 'a list -> 'a t list = + fun _all_of_a -> + Ppx_enumerate_lib.List.append + (let rec map l acc = + match l with + | [] -> Ppx_enumerate_lib.List.rev acc + | enumerate__001_::l -> map l ((Incl enumerate__001_) :: acc) in + map _all_of_a []) + (Ppx_enumerate_lib.List.append + (let rec map l acc = + match l with + | [] -> Ppx_enumerate_lib.List.rev acc + | enumerate__002_::l -> map l ((Excl enumerate__002_) :: acc) in + map _all_of_a []) [Unbounded]) +let t_of_sexp : type a. + (Ppx_sexp_conv_lib.Sexp.t -> a) -> Ppx_sexp_conv_lib.Sexp.t -> a t = + let _tp_loc = "src/maybe_bound.ml.t" in + fun _of_a -> + function + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("incl"|"Incl" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_a v0 in Incl v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc _tag + _sexp) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("excl"|"Excl" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_a v0 in Excl v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc _tag + _sexp) + | Ppx_sexp_conv_lib.Sexp.Atom ("unbounded"|"Unbounded") -> Unbounded + | Ppx_sexp_conv_lib.Sexp.Atom ("incl"|"Incl") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.Atom ("excl"|"Excl") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("unbounded"|"Unbounded"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as + sexp -> + Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_t : type a. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + function + | Incl v0 -> + let v0 = _of_a v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Incl"; v0] + | Excl v0 -> + let v0 = _of_a v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Excl"; v0] + | Unbounded -> Ppx_sexp_conv_lib.Sexp.Atom "Unbounded" +[@@@end] + +type interval_comparison = + | Below_lower_bound + | In_range + | Above_upper_bound +[@@deriving_inline sexp, compare, hash] +let interval_comparison_of_sexp : + Ppx_sexp_conv_lib.Sexp.t -> interval_comparison = + let _tp_loc = "src/maybe_bound.ml.interval_comparison" in + function + | Ppx_sexp_conv_lib.Sexp.Atom ("below_lower_bound"|"Below_lower_bound") -> + Below_lower_bound + | Ppx_sexp_conv_lib.Sexp.Atom ("in_range"|"In_range") -> In_range + | Ppx_sexp_conv_lib.Sexp.Atom ("above_upper_bound"|"Above_upper_bound") -> + Above_upper_bound + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("below_lower_bound"|"Below_lower_bound"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("in_range"|"In_range"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("above_upper_bound"|"Above_upper_bound"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as sexp + -> Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_interval_comparison : + interval_comparison -> Ppx_sexp_conv_lib.Sexp.t = + function + | Below_lower_bound -> Ppx_sexp_conv_lib.Sexp.Atom "Below_lower_bound" + | In_range -> Ppx_sexp_conv_lib.Sexp.Atom "In_range" + | Above_upper_bound -> Ppx_sexp_conv_lib.Sexp.Atom "Above_upper_bound" +let compare_interval_comparison : + interval_comparison -> interval_comparison -> int = + Ppx_compare_lib.polymorphic_compare +let (hash_fold_interval_comparison : + Ppx_hash_lib.Std.Hash.state -> + interval_comparison -> Ppx_hash_lib.Std.Hash.state) + = + (fun hsv -> + fun arg -> + match arg with + | Below_lower_bound -> Ppx_hash_lib.Std.Hash.fold_int hsv 0 + | In_range -> Ppx_hash_lib.Std.Hash.fold_int hsv 1 + | Above_upper_bound -> Ppx_hash_lib.Std.Hash.fold_int hsv 2 : + Ppx_hash_lib.Std.Hash.state -> + interval_comparison -> Ppx_hash_lib.Std.Hash.state) +let (hash_interval_comparison : + interval_comparison -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in + hash_fold_interval_comparison hsv arg) in + fun x -> func x +[@@@end] + +let map t ~f = + match t with + | Incl incl -> Incl (f incl) + | Excl excl -> Excl (f excl) + | Unbounded -> Unbounded + +let is_lower_bound t ~of_:a ~compare = + match t with + | Incl incl -> compare incl a <= 0 + | Excl excl -> compare excl a < 0 + | Unbounded -> true + +let is_upper_bound t ~of_:a ~compare = + match t with + | Incl incl -> compare a incl <= 0 + | Excl excl -> compare a excl < 0 + | Unbounded -> true + +let bounds_crossed ~lower ~upper ~compare = + match lower with + | Unbounded -> false + | (Incl lower | Excl lower) -> + match upper with + | Unbounded -> false + | (Incl upper | Excl upper) -> + compare lower upper > 0 + +let check_interval_exn ~lower ~upper ~compare = + if bounds_crossed ~lower ~upper ~compare then + failwith "Maybe_bound.compare_to_interval_exn: lower bound > upper bound" + +let compare_to_interval_exn ~lower ~upper a ~compare = + check_interval_exn ~lower ~upper ~compare; + if not (is_lower_bound lower ~of_:a ~compare) then Below_lower_bound else + if not (is_upper_bound upper ~of_:a ~compare) then Above_upper_bound else + In_range + +let interval_contains_exn ~lower ~upper a ~compare = + match compare_to_interval_exn ~lower ~upper a ~compare with + | In_range -> true + | Below_lower_bound + | Above_upper_bound -> false + diff --git a/src/maybe_bound.mli b/src/maybe_bound.mli new file mode 100644 index 0000000..3553e56 --- /dev/null +++ b/src/maybe_bound.mli @@ -0,0 +1,63 @@ +(** Used for specifying a bound (either upper or lower) as inclusive, exclusive, or + unbounded. *) + +open! Import + +type 'a t = Incl of 'a | Excl of 'a | Unbounded [@@deriving_inline enumerate, sexp] +include +sig + [@@@ocaml.warning "-32"] + val all : 'a list -> 'a t list + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +val map : 'a t -> f:('a -> 'b) -> 'b t + +val is_lower_bound : 'a t -> of_:'a -> compare:('a -> 'a -> int) -> bool +val is_upper_bound : 'a t -> of_:'a -> compare:('a -> 'a -> int) -> bool + +(** [interval_contains_exn ~lower ~upper x ~compare] raises if [lower] and [upper] are + crossed. *) +val interval_contains_exn + : lower : 'a t + -> upper : 'a t + -> 'a + -> compare : ('a -> 'a -> int) + -> bool + +(** [bounds_crossed ~lower ~upper ~compare] returns true if [lower > upper]. + + It ignores whether the bounds are [Incl] or [Excl]. *) +val bounds_crossed: lower:'a t -> upper: 'a t -> compare:('a -> 'a -> int) -> bool + +type interval_comparison = + | Below_lower_bound + | In_range + | Above_upper_bound +[@@deriving_inline sexp, compare, hash] +include +sig + [@@@ocaml.warning "-32"] + val sexp_of_interval_comparison : + interval_comparison -> Ppx_sexp_conv_lib.Sexp.t + val interval_comparison_of_sexp : + Ppx_sexp_conv_lib.Sexp.t -> interval_comparison + val compare_interval_comparison : + interval_comparison -> interval_comparison -> int + val hash_fold_interval_comparison : + Ppx_hash_lib.Std.Hash.state -> + interval_comparison -> Ppx_hash_lib.Std.Hash.state + val hash_interval_comparison : + interval_comparison -> Ppx_hash_lib.Std.Hash.hash_value +end[@@ocaml.doc "@inline"] +[@@@end] + +(** [compare_to_interval_exn ~lower ~upper x ~compare] raises if [lower] and [upper] are + crossed. *) +val compare_to_interval_exn + : lower : 'a t + -> upper : 'a t + -> 'a + -> compare : ('a -> 'a -> int) + -> interval_comparison diff --git a/src/monad.ml b/src/monad.ml new file mode 100644 index 0000000..d7ccf25 --- /dev/null +++ b/src/monad.ml @@ -0,0 +1,107 @@ +open! Import + +module List = List0 + +include Monad_intf + +module type Basic_general = sig + type ('a, 'i, 'j, 'd, 'e) t + + val bind + : ('a, 'i, 'j, 'd, 'e) t + -> f:('a -> ('b, 'j, 'k, 'd, 'e) t) + -> ('b, 'i, 'k, 'd, 'e) t + + val map + : [ `Define_using_bind + | `Custom of (('a, 'i, 'j, 'd, 'e) t -> f:('a -> 'b) -> ('b, 'i, 'j, 'd, 'e) t) + ] + + val return : 'a -> ('a, 'i, 'i, 'd, 'e) t +end + +module Make_general (M : Basic_general) = struct + + let bind = M.bind + let return = M.return + + let map_via_bind ma ~f = M.bind ma ~f:(fun a -> M.return (f a)) + + let map = + match M.map with + | `Define_using_bind -> map_via_bind + | `Custom x -> x + + module Monad_infix = struct + let (>>=) t f = bind t ~f + let (>>|) t f = map t ~f + end + include Monad_infix + + module Let_syntax = struct + + let return = return + include Monad_infix + + module Let_syntax = struct + let return = return + let bind = bind + let map = map + let both a b = a >>= fun a -> b >>| fun b -> (a, b) + module Open_on_rhs = struct end + end + end + + let join t = t >>= fun t' -> t' + + let ignore_m t = map t ~f:(fun _ -> ()) + + let all = + let rec loop vs = function + | [] -> return (List.rev vs) + | t :: ts -> t >>= fun v -> loop (v :: vs) ts + in + fun ts -> loop [] ts + + let rec all_unit = function + | [] -> return () + | t :: ts -> t >>= fun () -> all_unit ts + + let all_ignore = all_unit + +end + +module Make_indexed (M : Basic_indexed) + : S_indexed with type ('a, 'i, 'j) t := ('a, 'i, 'j) M.t = + Make_general (struct + type ('a, 'i, 'j, 'd, 'e) t = ('a, 'i, 'j) M.t + include (M : Basic_indexed with type ('a, 'b, 'c) t := ('a, 'b, 'c) M.t) + end) + +module Make3 (M : Basic3) : S3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) M.t = + Make_general (struct + type ('a, 'i, 'j, 'd, 'e) t = ('a, 'd, 'e) M.t + include (M : Basic3 with type ('a, 'b, 'c) t := ('a, 'b, 'c) M.t) + end) + +module Make2 (M : Basic2) : S2 with type ('a, 'd) t := ('a, 'd) M.t = + Make_general (struct + type ('a, 'i, 'j, 'd, 'e) t = ('a, 'd) M.t + include (M : Basic2 with type ('a, 'b) t := ('a, 'b) M.t) + end) + +module Make (M : Basic) : S with type 'a t := 'a M.t = + Make_general (struct + type ('a, 'i, 'j, 'd, 'e) t = 'a M.t + include (M : Basic with type 'a t := 'a M.t) + end) + +module Ident = struct + type 'a t = 'a + include Make (struct + type nonrec 'a t = 'a t + let bind a ~f = f a + let return a = a + let map = `Custom (fun a ~f -> f a) + end) +end diff --git a/src/monad.mli b/src/monad.mli new file mode 100644 index 0000000..65be892 --- /dev/null +++ b/src/monad.mli @@ -0,0 +1 @@ +include Monad_intf.Monad (** @inline *) diff --git a/src/monad_intf.ml b/src/monad_intf.ml new file mode 100644 index 0000000..34ae035 --- /dev/null +++ b/src/monad_intf.ml @@ -0,0 +1,365 @@ +open! Import + +module type Basic = sig + type 'a t + val bind : 'a t -> f:('a -> 'b t) -> 'b t + val return : 'a -> 'a t + (** The following identities ought to hold (for some value of =): + + - [return x >>= f = f x] + - [t >>= fun x -> return x = t] + - [(t >>= f) >>= g = t >>= fun x -> (f x >>= g)] + + Note: [>>=] is the infix notation for [bind]) *) + + (** The [map] argument to [Monad.Make] says how to implement the monad's [map] function. + [`Define_using_bind] means to define [map t ~f = bind t ~f:(fun a -> return (f a))]. + [`Custom] overrides the default implementation, presumably with something more + efficient. + + Some other functions returned by [Monad.Make] are defined in terms of [map], so + passing in a more efficient [map] will improve their efficiency as well. *) + val map : [ `Define_using_bind + | `Custom of ('a t -> f:('a -> 'b) -> 'b t) + ] +end + +module type Infix = sig + type 'a t + + (** [t >>= f] returns a computation that sequences the computations represented by two + monad elements. The resulting computation first does [t] to yield a value [v], and + then runs the computation returned by [f v]. *) + val (>>=) : 'a t -> ('a -> 'b t) -> 'b t + + (** [t >>| f] is [t >>= (fun a -> return (f a))]. *) + val (>>|) : 'a t -> ('a -> 'b) -> 'b t +end + +(** Opening a module of this type allows one to use the [%bind] and [%map] syntax + extensions defined by ppx_let, and brings [return] into scope. *) +module type Syntax = sig + type 'a t + module Let_syntax : sig + + (** These are convenient to have in scope when programming with a monad: *) + + val return : 'a -> 'a t + include Infix with type 'a t := 'a t + + module Let_syntax : sig + val return : 'a -> 'a t + val bind : 'a t -> f:('a -> 'b t) -> 'b t + val map : 'a t -> f:('a -> 'b) -> 'b t + val both : 'a t -> 'b t -> ('a * 'b) t + module Open_on_rhs : sig end + end + end +end + +module type S_without_syntax = sig + type 'a t + include Infix with type 'a t := 'a t + + module Monad_infix : Infix with type 'a t := 'a t + + (** [bind t ~f] = [t >>= f] *) + val bind : 'a t -> f:('a -> 'b t) -> 'b t + + (** [return v] returns the (trivial) computation that returns v. *) + val return : 'a -> 'a t + + (** [map t ~f] is t >>| f. *) + val map : 'a t -> f:('a -> 'b) -> 'b t + + (** [join t] is [t >>= (fun t' -> t')]. *) + val join : 'a t t -> 'a t + + (** [ignore_m t] is [map t ~f:(fun _ -> ())]. [ignore_m] used to be called [ignore], + but we decided that was a bad name, because it shadowed the widely used + [Caml.ignore]. Some monads still do [let ignore = ignore_m] for historical + reasons. *) + val ignore_m : 'a t -> unit t + + + val all : 'a t list -> 'a list t + + (** Like [all], but ensures that every monadic value in the list produces a unit value, + all of which are discarded rather than being collected into a list. *) + val all_unit : unit t list -> unit t + + val all_ignore : unit t list -> unit t [@@deprecated "[since 2018-02] Use [all_unit]"] +end + +module type S = sig + type 'a t + include S_without_syntax with type 'a t := 'a t + include Syntax with type 'a t := 'a t +end + +(** Multi parameter monad. The second parameter gets unified across all the computation. + This is used to encode monads working on a multi parameter data structure like + ([('a,'b) result]). *) +module type Basic2 = sig + type ('a, 'e) t + val bind : ('a, 'e) t -> f:('a -> ('b, 'e) t) -> ('b, 'e) t + val map : [ `Define_using_bind + | `Custom of (('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t) + ] + val return : 'a -> ('a, _) t +end + +(** Same as Infix, except the monad type has two arguments. The second is always just + passed through. *) +module type Infix2 = sig + type ('a, 'e) t + val (>>=) : ('a, 'e) t -> ('a -> ('b, 'e) t) -> ('b, 'e) t + val (>>|) : ('a, 'e) t -> ('a -> 'b) -> ('b, 'e) t +end + +module type Syntax2 = sig + type ('a, 'e) t + + module Let_syntax : sig + val return : 'a -> ('a, _) t + include Infix2 with type ('a,'e) t := ('a,'e) t + module Let_syntax : sig + val return : 'a -> ('a, _) t + val bind : ('a, 'e) t -> f:('a -> ('b, 'e) t) -> ('b, 'e) t + val map : ('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t + val both : ('a, 'e) t -> ('b, 'e) t -> ('a * 'b, 'e) t + module Open_on_rhs : sig end + end + end +end + +(** The same as S except the monad type has two arguments. The second is always just + passed through. *) +module type S2 = sig + type ('a, 'e) t + include Infix2 with type ('a, 'e) t := ('a, 'e) t + include Syntax2 with type ('a, 'e) t := ('a, 'e) t + + module Monad_infix : Infix2 with type ('a, 'e) t := ('a, 'e) t + + val bind : ('a, 'e) t -> f:('a -> ('b, 'e) t) -> ('b, 'e) t + + val return : 'a -> ('a, _) t + + val map : ('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t + + val join : (('a, 'e) t, 'e) t -> ('a, 'e) t + + val ignore_m : (_, 'e) t -> (unit, 'e) t + + val all : ('a, 'e) t list -> ('a list, 'e) t + + val all_unit : (unit, 'e) t list -> (unit, 'e) t + + val all_ignore : (unit, 'e) t list -> (unit, 'e) t + [@@deprecated "[since 2018-02] Use [all_unit]"] +end + +(** Multi parameter monad. The second and third parameters get unified across all the + computation. *) +module type Basic3 = sig + type ('a, 'd, 'e) t + val bind : ('a, 'd, 'e) t -> f:('a -> ('b, 'd, 'e) t) -> ('b, 'd, 'e) t + val map : [ `Define_using_bind + | `Custom of (('a, 'd, 'e) t -> f:('a -> 'b) -> ('b, 'd, 'e) t) + ] + val return : 'a -> ('a, _, _) t +end + +(** Same as Infix, except the monad type has three arguments. The second and third are + always just passed through. *) +module type Infix3 = sig + type ('a, 'd, 'e) t + val (>>=) : ('a, 'd, 'e) t -> ('a -> ('b, 'd, 'e) t) -> ('b, 'd, 'e) t + val (>>|) : ('a, 'd, 'e) t -> ('a -> 'b) -> ('b, 'd, 'e) t +end + +module type Syntax3 = sig + type ('a, 'd, 'e) t + + module Let_syntax : sig + val return : 'a -> ('a, _, _) t + include Infix3 with type ('a,'d,'e) t := ('a,'d,'e) t + module Let_syntax : sig + val return : 'a -> ('a, _, _) t + val bind : ('a, 'd, 'e) t -> f:('a -> ('b, 'd, 'e) t) -> ('b, 'd, 'e) t + val map : ('a, 'd, 'e) t -> f:('a -> 'b) -> ('b, 'd, 'e) t + val both : ('a, 'd, 'e) t -> ('b, 'd, 'e) t -> ('a * 'b, 'd, 'e) t + module Open_on_rhs : sig end + end + end +end + +(** The same as S except the monad type has three arguments. The second and third are + always just passed through. *) +module type S3 = sig + type ('a, 'd, 'e) t + include Infix3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) t + include Syntax3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) t + + module Monad_infix : Infix3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) t + + val bind : ('a, 'd, 'e) t -> f:('a -> ('b, 'd, 'e) t) -> ('b, 'd, 'e) t + + val return : 'a -> ('a, _, _) t + + val map : ('a, 'd, 'e) t -> f:('a -> 'b) -> ('b, 'd, 'e) t + + val join : (('a, 'd, 'e) t, 'd, 'e) t -> ('a, 'd, 'e) t + + val ignore_m : (_, 'd, 'e) t -> (unit, 'd, 'e) t + + val all : ('a, 'd, 'e) t list -> ('a list, 'd, 'e) t + + val all_unit : (unit, 'd, 'e) t list -> (unit, 'd, 'e) t + + val all_ignore : (unit, 'd, 'e) t list -> (unit, 'd, 'e) t + [@@deprecated "[since 2018-02] Use [all_unit]"] +end + +(** Indexed monad, in the style of Atkey. The second and third parameters are composed + across all computation. To see this more clearly, you can look at the type of bind: + + {[ + val bind : ('a, 'i, 'j) t -> f:('a -> ('b, 'j, 'k) t) -> ('b, 'i, 'k) t + ]} + + and isolate some of the type variables to see their individual behaviors: + + {[ + val bind : 'a -> f:('a -> 'b ) -> 'b + val bind : 'i, 'j -> 'j, 'k -> 'i, 'k + ]} + + For more information on Atkey-style indexed monads, see: + + {v + Parameterised Notions of Computation + Robert Atkey + http://bentnib.org/paramnotions-jfp.pdf + v} *) +module type Basic_indexed = sig + type ('a, 'i, 'j) t + val bind : ('a, 'i, 'j) t -> f:('a -> ('b, 'j, 'k) t) -> ('b, 'i, 'k) t + val map : [ `Define_using_bind + | `Custom of (('a, 'i, 'j) t -> f:('a -> 'b) -> ('b, 'i, 'j) t) + ] + val return : 'a -> ('a, 'i, 'i) t +end + +(** Same as Infix, except the monad type has three arguments. The second and third are + compose across all computation. *) +module type Infix_indexed = sig + type ('a, 'i, 'j) t + val (>>=) : ('a, 'i, 'j) t -> ('a -> ('b, 'j, 'k) t) -> ('b, 'i, 'k) t + val (>>|) : ('a, 'i, 'j) t -> ('a -> 'b) -> ('b, 'i, 'j) t +end + +module type Syntax_indexed = sig + type ('a, 'i, 'j) t + + module Let_syntax : sig + val return : 'a -> ('a, 'i, 'i) t + include Infix_indexed with type ('a,'i,'j) t := ('a,'i,'j) t + module Let_syntax : sig + val return : 'a -> ('a, 'i, 'i) t + val bind : ('a, 'i, 'j) t -> f:('a -> ('b, 'j, 'k) t) -> ('b, 'i, 'k) t + val map : ('a, 'i, 'j) t -> f:('a -> 'b) -> ('b, 'i, 'j) t + val both : ('a, 'i, 'j) t -> ('b, 'j, 'k) t -> ('a * 'b, 'i, 'k) t + module Open_on_rhs : sig end + end + end +end + +(** The same as S except the monad type has three arguments. The second and third are + composed across all computation. *) +module type S_indexed = sig + type ('a, 'i, 'j) t + include Infix_indexed with type ('a, 'i, 'j) t := ('a, 'i, 'j) t + include Syntax_indexed with type ('a, 'i, 'j) t := ('a, 'i, 'j) t + + module Monad_infix : Infix_indexed with type ('a, 'i, 'j) t := ('a, 'i, 'j) t + + val bind : ('a, 'i, 'j) t -> f:('a -> ('b, 'j, 'k) t) -> ('b, 'i, 'k) t + + val return : 'a -> ('a, 'i, 'i) t + + val map : ('a, 'i, 'j) t -> f:('a -> 'b) -> ('b, 'i, 'j) t + + val join : (('a, 'j, 'k) t, 'i, 'j) t -> ('a, 'i, 'k) t + + val ignore_m : (_, 'i, 'j) t -> (unit, 'i, 'j) t + + val all : ('a, 'i, 'i) t list -> ('a list, 'i, 'i) t + + val all_unit : (unit, 'i, 'i) t list -> (unit, 'i, 'i) t + + val all_ignore : (unit, 'i, 'i) t list -> (unit, 'i, 'i) t + [@@deprecated "[since 2018-02] Use [all_unit]"] +end + +module S_to_S2 (X : S) : (S2 with type ('a, 'e) t = 'a X.t) = struct + type ('a, 'e) t = 'a X.t + include (X : S with type 'a t := 'a X.t) +end + +module S2_to_S3 (X : S2) : (S3 with type ('a, 'd, 'e) t = ('a, 'd) X.t) = struct + type ('a, 'd, 'e) t = ('a, 'd) X.t + include (X : S2 with type ('a, 'd) t := ('a, 'd) X.t) +end + +module S_to_S_indexed (X : S) : (S_indexed with type ('a, 'i, 'j) t = 'a X.t) = struct + type ('a, 'i, 'j) t = 'a X.t + include (X : S with type 'a t := 'a X.t) +end + +module S2_to_S (X : S2) : (S with type 'a t = ('a, unit) X.t) = struct + type 'a t = ('a, unit) X.t + include (X : S2 with type ('a, 'e) t := ('a, 'e) X.t) +end + +module S3_to_S2 (X : S3) : (S2 with type ('a, 'e) t = ('a, 'e, unit) X.t) = struct + type ('a, 'e) t = ('a, 'e, unit) X.t + include (X : S3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) X.t) +end + +module S_indexed_to_S2 (X : S_indexed) : (S2 with type ('a, 'e) t = ('a, 'e, 'e) X.t) = +struct + type ('a, 'e) t = ('a, 'e, 'e) X.t + include (X : S_indexed with type ('a, 'i, 'j) t := ('a, 'i, 'j) X.t) +end + +module type Monad = sig + (** A monad is an abstraction of the concept of sequencing of computations. A value of + type ['a monad] represents a computation that returns a value of type ['a]. *) + + module type Basic = Basic + module type Basic2 = Basic2 + module type Basic3 = Basic3 + module type Basic_indexed = Basic_indexed + module type Infix = Infix + module type Infix2 = Infix2 + module type Infix3 = Infix3 + module type Infix_indexed = Infix_indexed + module type Syntax = Syntax + module type Syntax2 = Syntax2 + module type Syntax3 = Syntax3 + module type Syntax_indexed = Syntax_indexed + module type S_without_syntax = S_without_syntax + module type S = S + module type S2 = S2 + module type S3 = S3 + module type S_indexed = S_indexed + + module Make (X : Basic ) : S with type 'a t := 'a X.t + module Make2 (X : Basic2) : S2 with type ('a, 'e) t := ('a, 'e) X.t + module Make3 (X : Basic3) : S3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) X.t + module Make_indexed (X : Basic_indexed) : S_indexed with type ('a, 'd, 'e) t := ('a, 'd, 'e) X.t + + module Ident : S with type 'a t = 'a +end diff --git a/src/nativeint.ml b/src/nativeint.ml new file mode 100644 index 0000000..914fb03 --- /dev/null +++ b/src/nativeint.ml @@ -0,0 +1,258 @@ +open! Import +open! Caml.Nativeint +include Nativeint_replace_polymorphic_compare + +module T = struct + type t = nativeint [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_nativeint + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_nativeint in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = nativeint_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_nativeint + [@@@end] + let compare = Nativeint_replace_polymorphic_compare.compare + + let to_string = to_string + let of_string = of_string +end + +include T +include Comparator.Make(T) +include Comparable.Validate_with_zero (struct + include T + let zero = zero + end) + + +module Conv = Int_conversions +include Conv.Make (T) +include Conv.Make_hex(struct + open Nativeint_replace_polymorphic_compare + type t = nativeint [@@deriving_inline compare, hash] + let compare : t -> t -> int = compare_nativeint + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_nativeint + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_nativeint in fun x -> func x + [@@@end] + + let zero = zero + let neg = neg + let (<) = (<) + let to_string i = Printf.sprintf "%nx" i + let of_string s = Caml.Scanf.sscanf s "%nx" Fn.id + + let module_name = "Base.Nativeint.Hex" + end) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string = to_string + let module_name = "Base.Nativeint" + end) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Nativeint_replace_polymorphic_compare + +let num_bits = Word_size.num_bits Word_size.word_size +let float_lower_bound = Float0.lower_bound_for_int num_bits +let float_upper_bound = Float0.upper_bound_for_int num_bits + +let shift_right_logical = shift_right_logical +let shift_right = shift_right +let shift_left = shift_left +let bit_not = lognot +let bit_xor = logxor +let bit_or = logor +let bit_and = logand +let min_value = min_int +let max_value = max_int +let abs = abs +let pred = pred +let succ = succ +let rem = rem +let neg = neg +let minus_one = minus_one +let one = one +let zero = zero +let to_float = to_float +let of_float_unchecked = of_float +let of_float f = + if Float_replace_polymorphic_compare.(>=) f float_lower_bound + && Float_replace_polymorphic_compare.(<=) f float_upper_bound + then + of_float f + else + Printf.invalid_argf "Nativeint.of_float: argument (%f) is out of range or NaN" + (Float0.box f) + () + +module Pow2 = struct + open! Import + open Nativeint_replace_polymorphic_compare + + module Sys = Sys0 + + let raise_s = Error.raise_s + + let non_positive_argument () = + Printf.invalid_argf "argument must be strictly positive" () + + let ( lor ) = Caml.Nativeint.logor;; + let ( lsr ) = Caml.Nativeint.shift_right_logical;; + let ( land ) = Caml.Nativeint.logand;; + + (** "ceiling power of 2" - Least power of 2 greater than or equal to x. *) + let ceil_pow2 (x : nativeint) = + if x <= 0n then non_positive_argument (); + let x = Caml.Nativeint.pred x in + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + (* The next line is superfluous on 32-bit architectures, but it's faster to do it + anyway than to branch *) + let x = x lor (x lsr 32) in + Caml.Nativeint.succ x + ;; + + (** "floor power of 2" - Largest power of 2 less than or equal to x. *) + let floor_pow2 x = + if x <= 0n then non_positive_argument (); + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + let x = x lor (x lsr 32) in + Caml.Nativeint.sub x (x lsr 1) + ;; + + let is_pow2 x = + if x <= 0n then non_positive_argument (); + (x land (Caml.Nativeint.pred x)) = 0n + ;; + + (* C stub for nativeint clz to use the CLZ/BSR instruction where possible *) + external nativeint_clz : nativeint -> int = "Base_int_math_nativeint_clz" [@@noalloc] + + (** Hacker's Delight Second Edition p106 *) + let floor_log2 i = + if Poly.( <= ) i Caml.Nativeint.zero then + raise_s (Sexp.message "[Nativeint.floor_log2] got invalid input" + ["", sexp_of_nativeint i]); + Sys.word_size_in_bits - 1 - nativeint_clz i + ;; + + (** Hacker's Delight Second Edition p106 *) + let ceil_log2 i = + if Poly.( <= ) i Caml.Nativeint.zero then + raise_s (Sexp.message "[Nativeint.ceil_log2] got invalid input" + ["", sexp_of_nativeint i]); + if Caml.Nativeint.equal i Caml.Nativeint.one + then 0 + else Sys.word_size_in_bits - nativeint_clz (Caml.Nativeint.pred i) + ;; +end +include Pow2 + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let ( / ) = div +let ( * ) = mul +let ( - ) = sub +let ( + ) = add +let ( ~- ) = neg + +let incr r = r := !r + one +let decr r = r := !r - one + +let of_nativeint t = t +let of_nativeint_exn = of_nativeint +let to_nativeint t = t +let to_nativeint_exn = to_nativeint + +let popcount = Popcount.nativeint_popcount + +let of_int = Conv.int_to_nativeint +let of_int_exn = of_int +let to_int = Conv.nativeint_to_int +let to_int_exn = Conv.nativeint_to_int_exn +let to_int_trunc = Conv.nativeint_to_int_trunc +let of_int32 = Conv.int32_to_nativeint +let of_int32_exn = of_int32 +let to_int32 = Conv.nativeint_to_int32 +let to_int32_exn = Conv.nativeint_to_int32_exn +let to_int32_trunc = Conv.nativeint_to_int32_trunc +let of_int64 = Conv.int64_to_nativeint +let of_int64_exn = Conv.int64_to_nativeint_exn +let of_int64_trunc = Conv.int64_to_nativeint_trunc +let to_int64 = Conv.nativeint_to_int64 + +let pow b e = of_int_exn (Int_math.int_pow (to_int_exn b) (to_int_exn e)) +let ( ** ) b e = pow b e + +module Pre_O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = ( ** ) + include (Nativeint_replace_polymorphic_compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int_exn = of_int_exn +end + +module O = struct + include Pre_O + include Int_math.Make (struct + type nonrec t = t + include Pre_O + let rem = rem + let to_float = to_float + let of_float = of_float + let of_string = T.of_string + let to_string = T.to_string + end) + + let ( land ) = bit_and + let ( lor ) = bit_or + let ( lxor ) = bit_xor + let ( lnot ) = bit_not + let ( lsl ) = shift_left + let ( asr ) = shift_right + let ( lsr ) = shift_right_logical +end + +include O (* [Nativeint] and [Nativeint.O] agree value-wise *) + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Nativeint_replace_polymorphic_compare diff --git a/src/nativeint.mli b/src/nativeint.mli new file mode 100644 index 0000000..0ba4d53 --- /dev/null +++ b/src/nativeint.mli @@ -0,0 +1,27 @@ +(** Processor-native integers. *) + +open! Import + +include Int_intf.S with type t = nativeint + +(** {2 Conversion functions} *) + +val of_int : int -> t +val to_int : t -> int option + +val of_int32 : int32 -> t +val to_int32 : t -> int32 option + +val of_nativeint : nativeint -> t +val to_nativeint : t -> nativeint + +val of_int64 : int64 -> t option + +(** {3 Truncating conversions} + + These functions return the least-significant bits of the input. In cases where + optional conversions return [Some x], truncating conversions return [x]. *) + +val to_int_trunc : t -> int +val to_int32_trunc : t -> int32 +val of_int64_trunc : int64 -> t diff --git a/src/obj_array.ml b/src/obj_array.ml new file mode 100644 index 0000000..0989345 --- /dev/null +++ b/src/obj_array.ml @@ -0,0 +1,179 @@ +open! Import + +module Int = Int0 +module String = String0 +module Array = Array0 + +(* We maintain the property that all values of type [t] do not have the tag + [double_array_tag]. Some functions below assume this in order to avoid testing the + tag, and will segfault if this property doesn't hold. *) +type t = Caml.Obj.t array + +let invariant t = + assert (Caml.Obj.tag (Caml.Obj.repr t) <> Caml.Obj.double_array_tag); +;; + +let length = Array.length + +let swap t i j = Array.swap t i j + +let sexp_of_t t = + Sexp.Atom (String.concat ~sep:"" + [ "" + ]) +;; + +let zero_obj = Caml.Obj.repr (0 : int) + +(* We call [Array.create] with a value that is not a float so that the array doesn't get + tagged with [Double_array_tag]. *) +let create_zero ~len = Array.create ~len zero_obj + +let create ~len x = + (* If we can, use [Array.create] directly. *) + if (Caml.Obj.tag x) <> Caml.Obj.double_tag then begin + Array.create ~len x + end else begin + (* Otherwise use [create_zero] and set the contents *) + let t = create_zero ~len in + let x = Sys.opaque_identity x in + for i = 0 to (len - 1) do + Array.unsafe_set t i x + done; + t + end + +let empty = [||] + +type not_a_float = Not_a_float_0 | Not_a_float_1 of int +let _not_a_float_0 = Not_a_float_0 +let _not_a_float_1 = Not_a_float_1 42 + +let get t i = + (* Make the compiler believe [t] is an array not containing floats so it does not check + if [t] is tagged with [Double_array_tag]. It is NOT ok to use [int array] since (if + this function is inlined and the array contains in-heap boxed values) wrong register + typing may result, leading to a failure to register necessary GC roots. *) + Caml.Obj.repr (Array.get (Caml.Obj.magic (t : t) : not_a_float array) i : not_a_float) +;; + +let [@inline always] unsafe_get t i = + (* Make the compiler believe [t] is an array not containing floats so it does not check + if [t] is tagged with [Double_array_tag]. *) + Caml.Obj.repr + (Array.unsafe_get (Caml.Obj.magic (t : t) : not_a_float array) i : not_a_float) + +let [@inline always] unsafe_set_with_caml_modify t i obj = + (* Same comment as [unsafe_get]. Sys.opaque_identity prevents the compiler from + potentially wrongly guessing the type of the array based on the type of element, that + is prevent the implication: (Obj.tag obj = Obj.double_tag) => (Obj.tag t = + Obj.double_array_tag) which flambda has tried in the past (at least that's assuming + the compiler respects Sys.opaque_identity, which is not always the case). *) + Array.unsafe_set + (Caml.Obj.magic (t : t) : not_a_float array) + i + (Caml.Obj.obj (Sys.opaque_identity obj) : not_a_float) +;; + +let [@inline always] unsafe_set_int_assuming_currently_int t i int = + (* This skips [caml_modify], which is OK if both the old and new values are integers. *) + Array.unsafe_set (Caml.Obj.magic (t : t) : int array) i (Sys.opaque_identity int) +;; + +(* For [set] and [unsafe_set], if a pointer is involved, we first do a physical-equality + test to see if the pointer is changing. If not, we don't need to do the [set], which + saves a call to [caml_modify]. We think this physical-equality test is worth it + because it is very cheap (both values are already available from the [is_int] test) + and because [caml_modify] is expensive. *) + +let set t i obj = + (* We use [get] first but then we use [Array.unsafe_set] since we know that [i] is + valid. *) + let old_obj = get t i in + if Caml.Obj.is_int old_obj && Caml.Obj.is_int obj + then unsafe_set_int_assuming_currently_int t i (Caml.Obj.obj obj : int) + else if not (phys_equal old_obj obj) + then unsafe_set_with_caml_modify t i obj +;; + +let [@inline always] unsafe_set t i obj = + let old_obj = unsafe_get t i in + if Caml.Obj.is_int old_obj && Caml.Obj.is_int obj + then unsafe_set_int_assuming_currently_int t i (Caml.Obj.obj obj : int) + else if not (phys_equal old_obj obj) + then unsafe_set_with_caml_modify t i obj +;; + +let [@inline always] unsafe_set_omit_phys_equal_check t i obj = + let old_obj = unsafe_get t i in + if Caml.Obj.is_int old_obj && Caml.Obj.is_int obj + then unsafe_set_int_assuming_currently_int t i (Caml.Obj.obj obj : int) + else unsafe_set_with_caml_modify t i obj +;; + +let singleton obj = + create ~len:1 obj +;; + +(* Pre-condition: t.(i) is an integer. *) +let unsafe_set_assuming_currently_int t i obj = + if Caml.Obj.is_int obj + then unsafe_set_int_assuming_currently_int t i (Caml.Obj.obj obj : int) + else + (* [t.(i)] is an integer and [obj] is not, so we do not need to check if they are + equal. *) + unsafe_set_with_caml_modify t i obj +;; + +let unsafe_set_int t i int = + let old_obj = unsafe_get t i in + if Caml.Obj.is_int old_obj + then unsafe_set_int_assuming_currently_int t i int + else unsafe_set_with_caml_modify t i (Caml.Obj.repr int) +;; + +let unsafe_clear_if_pointer t i = + let old_obj = unsafe_get t i in + if not (Caml.Obj.is_int old_obj) then unsafe_set_with_caml_modify t i (Caml.Obj.repr 0); +;; + +(** [unsafe_blit] is like [Array.blit], except it uses our own for-loop to avoid + caml_modify when possible. Its performance is still not comparable to a memcpy. *) +let unsafe_blit ~src ~src_pos ~dst ~dst_pos ~len = + (* When [phys_equal src dst], we need to check whether [dst_pos < src_pos] and have the + for loop go in the right direction so that we don't overwrite data that we still need + to read. When [not (phys_equal src dst)], doing this is harmless. From a + memory-performance perspective, it doesn't matter whether one loops up or down. + Constant-stride access, forward or backward, should be indistinguishable (at least on + an intel i7). So, we don't do a check for [phys_equal src dst] and always loop up in + that case. *) + if dst_pos < src_pos + then + for i = 0 to len - 1 do + unsafe_set dst (dst_pos + i) (unsafe_get src (src_pos + i)) + done + else + for i = len - 1 downto 0 do + unsafe_set dst (dst_pos + i) (unsafe_get src (src_pos + i)) + done; +;; + +include + Blit.Make + (struct + type nonrec t = t + let create = create_zero + let length = length + let unsafe_blit = unsafe_blit + end) +;; + +let copy src = + let dst = create_zero ~len:(length src) in + blito ~src ~dst (); + dst +;; + +let truncate t ~len = Caml.Obj.truncate (Caml.Obj.repr (t : t)) len diff --git a/src/obj_array.mli b/src/obj_array.mli new file mode 100644 index 0000000..2481bb9 --- /dev/null +++ b/src/obj_array.mli @@ -0,0 +1,72 @@ +(** This module is deprecated for external use. Users should replace occurrences of + [Obj_array.t] in their code with [Obj.t Uniform_array.t]. + + This module is here for the implementing [Uniform_array] internally, and exposed + through [Not_exposed_properly] to ease the transition for users. +*) + +open! Import + +type t [@@deriving_inline sexp_of] +include +sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Blit. S with type t := t +include Invariant.S with type t := t + +(** [create ~len x] returns an obj-array of length [len], all of whose indices have value + [x]. *) +val create : len:int -> Caml.Obj.t -> t + +(** [create_zero ~len] returns an obj-array of length [len], all of whose indices have + value [Caml.Obj.repr 0]. *) +val create_zero : len:int -> t + +(** [copy t] returns a new array with the same elements as [t]. *) +val copy : t -> t + +val singleton : Caml.Obj.t -> t + +val empty : t + +val length : t -> int + +(** [get t i] and [unsafe_get t i] return the object at index [i]. [set t i o] and + [unsafe_set t i o] set index [i] to [o]. In no case is the object copied. The + [unsafe_*] variants omit the bounds check of [i]. *) +val get : t -> int -> Caml.Obj.t +val unsafe_get : t -> int -> Caml.Obj.t +val set : t -> int -> Caml.Obj.t -> unit +val unsafe_set : t -> int -> Caml.Obj.t -> unit + +val swap : t -> int -> int -> unit + +(** [unsafe_set_assuming_currently_int t i obj] sets index [i] of [t] to [obj], but only + works correctly if [Caml.Obj.is_int (get t i)]. This precondition saves a dynamic + check. + + [unsafe_set_int_assuming_currently_int] is similar, except the value being set is an + int. + + [unsafe_set_int] is similar but does not assume anything about the target. *) +val unsafe_set_assuming_currently_int : t -> int -> Caml.Obj.t -> unit + +val unsafe_set_int_assuming_currently_int : t -> int -> int -> unit +val unsafe_set_int : t -> int -> int -> unit + +(** [unsafe_set_omit_phys_equal_check] is like [unsafe_set], except it doesn't do a + [phys_equal] check to try to skip [caml_modify]. It is safe to call this even if the + values are [phys_equal]. *) +val unsafe_set_omit_phys_equal_check : t -> int -> Caml.Obj.t -> unit + +(** [unsafe_clear_if_pointer t i] prevents [t.(i)] from pointing to anything to prevent + space leaks. It does this by setting [t.(i)] to [Caml.Obj.repr 0]. As a performance hack, + it only does this when [not (Caml.Obj.is_int t.(i))]. *) +val unsafe_clear_if_pointer : t -> int -> unit + +(** [truncate t ~len] shortens [t]'s length to [len]. It is an error if [len <= 0] or + [len > length t].*) +val truncate : t -> len:int -> unit + diff --git a/src/option.ml b/src/option.ml new file mode 100644 index 0000000..591dcc0 --- /dev/null +++ b/src/option.ml @@ -0,0 +1,200 @@ +open! Import + +type 'a t = 'a option [@@deriving_inline sexp, compare, hash] +let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + option_of_sexp +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + sexp_of_option +let compare : 'a . ('a -> 'a -> int) -> 'a t -> 'a t -> int = compare_option +let hash_fold_t : + 'a . + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + = hash_fold_option +[@@@end] + +let is_none = function None -> true | _ -> false + +let is_some = function Some _ -> true | _ -> false + +let value_map o ~default ~f = + match o with + | Some x -> f x + | None -> default + +let iter o ~f = + match o with + | None -> () + | Some a -> f a +;; + +let invariant f t = iter t ~f + +let map2 o1 o2 ~f = + match o1, o2 with + | Some a1, Some a2 -> Some (f a1 a2) + | _ -> None + +let call x ~f = + match f with + | None -> () + | Some f -> f x + +let value t ~default = + match t with + | None -> default + | Some x -> x +;; + +let value_exn ?here ?error ?message t = + match t with + | Some x -> x + | None -> + let error = + match here, error, message with + | None , None , None -> Error.of_string "Option.value_exn None" + | None , None , Some m -> Error.of_string m + | None , Some e, None -> e + | None , Some e, Some m -> Error.tag e ~tag:m + | Some p, None , None -> + Error.create "Option.value_exn" p Source_code_position0.sexp_of_t + | Some p, None , Some m -> + Error.create m p Source_code_position0.sexp_of_t + | Some p, Some e, _ -> + Error.create (value message ~default:"") (e, p) + (sexp_of_pair Error.sexp_of_t Source_code_position0.sexp_of_t) + in + Error.raise error +;; + +let to_array t = + match t with + | None -> [||] + | Some x -> [|x|] +;; + +let to_list t = + match t with + | None -> [] + | Some x -> [x] +;; + +let min_elt t ~compare:_ = t +let max_elt t ~compare:_ = t + +let sum (type a) (module M : Container.Summable with type t = a) t ~f = + match t with + | None -> M.zero + | Some x -> f x +;; + +let for_all t ~f = + match t with + | None -> true + | Some x -> f x +;; + +let exists t ~f = + match t with + | None -> false + | Some x -> f x +;; + +let mem t a ~equal = + match t with + | None -> false + | Some a' -> equal a a' +;; + +let length t = + match t with + | None -> 0 + | Some _ -> 1 +;; + +let is_empty = is_none + +let fold t ~init ~f = + match t with + | None -> init + | Some x -> f init x +;; + +let count t ~f = + match t with + | None -> 0 + | Some a -> if f a then 1 else 0 +;; + +let find t ~f = + match t with + | None -> None + | Some x -> if f x then Some x else None +;; + +let find_map t ~f = + match t with + | None -> None + | Some a -> f a +;; + +let equal f t t' = + match t, t' with + | None, None -> true + | Some x, Some x' -> f x x' + | _ -> false + +let some x = Some x + +let both x y = + match x,y with + | Some a, Some b -> Some (a,b) + | _ -> None + +let first_some x y = + match x with + | Some _ -> x + | None -> y + +let some_if cond x = if cond then Some x else None + +let merge a b ~f = + match a, b with + | None, x | x, None -> x + | Some a, Some b -> Some (f a b) + +let filter t ~f = + match t with + | Some v as o when f v -> o + | _ -> None + +let try_with f = + try Some (f ()) + with _ -> None + +include Monad.Make (struct + type 'a t = 'a option + let return x = Some x + let map t ~f = + match t with + | None -> None + | Some a -> Some (f a) + ;; + let map = `Custom map + let bind o ~f = + match o with + | None -> None + | Some x -> f x + end) + +let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t +let fold_until t ~init ~f = Container.fold_until ~fold ~init ~f t + +let validate ~none ~some t = + let module V = Validate in + match t with + | None -> V.name "none" (V.protect none ()) + | Some x -> V.name "some" (V.protect some x ) +;; diff --git a/src/option.mli b/src/option.mli new file mode 100644 index 0000000..cb4de45 --- /dev/null +++ b/src/option.mli @@ -0,0 +1,75 @@ +(** Option type. *) + +open! Import + +type 'a t = 'a option [@@deriving_inline compare, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Container.S1 with type 'a t := 'a t +include Equal.S1 with type 'a t := 'a t +include Invariant.S1 with type 'a t := 'a t + +(** Options form a monad, where [return x = Some x], [(None >>= f) = None], and [(Some x + >>= f) = f x]. *) +include Monad.S with type 'a t := 'a t + +(** [is_none t] returns true iff [t = None]. *) +val is_none : 'a t -> bool + +(** [is_some t] returns true iff [t = Some x]. *) +val is_some : 'a t -> bool + +(** [value_map ~default ~f] is the same as [function Some x -> f x | None -> default]. *) +val value_map : 'a t -> default:'b -> f:('a -> 'b) -> 'b + +(** [map2 o f] maps ['a option] and ['b option] to a ['c option] using [~f]. *) +val map2 : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t + +(** [call x f] runs an optional function [~f] on the argument. *) +val call : 'a -> f:('a -> unit) t -> unit + +(** [value None ~default] = [default] + + [value (Some x) ~default] = [x] *) +val value : 'a t -> default:'a -> 'a + +(** [value_exn (Some x)] = [x]. [value_exn None] raises an error whose contents contain + the supplied [~here], [~error], and [message], or a default message if none are + supplied. *) +val value_exn + : ?here:Source_code_position0.t + -> ?error:Error.t + -> ?message:string + -> 'a t + -> 'a + +val some : 'a -> 'a t + +val both : 'a t -> 'b t -> ('a * 'b) t + +val first_some : 'a t -> 'a t -> 'a t + +val some_if : bool -> 'a -> 'a t + +(** [merge a b ~f] merges together the values from [a] and [b] using [f]. If both [a] and + [b] are [None], returns [None]. If only one is [Some], returns that one, and if both + are [Some], returns [Some] of the result of applying [f] to the contents of [a] and + [b]. *) +val merge : 'a t -> 'a t -> f:('a -> 'a -> 'a) -> 'a t + +val filter : 'a t -> f:('a -> bool) -> 'a t + +(** [try_with f] returns [Some x] if [f] returns [x] and [None] if [f] raises an + exception. See [Result.try_with] if you'd like to know which exception. *) +val try_with : (unit -> 'a) -> 'a t + +val validate : none:unit Validate.check -> some:'a Validate.check -> 'a t Validate.check diff --git a/src/option_array.ml b/src/option_array.ml new file mode 100644 index 0000000..569f54c --- /dev/null +++ b/src/option_array.ml @@ -0,0 +1,171 @@ +open! Import + + +(** ['a Cheap_option.t] is like ['a option], but it doesn't box [some _] values. + + There are several things that are unsafe about it: + + - [float t array] (or any array-backed container) is not memory-safe + because float array optimization is incompatible with unboxed option + optimization. You have to use [Uniform_array.t] instead of [array]. + + - Nested options (['a t t]) don't work. They are believed to be + memory-safe, but not parametric. + + - A record with [float t]s in it should be safe, but it's only [t] being + abstract that gives you safety. If the compiler was smart enough to peek + through the module signature then it could decide to construct a float + array instead. *) +module Cheap_option = struct + + (* This is taken from core_kernel. Rather than expose it in the public + interface of base, just keep a copy around here. *) + let phys_same (type a) (type b) (a : a) (b : b) = + phys_equal a (Caml.Obj.magic b : a) + + module T0 : sig + type 'a t + val none : _ t + val some : 'a -> 'a t + val is_none : _ t -> bool + val is_some : _ t -> bool + val value_exn : 'a t -> 'a + val value_unsafe : 'a t -> 'a + end + = struct + type +'a t + + (* Being a pointer, no one outside this module can construct a value that is + [phys_same] as this one. + + It would be simpler to use this value as [none], but we use an immediate instead + because it lets us avoid caml_modify when setting to [none], making certain + benchmarks significantly faster (e.g. ../bench/array_queue.exe). + + this code is duplicated in Moption, and if we find yet another place where we want + it we should reconsider making it shared. *) + let none_substitute : _ t = Caml.Obj.obj (Caml.Obj.new_block Caml.Obj.abstract_tag 1) + + let none : _ t = + (* The number was produced by + [< /dev/urandom tr -c -d '1234567890abcdef' | head -c 16]. + + The idea is that a random number will have lower probability to collide with + anything than any number we can choose ourselves. + + We are using a polymorphic variant instead of an integer constant because there + is a compiler bug where it wrongly assumes that the result of [if _ then c else + y] is not a pointer if [c] is an integer compile-time constant. This is being + fixed in https://github.com/ocaml/ocaml/pull/555. The "memory corruption" test + below demonstrates the issue. *) + Caml.Obj.magic `x6e8ee3478e1d7449 + + let is_none x = phys_equal x none + + let is_some x = not (phys_equal x none) + + let some (type a) (x : a) : a t = + if phys_same x none + then none_substitute + else (Caml.Obj.magic x) + + let value_unsafe (type a) (x : a t) : a = + if phys_equal x none_substitute + then Caml.Obj.magic none + else Caml.Obj.magic x + + let value_exn x = + if is_some x + then value_unsafe x + else failwith "Option_array.get_some_exn: the element is [None]" + + end + + module T1 = struct + include T0 + let of_option = function + | None -> none + | Some x -> some x + let to_option x = + if is_some x + then Some (value_unsafe x) + else None + + let to_sexpable = to_option + let of_sexpable = of_option + end + + include T1 + include Sexpable.Of_sexpable1(Option)(T1) +end + +type 'a t = 'a Cheap_option.t Uniform_array.t [@@deriving_inline sexp] +let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + let _tp_loc = "src/option_array.ml.t" in + fun _of_a -> + fun t -> Uniform_array.t_of_sexp (Cheap_option.t_of_sexp _of_a) t +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + fun v -> Uniform_array.sexp_of_t (Cheap_option.sexp_of_t _of_a) v +[@@@end] + +let empty = Uniform_array.empty + +let create ~len = Uniform_array.create ~len Cheap_option.none + +let init n ~f = Uniform_array.init n ~f:(fun i -> Cheap_option.of_option (f i)) + +let init_some n ~f = Uniform_array.init n ~f:(fun i -> Cheap_option.some (f i)) + +let length = Uniform_array.length + +let get t i = Cheap_option.to_option (Uniform_array.get t i) + +let get_some_exn t i = Cheap_option.value_exn (Uniform_array.get t i) + +let is_none t i = Cheap_option.is_none (Uniform_array.get t i) + +let is_some t i = Cheap_option.is_some (Uniform_array.get t i) + +let set t i x = Uniform_array.set t i (Cheap_option.of_option x) + +let set_some t i x = Uniform_array.set t i (Cheap_option.some x) + +let set_none t i = Uniform_array.set t i Cheap_option.none + +let swap t i j = Uniform_array.swap t i j + +let unsafe_get t i = Cheap_option.to_option (Uniform_array.unsafe_get t i) + +let unsafe_get_some_exn t i = Cheap_option.value_exn (Uniform_array.unsafe_get t i) + +let unsafe_is_some t i = Cheap_option.is_some (Uniform_array.unsafe_get t i) + +let unsafe_set t i x = Uniform_array.unsafe_set t i (Cheap_option.of_option x) + +let unsafe_set_some t i x = Uniform_array.unsafe_set t i (Cheap_option.some x) + +let unsafe_set_none t i = Uniform_array.unsafe_set t i (Cheap_option.none) + +let clear t = + for i = 0 to length t - 1 + do + unsafe_set_none t i + done + +include + Blit.Make1_generic + (struct + type nonrec 'a t = 'a t + let length = length + let create_like ~len _ = create ~len + let unsafe_blit = Uniform_array.unsafe_blit + end) + +let copy = Uniform_array.copy + +module For_testing = struct + module Unsafe_cheap_option = Cheap_option +end diff --git a/src/option_array.mli b/src/option_array.mli new file mode 100644 index 0000000..a39a5d0 --- /dev/null +++ b/src/option_array.mli @@ -0,0 +1,93 @@ +(** ['a Option_array.t] is a compact representation of ['a option array]: it avoids + allocating heap objects representing [Some x], usually representing them with [x] + instead. It uses a special representation for [None] that's guaranteed to never + collide with any representation of [Some x]. *) + +open! Import + +type 'a t [@@deriving_inline sexp] +include +sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +val empty : _ t + +(** Initially filled with all [None] *) +val create : len:int -> _ t + + +val init_some : int -> f:(int -> 'a ) -> 'a t +val init : int -> f:(int -> 'a option) -> 'a t + +val length : _ t -> int + +(** [get t i] returns the element number [i] of array [t], raising if [i] is outside the + range 0 to [length t - 1]. *) +val get : 'a t -> int -> 'a option + +(** Raises if the element number [i] is [None]. *) +val get_some_exn : 'a t -> int -> 'a + +(** [is_none t i = Option.is_none (get t i)] *) +val is_none : _ t -> int -> bool + +(** [is_some t i = Option.is_some (get t i)] *) +val is_some : _ t -> int -> bool + +(** These can cause arbitrary behavior when used for an out-of-bounds array access. *) + +val unsafe_get : 'a t -> int -> 'a option +val unsafe_get_some_exn : 'a t -> int -> 'a +val unsafe_is_some : _ t -> int -> bool + +(** [set t i x] modifies array [t] in place, replacing element number [i] with [x], + raising if [i] is outside the range 0 to [length t - 1]. *) +val set : 'a t -> int -> 'a option -> unit +val set_some : 'a t -> int -> 'a -> unit +val set_none : _ t -> int -> unit + +val swap : _ t -> int -> int -> unit + +(** Replaces all the elements of the array with [None]. *) +val clear : _ t -> unit + +(** Unsafe versions of [set*]. Can cause arbitrary behaviour when used for an + out-of-bounds array access. *) + +val unsafe_set : 'a t -> int -> 'a option -> unit +val unsafe_set_some : 'a t -> int -> 'a -> unit +val unsafe_set_none : _ t -> int -> unit + +include Blit.S1 with type 'a t := 'a t + +(** Makes a (shallow) copy of the array. *) +val copy : 'a t -> 'a t + +(**/**) + +module For_testing : sig + module Unsafe_cheap_option : sig + type 'a t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t + end[@@ocaml.doc "@inline"] + [@@@end] + + val none : _ t + val some : 'a -> 'a t + + val is_none : _ t -> bool + val is_some : _ t -> bool + + val value_exn : 'a t -> 'a + val value_unsafe : 'a t -> 'a + + val to_option : 'a t -> 'a Option.t + val of_option : 'a Option.t -> 'a t + end +end diff --git a/src/or_error.ml b/src/or_error.ml new file mode 100644 index 0000000..4ded9a3 --- /dev/null +++ b/src/or_error.ml @@ -0,0 +1,129 @@ +open! Import + +type 'a t = ('a, Error.t) Result.t [@@deriving_inline compare, hash, sexp] +let compare : 'a . ('a -> 'a -> int) -> 'a t -> 'a t -> int = + fun _cmp__a -> + fun a__001_ -> + fun b__002_ -> Result.compare _cmp__a Error.compare a__001_ b__002_ +let hash_fold_t : + 'a . + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + = + fun _hash_fold_a -> + fun hsv -> + fun arg -> Result.hash_fold_t _hash_fold_a Error.hash_fold_t hsv arg +let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + let _tp_loc = "src/or_error.ml.t" in + fun _of_a -> fun t -> Result.t_of_sexp _of_a Error.t_of_sexp t +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> fun v -> Result.sexp_of_t _of_a Error.sexp_of_t v +[@@@end] + +let invariant invariant_a t = + match t with + | Ok a -> invariant_a a + | Error error -> Error.invariant error +;; + +include (Result : Monad.S2 + with type ('a, 'b) t := ('a, 'b) Result.t + with module Let_syntax := Result.Let_syntax) + +include Applicative.Make (struct + type nonrec 'a t = 'a t + let return = return + let apply f x = + Result.combine f x + ~ok:(fun f x -> f x) + ~err:(fun e1 e2 -> Error.of_list [e1; e2]) + let map = `Custom map + end) + +module Let_syntax = struct + let return = return + include Monad_infix + module Let_syntax = struct + let return = return + let map = map + let bind = bind + let both = both (* from Applicative.Make *) + module Open_on_rhs = struct end + end +end + +let ok = Result.ok +let is_ok = Result.is_ok +let is_error = Result.is_error + +let ignore = ignore_m + +let try_with ?(backtrace = false) f = + try Ok (f ()) + with exn -> Error (Error.of_exn exn ?backtrace:(if backtrace then Some `Get else None)) +;; + +let try_with_join ?backtrace f = join (try_with ?backtrace f) + +let ok_exn = function + | Ok x -> x + | Error err -> Error.raise err +;; + +let of_exn ?backtrace exn = Error (Error.of_exn ?backtrace exn) + +let of_exn_result ?backtrace = function + | Ok _ as z -> z + | Error exn -> of_exn ?backtrace exn +;; + +let error ?strict message a sexp_of_a = + Error (Error.create ?strict message a sexp_of_a) +;; + +let error_s sexp = Error (Error.create_s sexp) + +let error_string message = Error (Error.of_string message) + +let errorf format = Printf.ksprintf error_string format + +let tag t ~tag = Result.map_error t ~f:(Error.tag ~tag) +let tag_arg t message a sexp_of_a = + Result.map_error t ~f:(fun e -> Error.tag_arg e message a sexp_of_a) +;; + +let unimplemented s = error "unimplemented" s sexp_of_string + +let combine_errors l = Result.map_error (Result.combine_errors l) ~f:Error.of_list + +let combine_errors_unit l = Result.map (combine_errors l) ~f:(fun (_ : unit list) -> ()) + +let filter_ok_at_least_one l = + let ok, errs = List.partition_map l ~f:Result.ok_fst in + match ok with + | [] -> Error (Error.of_list errs) + | _ -> Ok ok +;; + +let find_ok l = + match List.find_map l ~f:Result.ok with + | Some x -> Ok x + | None -> + Error (Error.of_list (List.map l ~f:(function + | Ok _ -> assert false + | Error err -> err))) +;; + +let find_map_ok l ~f = + With_return.with_return (fun {return} -> + Error (Error.of_list (List.map l ~f:(fun elt -> + match f elt with + | (Ok _ as x) -> return x + | Error err -> err)))) +;; + +let map = Result.map +let iter = Result.iter +let iter_error = Result.iter_error diff --git a/src/or_error.mli b/src/or_error.mli new file mode 100644 index 0000000..390270b --- /dev/null +++ b/src/or_error.mli @@ -0,0 +1,126 @@ +(** Type for tracking errors in an [Error.t]. This is a specialization of the [Result] + type, where the [Error] constructor carries an [Error.t]. + + A common idiom is to wrap a function that is not implemented on all platforms, e.g., + + {[val do_something_linux_specific : (unit -> unit) Or_error.t]} +*) + +open! Import + +(** Serialization and comparison of an [Error] force the error's lazy message. *) +type 'a t = ('a, Error.t) Result.t [@@deriving_inline compare, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +(** [Applicative] functions don't have quite the same semantics as + [Applicative.of_Monad(Or_error)] would give -- [apply (Error e1) (Error e2)] returns + the combination of [e1] and [e2], whereas it would only return [e1] if it were defined + using [bind]. *) +include Applicative.S with type 'a t := 'a t +include Invariant.S1 with type 'a t := 'a t +include Monad.S with type 'a t := 'a t + +val is_ok : _ t -> bool +val is_error : _ t -> bool + +val ignore : _ t -> unit t + +(** [try_with f] catches exceptions thrown by [f] and returns them in the [Result.t] as an + [Error.t]. [try_with_join] is like [try_with], except that [f] can throw exceptions + or return an [Error] directly, without ending up with a nested error; it is equivalent + to [Result.join (try_with f)]. *) +val try_with : ?backtrace:bool (** defaults to [false] *) -> (unit -> 'a ) -> 'a t +val try_with_join : ?backtrace:bool (** defaults to [false] *) -> (unit -> 'a t) -> 'a t + + +(** [ok t] returns [None] if [t] is an [Error], and otherwise returns the contents of the + [Ok] constructor. *) +val ok : 'ok t -> 'ok option + +(** [ok_exn t] throws an exception if [t] is an [Error], and otherwise returns the + contents of the [Ok] constructor. *) +val ok_exn : 'a t -> 'a + +(** [of_exn ?backtrace exn] is [Error (Error.of_exn ?backtrace exn)]. *) +val of_exn : ?backtrace:[ `Get | `This of string ] -> exn -> _ t + +(** [of_exn_result ?backtrace (Ok a) = Ok a] + + [of_exn_result ?backtrace (Error exn) = of_exn ?backtrace exn] *) +val of_exn_result : ?backtrace:[ `Get | `This of string ] -> ('a, exn) Result.t -> 'a t + +(** [error] is a wrapper around [Error.create]: + + {[ + error ?strict message a sexp_of_a + = Error (Error.create ?strict message a sexp_of_a) + ]} + + As with [Error.create], [sexp_of_a a] is lazily computed when the info is converted + to a sexp. So, if [a] is mutated in the time between the call to [create] and the + sexp conversion, those mutations will be reflected in the sexp. Use [~strict:()] to + force [sexp_of_a a] to be computed immediately. *) +val error + : ?strict : unit + -> string + -> 'a + -> ('a -> Sexp.t) + -> _ t + +val error_s : Sexp.t -> _ t + +(** [error_string message] is [Error (Error.of_string message)]. *) +val error_string : string -> _ t + +(** [errorf format arg1 arg2 ...] is [Error (sprintf format arg1 arg2 ...)]. Note that it + calculates the string eagerly, so when performance matters you may want to use [error] + instead. *) +val errorf : ('a, unit, string, _ t) format4 -> 'a + +(** [tag t ~tag] is [Result.map_error t ~f:(Error.tag ~tag)]. + [tag_arg] is similar. *) +val tag : 'a t -> tag:string -> 'a t +val tag_arg : 'a t -> string -> 'b -> ('b -> Sexp.t) -> 'a t + +(** For marking a given value as unimplemented. Typically combined with conditional + compilation, where on some platforms the function is defined normally, and on some + platforms it is defined as unimplemented. The supplied string should be the name of + the function that is unimplemented. *) +val unimplemented : string -> _ t + +val map : 'a t -> f:('a -> 'b) -> 'b t +val iter : 'a t -> f:('a -> unit) -> unit +val iter_error : _ t -> f:(Error.t -> unit) -> unit + +(** [combine_errors ts] returns [Ok] if every element in [ts] is [Ok], else it returns + [Error] with all the errors in [ts]. More precisely: + + - [combine_errors [Ok a1; ...; Ok an] = Ok [a1; ...; an]] + - {[ combine_errors [...; Error e1; ...; Error en; ...] + = Error (Error.of_list [e1; ...; en]) ]} *) +val combine_errors : 'a t list -> 'a list t + +(** [combine_errors_unit ts] returns [Ok] if every element in [ts] is [Ok ()], else it + returns [Error] with all the errors in [ts], like [combine_errors]. *) +val combine_errors_unit : unit t list -> unit t + +(** [filter_ok_at_least_one ts] returns all values in [ts] that are [Ok] if there is at + least one, otherwise it returns the same error as [combine_errors ts]. *) +val filter_ok_at_least_one : 'a t list -> 'a list t + +(** [find_ok ts] returns the first value in [ts] that is [Ok], otherwise it returns the + same error as [combine_errors ts]. *) +val find_ok : 'a t list -> 'a t + +(** [find_map_ok l ~f] returns the first value in [l] for which [f] returns [Ok], + otherwise it returns the same error as [combine_errors (List.map l ~f)]. *) +val find_map_ok : 'a list -> f:('a -> 'b t) -> 'b t diff --git a/src/ordered_collection_common.ml b/src/ordered_collection_common.ml new file mode 100644 index 0000000..9e1fda3 --- /dev/null +++ b/src/ordered_collection_common.ml @@ -0,0 +1,46 @@ +open! Import + +let invalid_argf = Printf.invalid_argf + +let [@inline never] slow_check_pos_len_exn ~pos ~len ~total_length = + if pos < 0 + then invalid_argf "Negative position: %d" pos (); + if len < 0 + then invalid_argf "Negative length: %d" len (); + (* We use [pos > total_length - len] rather than [pos + len > total_length] to avoid the + possibility of overflow. *) + if pos > total_length - len + then invalid_argf "pos + len past end: %d + %d > %d" pos len total_length () +;; + +let check_pos_len_exn ~pos ~len ~total_length = + (* This is better than [slow_check_pos_len_exn] for two reasons: + + - much less inlined code + - only one conditional jump + + The reason it works is that checking [< 0] is testing the highest order bit, so + [a < 0 || b < 0] is the same as [a lor b < 0]. + + [pos + len] can overflow, so [pos > total_length - len] is not equivalent to + [total_length - len - pos < 0], we need to test for [pos + len] overflow as + well. *) + let stop = pos + len in + if pos lor len lor stop lor (total_length - stop) < 0 then + slow_check_pos_len_exn ~pos ~len ~total_length +;; + +let get_pos_len_exn ?(pos = 0) ?len () ~total_length = + let len = match len with Some i -> i | None -> total_length - pos in + check_pos_len_exn ~pos ~len ~total_length; + pos, len +;; + +let get_pos_len ?pos ?len () ~total_length = + try Result.Ok (get_pos_len_exn () ?pos ?len ~total_length) + with Invalid_argument s -> Or_error.error_string s +;; + +module Private = struct + let slow_check_pos_len_exn = slow_check_pos_len_exn +end diff --git a/src/ordered_collection_common.mli b/src/ordered_collection_common.mli new file mode 100644 index 0000000..2397835 --- /dev/null +++ b/src/ordered_collection_common.mli @@ -0,0 +1,39 @@ +(** Functions for ordered collections. *) + +open! Import + +(** [get_pos_len], [get_pos_len_exn], and [check_pos_len_exn] are intended to be used + by functions that take a sequence (array, string, bigstring, ...) and an optional + [pos] and [len] specifying a subrange of the sequence. Such functions should call + [get_pos_len] with the length of the sequence and the optional [pos] and [len], and it + will return the [pos] and [len] specifying the range, where the default [pos] is zero + and the default [len] is to go to the end of the sequence. + + It should be the case that: + + {[ + pos >= 0 && len >= 0 && pos + len <= total_length + ]} + + Note that this allows [pos = total_length] and [len = 0], i.e., an empty subrange + at the end of the sequence. + + [get_pos_len] returns [(pos', len')] specifying a subrange where: + + {v + pos' = match pos with None -> 0 | Some i -> i + len' = match len with None -> total_length - pos' | Some i -> i + v} *) +val get_pos_len : ?pos:int -> ?len:int -> unit -> total_length:int -> (int * int) Or_error.t +val get_pos_len_exn : ?pos:int -> ?len:int -> unit -> total_length:int -> int * int + +(** [check_pos_len_exn ~pos ~len ~total_length] raises unless [pos >= 0 && len >= 0 && + pos + len <= total_length]. *) +val check_pos_len_exn : pos:int -> len:int -> total_length:int -> unit + +(*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) +module Private : sig + val slow_check_pos_len_exn : pos:int -> len:int -> total_length:int -> unit +end diff --git a/src/ordering.ml b/src/ordering.ml new file mode 100644 index 0000000..f812e99 --- /dev/null +++ b/src/ordering.ml @@ -0,0 +1,70 @@ +open! Import + +type t = Less | Equal | Greater [@@deriving_inline compare, hash, enumerate, sexp] +let compare : t -> t -> int = Ppx_compare_lib.polymorphic_compare +let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + (fun hsv -> + fun arg -> + match arg with + | Less -> Ppx_hash_lib.Std.Hash.fold_int hsv 0 + | Equal -> Ppx_hash_lib.Std.Hash.fold_int hsv 1 + | Greater -> Ppx_hash_lib.Std.Hash.fold_int hsv 2 : Ppx_hash_lib.Std.Hash.state + -> + t -> + Ppx_hash_lib.Std.Hash.state) +let (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in hash_fold_t hsv arg) in + fun x -> func x +let all : t list = [Less; Equal; Greater] +let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = + let _tp_loc = "src/ordering.ml.t" in + function + | Ppx_sexp_conv_lib.Sexp.Atom ("less"|"Less") -> Less + | Ppx_sexp_conv_lib.Sexp.Atom ("equal"|"Equal") -> Equal + | Ppx_sexp_conv_lib.Sexp.Atom ("greater"|"Greater") -> Greater + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("less"|"Less"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("equal"|"Equal"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("greater"|"Greater"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as sexp + -> Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | Less -> Ppx_sexp_conv_lib.Sexp.Atom "Less" + | Equal -> Ppx_sexp_conv_lib.Sexp.Atom "Equal" + | Greater -> Ppx_sexp_conv_lib.Sexp.Atom "Greater" +[@@@end] + +let equal a b = compare a b = 0 + +module Export = struct + type _ordering = t = + | Less + | Equal + | Greater +end + +let of_int n = + if n < 0 + then Less + else if n = 0 + then Equal + else Greater +;; + +let to_int = function + | Less -> -1 + | Equal -> 0 + | Greater -> 1 +;; diff --git a/src/ordering.mli b/src/ordering.mli new file mode 100644 index 0000000..8293e9d --- /dev/null +++ b/src/ordering.mli @@ -0,0 +1,74 @@ +(** [Ordering] is intended to make code that matches on the result of a comparison + more concise and easier to read. + + For example, instead of writing: + + {[ + let r = compare x y in + if r < 0 then + ... + else if r = 0 then + ... + else + ... + ]} + + you could simply write: + + {[ + match Ordering.of_int (compare x y) with + | Less -> ... + | Equal -> ... + | Greater -> ... + ]} + +*) + +open! Import + +type t = + | Less + | Equal + | Greater +[@@deriving_inline compare, enumerate, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Equal.S with type t := t + +(** [of_int n] is: + + {v + Less if n < 0 + Equal if n = 0 + Greater if n > 0 + v} *) +val of_int : int -> t + +(** [to_int t] is: + + {v + Less -> -1 + Equal -> 0 + Greater -> 1 + v} + + It can be useful when writing a comparison function to allow one to return + [Ordering.t] values and transform them to [int]s later. *) +val to_int : t -> int + +module Export : sig + type _ordering = t = + | Less + | Equal + | Greater +end diff --git a/src/poly0.ml b/src/poly0.ml new file mode 100644 index 0000000..7fa2a9a --- /dev/null +++ b/src/poly0.ml @@ -0,0 +1,18 @@ +(** Primitives for polymorphic compare. *) + +(*_ Polymorphic compiler primitives can't be aliases as this doesn't play well with + inlining. (If aliased without a type annotation, the compiler would implement them + using the generic code doing a C call, and it's this code that would be inlined.) As a + result we have to copy the [external ...] declaration here. *) +external ( < ) : 'a -> 'a -> bool = "%lessthan" +external ( <= ) : 'a -> 'a -> bool = "%lessequal" +external ( <> ) : 'a -> 'a -> bool = "%notequal" +external ( = ) : 'a -> 'a -> bool = "%equal" +external ( > ) : 'a -> 'a -> bool = "%greaterthan" +external ( >= ) : 'a -> 'a -> bool = "%greaterequal" +external ascending : 'a -> 'a -> int = "%compare" +external compare : 'a -> 'a -> int = "%compare" +external equal : 'a -> 'a -> bool = "%equal" +let descending x y = compare y x +let max = Caml.max +let min = Caml.min diff --git a/src/poly0.mli b/src/poly0.mli new file mode 100644 index 0000000..01bfc8e --- /dev/null +++ b/src/poly0.mli @@ -0,0 +1,22 @@ +(** A module containing the ad-hoc polymorphic comparison functions. Useful when + you want to use polymorphic compare in some small scope of a file within which + polymorphic compare has been hidden *) + +external compare : 'a -> 'a -> int = "%compare" + +(** [ascending] is identical to [compare]. [descending x y = ascending y x]. These are + intended to be mnemonic when used like [List.sort ~compare:ascending] and [List.sort + ~compare:descending], since they cause the list to be sorted in ascending or + descending order, respectively. *) +val ascending : 'a -> 'a -> int +val descending : 'a -> 'a -> int + +external ( < ) : 'a -> 'a -> bool = "%lessthan" +external ( <= ) : 'a -> 'a -> bool = "%lessequal" +external ( <> ) : 'a -> 'a -> bool = "%notequal" +external ( = ) : 'a -> 'a -> bool = "%equal" +external ( > ) : 'a -> 'a -> bool = "%greaterthan" +external ( >= ) : 'a -> 'a -> bool = "%greaterequal" +external equal : 'a -> 'a -> bool = "%equal" +val min : 'a -> 'a -> 'a +val max : 'a -> 'a -> 'a diff --git a/src/popcount.ml b/src/popcount.ml new file mode 100644 index 0000000..34e140d --- /dev/null +++ b/src/popcount.ml @@ -0,0 +1,39 @@ +open! Import + +(* C stub for int popcount to use the POPCNT instruction where possible *) +external int_popcount : int -> int = "Base_int_math_int_popcount" [@@noalloc] + +(* To maintain javascript compatibility and enable unboxing, we implement popcount in + OCaml rather than use C stubs. Implementation adapted from: + https://en.wikipedia.org/wiki/Hamming_weight#Efficient_implementation *) +let int64_popcount = + let open Caml.Int64 in + let ( + ) = add in + let ( - ) = sub in + let ( * ) = mul in + let ( lsr ) = shift_right_logical in + let ( land ) = logand in + let m1 = 0x5555555555555555L in (* 0b01010101... *) + let m2 = 0x3333333333333333L in (* 0b00110011... *) + let m4 = 0x0f0f0f0f0f0f0f0fL in (* 0b00001111... *) + let h01 = 0x0101010101010101L in (* 1 bit set per byte *) + (fun x -> + (* gather the bit count for every pair of bits *) + let x = x - ((x lsr 1) land m1) in + (* gather the bit count for every 4 bits *) + let x = (x land m2) + ((x lsr 2) land m2) in + (* gather the bit count for every byte *) + let x = (x + (x lsr 4)) land m4 in + (* sum the bit counts in the top byte and shift it down *) + to_int ((x * h01) lsr 56)) [@inline] + +let int32_popcount = + (* On 64-bit systems, this is faster than implementing using [int32] arithmetic. *) + let mask = 0xffff_ffffL in + (fun x -> int64_popcount (Caml.Int64.logand (Caml.Int64.of_int32 x) mask)) [@inline] + +let nativeint_popcount = + match Caml.Nativeint.size with + | 32 -> (fun x -> int32_popcount (Caml.Nativeint.to_int32 x)) [@inline] + | 64 -> (fun x -> int64_popcount (Caml.Int64.of_nativeint x)) [@inline] + | _ -> assert false diff --git a/src/popcount.mli b/src/popcount.mli new file mode 100644 index 0000000..8aa4fee --- /dev/null +++ b/src/popcount.mli @@ -0,0 +1,11 @@ +(** This module exposes popcount functions (which count the number of ones in a bitstring) + for the various integer types. + + Functions are exposed in their respective modules. *) + +open! Import + +val int_popcount : int -> int +val int32_popcount : int32 -> int +val int64_popcount : int64 -> int +val nativeint_popcount : nativeint -> int diff --git a/src/pow_overflow_bounds.ml b/src/pow_overflow_bounds.ml new file mode 100644 index 0000000..b3fd8e6 --- /dev/null +++ b/src/pow_overflow_bounds.ml @@ -0,0 +1,425 @@ +(* This file was autogenerated by ../generate/generate_pow_overflow_bounds.exe *) + +open! Import + +module Array = Array0 + +(* We have to use Int64.to_int_exn instead of int constants to make + sure that file can be preprocessed on 32-bit machines. *) + +let overflow_bound_max_int32_value : int32 = + 2147483647l + +let int32_positive_overflow_bounds : int32 array = + [| 2147483647l + ; 2147483647l + ; 46340l + ; 1290l + ; 215l + ; 73l + ; 35l + ; 21l + ; 14l + ; 10l + ; 8l + ; 7l + ; 5l + ; 5l + ; 4l + ; 4l + ; 3l + ; 3l + ; 3l + ; 3l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + |] + +let overflow_bound_max_int_value : int = + (-1) lsr 1 + +let int_positive_overflow_bounds : int array = + match Int_conversions.num_bits_int with + | 32 -> Array.map int32_positive_overflow_bounds ~f:Caml.Int32.to_int + | 63 -> + [| Caml.Int64.to_int 4611686018427387903L + ; Caml.Int64.to_int 4611686018427387903L + ; Caml.Int64.to_int 2147483647L + ; 1664510 + ; 46340 + ; 5404 + ; 1290 + ; 463 + ; 215 + ; 118 + ; 73 + ; 49 + ; 35 + ; 27 + ; 21 + ; 17 + ; 14 + ; 12 + ; 10 + ; 9 + ; 8 + ; 7 + ; 7 + ; 6 + ; 5 + ; 5 + ; 5 + ; 4 + ; 4 + ; 4 + ; 4 + ; 3 + ; 3 + ; 3 + ; 3 + ; 3 + ; 3 + ; 3 + ; 3 + ; 3 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 1 + ; 1 + |] + | 31 -> + [| 1073741823 + ; 1073741823 + ; 32767 + ; 1023 + ; 181 + ; 63 + ; 31 + ; 19 + ; 13 + ; 10 + ; 7 + ; 6 + ; 5 + ; 4 + ; 4 + ; 3 + ; 3 + ; 3 + ; 3 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + |] + | _ -> assert false + +let overflow_bound_max_int63_on_int64_value : int64 = + 4611686018427387903L + +let int63_on_int64_positive_overflow_bounds : int64 array = + [| 4611686018427387903L + ; 4611686018427387903L + ; 2147483647L + ; 1664510L + ; 46340L + ; 5404L + ; 1290L + ; 463L + ; 215L + ; 118L + ; 73L + ; 49L + ; 35L + ; 27L + ; 21L + ; 17L + ; 14L + ; 12L + ; 10L + ; 9L + ; 8L + ; 7L + ; 7L + ; 6L + ; 5L + ; 5L + ; 5L + ; 4L + ; 4L + ; 4L + ; 4L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 1L + ; 1L + |] + +let overflow_bound_max_int64_value : int64 = + 9223372036854775807L + +let int64_positive_overflow_bounds : int64 array = + [| 9223372036854775807L + ; 9223372036854775807L + ; 3037000499L + ; 2097151L + ; 55108L + ; 6208L + ; 1448L + ; 511L + ; 234L + ; 127L + ; 78L + ; 52L + ; 38L + ; 28L + ; 22L + ; 18L + ; 15L + ; 13L + ; 11L + ; 9L + ; 8L + ; 7L + ; 7L + ; 6L + ; 6L + ; 5L + ; 5L + ; 5L + ; 4L + ; 4L + ; 4L + ; 4L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 1L + |] + +let int64_negative_overflow_bounds : int64 array = + [| -9223372036854775807L + ; -9223372036854775807L + ; -3037000499L + ; -2097151L + ; -55108L + ; -6208L + ; -1448L + ; -511L + ; -234L + ; -127L + ; -78L + ; -52L + ; -38L + ; -28L + ; -22L + ; -18L + ; -15L + ; -13L + ; -11L + ; -9L + ; -8L + ; -7L + ; -7L + ; -6L + ; -6L + ; -5L + ; -5L + ; -5L + ; -4L + ; -4L + ; -4L + ; -4L + ; -3L + ; -3L + ; -3L + ; -3L + ; -3L + ; -3L + ; -3L + ; -3L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -1L + |] diff --git a/src/ppx_compare_lib.ml b/src/ppx_compare_lib.ml new file mode 100644 index 0000000..6e6809c --- /dev/null +++ b/src/ppx_compare_lib.ml @@ -0,0 +1,113 @@ +open Import0 + +let phys_equal = phys_equal +external polymorphic_compare : 'a -> 'a -> int = "%compare" +external polymorphic_equal : 'a -> 'a -> bool = "%equal" +external ( && ) : bool -> bool -> bool = "%sequand" + +let compare_abstract ~type_name _ _ = + Printf.ksprintf failwith + "Compare called on the type %s, which is abstract in an implementation." + type_name + +let equal_abstract ~type_name _ _ = + Printf.ksprintf failwith + "Equal called on the type %s, which is abstract in an implementation." + type_name + +type 'a compare = 'a -> 'a -> int +type 'a equal = 'a -> 'a -> bool + +module Builtin = struct + let compare_bool : bool compare = Poly.compare + let compare_char : char compare = Poly.compare + let compare_float : float compare = Poly.compare + let compare_int : int compare = Poly.compare + let compare_int32 : int32 compare = Poly.compare + let compare_int64 : int64 compare = Poly.compare + let compare_nativeint : nativeint compare = Poly.compare + let compare_string : string compare = Poly.compare + let compare_unit : unit compare = Poly.compare + + let compare_array compare_elt a b = + if phys_equal a b then + 0 + else + let len_a = Array0.length a in + let len_b = Array0.length b in + let ret = compare len_a len_b in + if ret <> 0 then ret + else + let rec loop i = + if i = len_a then + 0 + else + let l = Array0.unsafe_get a i + and r = Array0.unsafe_get b i in + let res = compare_elt l r in + if res <> 0 then res + else loop (i + 1) + in + loop 0 + + let rec compare_list compare_elt a b = + match a, b with + | [] , [] -> 0 + | [] , _ -> -1 + | _ , [] -> 1 + | x::xs, y::ys -> + let res = compare_elt x y in + if res <> 0 then res + else compare_list compare_elt xs ys + + let compare_option compare_elt a b = + match a, b with + | None , None -> 0 + | None , Some _ -> -1 + | Some _, None -> 1 + | Some a, Some b -> compare_elt a b + + let compare_ref compare_elt a b = compare_elt !a !b + + let equal_bool : bool equal = Poly.equal + let equal_char : char equal = Poly.equal + let equal_int : int equal = Poly.equal + let equal_int32 : int32 equal = Poly.equal + let equal_int64 : int64 equal = Poly.equal + let equal_nativeint : nativeint equal = Poly.equal + let equal_string : string equal = Poly.equal + let equal_unit : unit equal = Poly.equal + + (* [Poly.equal] is IEEE compliant, which is not what we want here. *) + let equal_float x y = equal_int (compare_float x y) 0 + + let equal_array equal_elt a b = + phys_equal a b || + (let len_a = Array0.length a in + let len_b = Array0.length b in + equal len_a len_b && + (let rec loop i = + i = len_a || + (let l = Array0.unsafe_get a i + and r = Array0.unsafe_get b i in + equal_elt l r && loop (i + 1)) + in + loop 0)) + + let rec equal_list equal_elt a b = + match a, b with + | [] , [] -> true + | [] , _ + | _ , [] -> false + | x::xs, y::ys -> + equal_elt x y && equal_list equal_elt xs ys + + let equal_option equal_elt a b = + match a, b with + | None , None -> true + | None , Some _ + | Some _, None -> false + | Some a, Some b -> equal_elt a b + + let equal_ref equal_elt a b = equal_elt !a !b +end diff --git a/src/ppx_compare_lib.mli b/src/ppx_compare_lib.mli new file mode 100644 index 0000000..60aeaef --- /dev/null +++ b/src/ppx_compare_lib.mli @@ -0,0 +1,50 @@ +(** Runtime support for auto-generated comparators. Users are not intended to use this + module directly. *) + +val phys_equal : 'a -> 'a -> bool + +(*_ /!\ WARNING /!\ all these functions need to declared "external" in order to get the + lazy behavior for ( && ) (relied upon by [@@deriving_inline equal][@@@end]) and the type-based + specialization for equal/compare. *) +external polymorphic_compare : 'a -> 'a -> int = "%compare" +external polymorphic_equal : 'a -> 'a -> bool = "%equal" +external ( && ) : bool -> bool -> bool = "%sequand" + +type 'a compare = 'a -> 'a -> int +type 'a equal = 'a -> 'a -> bool + +(** Raise when fully applied *) +val compare_abstract : type_name:string -> _ compare +val equal_abstract : type_name:string -> _ equal + +module Builtin : sig + val compare_bool : bool compare + val compare_char : char compare + val compare_float : float compare + val compare_int : int compare + val compare_int32 : int32 compare + val compare_int64 : int64 compare + val compare_nativeint : nativeint compare + val compare_string : string compare + val compare_unit : unit compare + + val compare_array : 'a compare -> 'a array compare + val compare_list : 'a compare -> 'a list compare + val compare_option : 'a compare -> 'a option compare + val compare_ref : 'a compare -> 'a ref compare + + val equal_bool : bool equal + val equal_char : char equal + val equal_float : float equal + val equal_int : int equal + val equal_int32 : int32 equal + val equal_int64 : int64 equal + val equal_nativeint : nativeint equal + val equal_string : string equal + val equal_unit : unit equal + + val equal_array : 'a equal -> 'a array equal + val equal_list : 'a equal -> 'a list equal + val equal_option : 'a equal -> 'a option equal + val equal_ref : 'a equal -> 'a ref equal +end diff --git a/src/ppx_enumerate_lib.ml b/src/ppx_enumerate_lib.ml new file mode 100644 index 0000000..bde251f --- /dev/null +++ b/src/ppx_enumerate_lib.ml @@ -0,0 +1 @@ +module List = List diff --git a/src/ppx_hash_lib.ml b/src/ppx_hash_lib.ml new file mode 100644 index 0000000..17a4578 --- /dev/null +++ b/src/ppx_hash_lib.ml @@ -0,0 +1,6 @@ +(** This module is for use by ppx_hash, and is thus not in the interface of Base. *) +module Std = struct + + (** @canonical Base.Hash *) + module Hash = Hash +end diff --git a/src/ppx_sexp_conv_lib.ml b/src/ppx_sexp_conv_lib.ml new file mode 100644 index 0000000..44973cf --- /dev/null +++ b/src/ppx_sexp_conv_lib.ml @@ -0,0 +1 @@ +include Sexplib diff --git a/src/pretty_printer.ml b/src/pretty_printer.ml new file mode 100644 index 0000000..5a6be9f --- /dev/null +++ b/src/pretty_printer.ml @@ -0,0 +1,29 @@ +open! Import + +let r = ref [ "Base.Sexp.pp_hum" ] + +let all () = !r + +let register p = r := p :: !r + +module type S = sig + type t + val pp : Formatter.t -> t -> unit +end + +module Register_pp (M : sig + include S + val module_name : string + end) = struct + include M + let () = register (M.module_name ^ ".pp") +end + +module Register (M : sig + type t + val module_name : string + val to_string : t -> string + end) = Register_pp (struct + include M + let pp formatter t = Caml.Format.pp_print_string formatter (M.to_string t) + end) diff --git a/src/pretty_printer.mli b/src/pretty_printer.mli new file mode 100644 index 0000000..a0c40d7 --- /dev/null +++ b/src/pretty_printer.mli @@ -0,0 +1,44 @@ +(** A list of pretty printers for various types, for use in toplevels. + + [Pretty_printer] has a [string list ref] with the names of [pp] functions matching the + interface: + + {[ + val pp : Format.formatter -> t -> unit + ]} + + The names are actually OCaml identifier names, e.g., "Base.Int.pp". Code for + building toplevels (this code is not in Base) evaluates the strings to yield the + pretty printers and register them with the OCaml runtime. *) + +open! Import + +(** [all ()] returns all pretty printers that have been [register]ed. *) +val all : unit -> string list + +(** Modules that provide a pretty printer will match [S]. *) +module type S = sig + type t + val pp : Formatter.t -> t -> unit +end + +(** [Register] builds a [pp] function from a [to_string] function, and adds the + [module_name ^ ".pp"] to the list of pretty printers. The idea is to statically + guarantee that one has the desired [pp] function at the same point where the [name] is + added. *) +module Register (M : sig + type t + val module_name : string + val to_string : t -> string + end) : S with type t := M.t + +(** [Register_pp] is like [Register], but allows a custom [pp] function rather than using + [to_string]. *) +module Register_pp (M : sig + include S + val module_name : string + end) : S with type t := M.t + +(** [register name] adds [name] to the list of pretty printers. Use the [Register] + functor if possible. *) +val register : string -> unit diff --git a/src/printf.ml b/src/printf.ml new file mode 100644 index 0000000..abbe5d2 --- /dev/null +++ b/src/printf.ml @@ -0,0 +1,8 @@ +open! Import0 + +include Caml.Printf + +(** failwith, invalid_arg, and exit accepting printf's format. *) + +let failwithf fmt = ksprintf (fun s () -> failwith s) fmt +let invalid_argf fmt = ksprintf (fun s () -> invalid_arg s) fmt diff --git a/src/printf.mli b/src/printf.mli new file mode 100644 index 0000000..a0fbabc --- /dev/null +++ b/src/printf.mli @@ -0,0 +1,132 @@ +(** Functions for formatted output. + + [fprintf] and related functions format their arguments according to the given format + string. The format string is a character string which contains two types of objects: + plain characters, which are simply copied to the output channel, and conversion + specifications, each of which causes conversion and printing of arguments. + + Conversion specifications have the following form: + + {[% [flags] [width] [.precision] type]} + + In short, a conversion specification consists in the [%] character, followed by + optional modifiers and a type which is made of one or two characters. + + The types and their meanings are: + + - [d], [i]: convert an integer argument to signed decimal. + - [u], [n], [l], [L], or [N]: convert an integer argument to unsigned + decimal. Warning: [n], [l], [L], and [N] are used for [scanf], and should not be used + for [printf]. + - [x]: convert an integer argument to unsigned hexadecimal, using lowercase letters. + - [X]: convert an integer argument to unsigned hexadecimal, using uppercase letters. + - [o]: convert an integer argument to unsigned octal. + - [s]: insert a string argument. + - [S]: convert a string argument to OCaml syntax (double quotes, escapes). + - [c]: insert a character argument. + - [C]: convert a character argument to OCaml syntax (single quotes, escapes). + - [f]: convert a floating-point argument to decimal notation, in the style [dddd.ddd]. + - [F]: convert a floating-point argument to OCaml syntax ([dddd.] or [dddd.ddd] or + [d.ddd e+-dd]). + - [e] or [E]: convert a floating-point argument to decimal notation, in the style + [d.ddd e+-dd] (mantissa and exponent). + - [g] or [G]: convert a floating-point argument to decimal notation, in style [f] or + [e], [E] (whichever is more compact). Moreover, any trailing zeros are removed from + the fractional part of the result and the decimal-point character is removed if there + is no fractional part remaining. + - [h] or [H]: convert a floating-point argument to hexadecimal notation, in the style + [0xh.hhhh e+-dd] (hexadecimal mantissa, exponent in decimal and denotes a power of 2). + - [B]: convert a boolean argument to the string true or false + - [b]: convert a boolean argument (deprecated; do not use in new programs). + - [ld], [li], [lu], [lx], [lX], [lo]: convert an int32 argument to the format + specified by the second letter (decimal, hexadecimal, etc). + - [nd], [ni], [nu], [nx], [nX], [no]: convert a nativeint argument to the format + specified by the second letter. + - [Ld], [Li], [Lu], [Lx], [LX], [Lo]: convert an int64 argument to the format + specified by the second letter. + - [a]: user-defined printer. Take two arguments and apply the first one to outchan + (the current output channel) and to the second argument. The first argument must + therefore have type [out_channel -> 'b -> unit] and the second ['b]. The output + produced by the function is inserted in the output of [fprintf] at the current point. + - [t]: same as [%a], but take only one argument (with type [out_channel -> unit]) and + apply it to [outchan]. + - [{ fmt %}]: convert a format string argument to its type digest. The argument must + have the same type as the internal format string [fmt]. + - [( fmt %)]: format string substitution. Take a format string argument and substitute + it to the internal format string fmt to print following arguments. The argument must + have the same type as the internal format string fmt. + - [!]: take no argument and flush the output. + - [%]: take no argument and output one [%] character. + - [@]: take no argument and output one [@] character. + - [,]: take no argument and output nothing: a no-op delimiter for conversion + specifications. + + The optional [flags] are: + + - [-]: left-justify the output (default is right justification). + - [0]: for numerical conversions, pad with zeroes instead of spaces. + - [+]: for signed numerical conversions, prefix number with a [+] sign if positive. + - space: for signed numerical conversions, prefix number with a space if positive. + - [#]: request an alternate formatting style for the hexadecimal and octal integer + types ([x], [X], [o], [lx], [lX], [lo], [Lx], [LX], [Lo]). + + The optional [width] is an integer indicating the minimal width of the result. For + instance, [%6d] prints an integer, prefixing it with spaces to fill at least 6 + characters. + + The optional [precision] is a dot [.] followed by an integer indicating how many + digits follow the decimal point in the [%f], [%e], and [%E] conversions. For instance, + [%.4f] prints a [float] with 4 fractional digits. + + The integer in a [width] or [precision] can also be specified as [*], in which case an + extra integer argument is taken to specify the corresponding [width] or + [precision]. This integer argument precedes immediately the argument to print. For + instance, [%.*f] prints a float with as many fractional digits as the value of the + argument given before the float. +*) + +open! Import0 + +(** Same as [fprintf], but does not print anything. Useful for ignoring some material when + conditionally printing. *) +val ifprintf : 'a -> ('r, 'a, 'c, unit) format4 -> 'r + +(** Same as [fprintf], but instead of printing on an output channel, returns a string. *) +val sprintf : ('r, unit, string) format -> 'r + +(** Same as [fprintf], but instead of printing on an output channel, appends the formatted + arguments to the given extensible buffer. *) +val bprintf : Caml.Buffer.t -> ('r, Caml.Buffer.t, unit) format -> 'r + +(** Same as [sprintf], but instead of returning the string, passes it to the first + argument. *) +val ksprintf : (string -> 'a) -> ('r, unit, string, 'a) format4 -> 'r + +(** Same as [bprintf], but instead of returning immediately, passes the buffer, after + printing, to its first argument. *) +val kbprintf : (Caml.Buffer.t -> 'a) -> Caml.Buffer.t -> ('r, Caml.Buffer.t, unit, 'a) format4 -> 'r + +(** {6 Formatting error and exit functions} + + These functions have a polymorphic return type, since they do not return. Naively, + this doesn't mix well with variadic functions: if you define, say, + + {[ + let f fmt = ksprintf (fun s -> failwith s) fmt + ]} + + then you find that [f "%d" : int -> 'a], as you'd expect, and [f "%d" 7 : 'a]. The + problem with this is that ['a] unifies with (say) [int -> 'b], so [f "%d" 7 4] is not + a type error -- the [4] is simply ignored. + + To mitigate this problem, these functions all take a final unit parameter. These + rarely arise as formatting positional parameters (they can do with e.g. "%a", but not + in a useful way) so they serve as an effective signpost for + "end of formatting arguments". *) + + +(** Raises [Failure]. *) +val failwithf : ('r, unit, string, unit -> _) format4 -> 'r + +(** Raises [Invalid_arg]. *) +val invalid_argf : ('r, unit, string, unit -> _) format4 -> 'r diff --git a/src/queue.ml b/src/queue.ml new file mode 100644 index 0000000..65617c9 --- /dev/null +++ b/src/queue.ml @@ -0,0 +1,503 @@ +open! Import + + +(* [t] stores the [t.length] queue elements at consecutive increasing indices of [t.elts], + mod the capacity of [t], which is [Option_array.length t.elts]. The capacity is + required to be a power of two (user-requested capacities are rounded up to the nearest + power), so that mod can quickly be computed using [land t.mask], where [t.mask = + capacity t - 1]. So, queue element [i] is at [t.elts.( (t.front + i) land t.mask )]. + + [num_mutations] is used to detect modification during iteration. *) +type 'a t = + { mutable num_mutations : int + ; mutable front : int + ; mutable mask : int + ; mutable length : int + ; mutable elts : 'a Option_array.t + } +[@@deriving_inline sexp_of] +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + function + | { num_mutations = v_num_mutations; front = v_front; mask = v_mask; + length = v_length; elts = v_elts } -> + let bnds = [] in + let bnds = + let arg = Option_array.sexp_of_t _of_a v_elts in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "elts"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_length in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "length"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_mask in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "mask"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_front in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "front"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_num_mutations in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "num_mutations"; arg]) + :: bnds in + Ppx_sexp_conv_lib.Sexp.List bnds +[@@@end] + +module type S = Queue_intf.S + +let inc_num_mutations t = t.num_mutations <- t.num_mutations + 1 + +let capacity t = t.mask + 1 + +let elts_index t i = (t.front + i) land t.mask + +let unsafe_get t i = Option_array.unsafe_get_some_exn t.elts (elts_index t i) +let unsafe_is_set t i = Option_array.unsafe_is_some t.elts (elts_index t i) +let unsafe_set t i a = Option_array.unsafe_set_some t.elts (elts_index t i) a +let unsafe_unset t i = Option_array.unsafe_set_none t.elts (elts_index t i) + +let check_index_exn t i = + if i < 0 || i >= t.length + then Error.raise_s + (Sexp.message "Queue index out of bounds" + [ "index" , i |> Int.sexp_of_t + ; "length", t.length |> Int.sexp_of_t ]) +;; + +let get t i = check_index_exn t i; unsafe_get t i +let set t i a = + check_index_exn t i; + inc_num_mutations t; + unsafe_set t i a; +;; + +let is_empty t = t.length = 0 + +let length { length; _ } = length + +let ensure_no_mutation t num_mutations = + if t.num_mutations <> num_mutations + then Error.raise_s + (Sexp.message "mutation of queue during iteration" + [ "", t |> sexp_of_t (fun _ -> Sexp.Atom "_")]) +;; + +let compare = + let rec unsafe_compare_from compare_elt pos ~t1 ~t2 ~len1 ~len2 ~mut1 ~mut2 = + match pos = len1, pos = len2 with + | true , true -> 0 + | true , false -> -1 + | false, true -> 1 + | false, false -> + let x = compare_elt (unsafe_get t1 pos) (unsafe_get t2 pos) in + ensure_no_mutation t1 mut1; + ensure_no_mutation t2 mut2; + match x with + | 0 -> unsafe_compare_from compare_elt (pos + 1) ~t1 ~t2 ~len1 ~len2 ~mut1 ~mut2 + | n -> n + in + fun compare_elt t1 t2 -> + if phys_equal t1 t2 then + 0 + else + unsafe_compare_from compare_elt 0 ~t1 ~t2 + ~len1:t1.length + ~len2:t2.length + ~mut1:t1.num_mutations + ~mut2:t2.num_mutations +;; + +let equal = + let rec unsafe_equal_from equal_elt pos ~t1 ~t2 ~mut1 ~mut2 ~len = + pos = len + || + (let b = equal_elt (unsafe_get t1 pos) (unsafe_get t2 pos) in + ensure_no_mutation t1 mut1; + ensure_no_mutation t2 mut2; + b && unsafe_equal_from equal_elt (pos + 1) ~t1 ~t2 ~mut1 ~mut2 ~len) + in + fun equal_elt t1 t2 -> + phys_equal t1 t2 + || + (let len1 = t1.length in + let len2 = t2.length in + len1 = len2 + && + unsafe_equal_from equal_elt 0 ~t1 ~t2 + ~len:len1 + ~mut1:t1.num_mutations + ~mut2:t2.num_mutations) +;; + +let invariant invariant_a t = + let + { num_mutations + ; mask = _ + ; elts + ; front + ; length } = t + in + assert (front >= 0); + assert (front < capacity t); + let capacity = capacity t in + assert (capacity = Option_array.length elts); + assert (capacity >= 1); + assert (Int.is_pow2 capacity); + assert (length >= 0); + assert (length <= capacity); + for i = 0 to capacity- 1 do + if i < t.length + then (invariant_a (unsafe_get t i); ensure_no_mutation t num_mutations) + else assert (not (unsafe_is_set t i)) + done +;; + +let create (type a) ?capacity () : a t = + let capacity = + match capacity with + | None -> 1 + | Some capacity -> + if capacity < 0 + then Error.raise_s + (Sexp.message "cannot have queue with negative capacity" + [ "capacity", capacity |> Int.sexp_of_t ]) + else if capacity = 0 + then 1 + else Int.ceil_pow2 capacity + in + { num_mutations = 0 + ; front = 0 + ; mask = capacity - 1 + ; length = 0 + ; elts = Option_array.create ~len:capacity + } +;; + +let blit_to_array ~src dst = + assert (src.length <= Option_array.length dst); + let front_len = Int.min src.length (capacity src - src.front) in + let rest_len = src.length - front_len in + Option_array.blit ~len:front_len ~src:src.elts ~src_pos:src.front ~dst ~dst_pos:0; + Option_array.blit ~len:rest_len ~src:src.elts ~src_pos:0 ~dst ~dst_pos:front_len; +;; + +let set_capacity t desired_capacity = + (* We allow arguments less than 1 to [set_capacity], but translate them to 1 to simplify + the code that relies on the array length being a power of 2. *) + inc_num_mutations t; + let new_capacity = Int.ceil_pow2 (max 1 (max desired_capacity t.length)) in + if new_capacity <> capacity t then begin + let dst = Option_array.create ~len:new_capacity in + blit_to_array ~src:t dst; + t.front <- 0; + t.mask <- new_capacity - 1; + t.elts <- dst; + end; +;; + +let enqueue t a = + inc_num_mutations t; + if t.length = capacity t then set_capacity t (2 * t.length); + unsafe_set t t.length a; + t.length <- t.length + 1; +;; + +let dequeue_nonempty t = + inc_num_mutations t; + let elts = t.elts in + let front = t.front in + let res = Option_array.get_some_exn elts front in + Option_array.set_none elts front; + t.front <- elts_index t 1; + t.length <- t.length - 1; + res +;; + +let dequeue_exn t = + if is_empty t + then raise Caml.Queue.Empty + else dequeue_nonempty t +;; + +let dequeue t = + if is_empty t + then None + else Some (dequeue_nonempty t) +;; + +let front_nonempty t = Option_array.unsafe_get_some_exn t.elts t.front +let last_nonempty t = unsafe_get t (t.length - 1) + +let peek t = + if is_empty t + then None + else Some (front_nonempty t) +;; + +let peek_exn t = + if is_empty t + then raise Caml.Queue.Empty + else front_nonempty t +;; + +let last t = + if is_empty t + then None + else Some (last_nonempty t) +;; + +let last_exn t = + if is_empty t + then raise Caml.Queue.Empty + else last_nonempty t +;; + +let clear t = + inc_num_mutations t; + if t.length > 0 then begin + for i = 0 to t.length - 1 do + unsafe_unset t i; + done; + t.length <- 0; + t.front <- 0; + end; +;; + +let blit_transfer ~src ~dst ?len () = + inc_num_mutations src; + inc_num_mutations dst; + let len = + match len with + | None -> src.length + | Some len -> + if len < 0 + then Error.raise_s + (Sexp.message "Queue.blit_transfer: negative length" + [ "length", len |> Int.sexp_of_t ]); + min len src.length + in + if len > 0 then begin + set_capacity dst (max (capacity dst) (dst.length + len)); + let dst_start = dst.front + dst.length in + for i = 0 to len - 1 do + (* This is significantly faster than simply [enqueue dst (dequeue_nonempty src)] *) + let src_i = (src.front + i) land src.mask in + let dst_i = (dst_start + i) land dst.mask in + Option_array.unsafe_set_some dst.elts dst_i + (Option_array.unsafe_get_some_exn src.elts src_i); + Option_array.unsafe_set_none src.elts src_i; + done; + dst.length <- dst.length + len; + src.front <- (src.front + len) land src.mask; + src.length <- src.length - len; + end; +;; + +let enqueue_all t l = + (* Traversing the list up front to compute its length is probably (but not definitely) + better than doubling the underlying array size several times for large queues. *) + set_capacity t (Int.max (capacity t) (t.length + List.length l)); + List.iter l ~f:(fun x -> enqueue t x) +;; + +let fold t ~init ~f = + if t.length = 0 + then init + else begin + let num_mutations = t.num_mutations in + let r = ref init in + for i = 0 to t.length - 1 do + r := f !r (unsafe_get t i); + ensure_no_mutation t num_mutations; + done; + !r + end; +;; + +let foldi t ~init ~f = + let i = ref 0 in + fold t ~init ~f:(fun acc a -> + let acc = f !i acc a in + i := !i + 1; + acc) +;; + + +(* [iter] is implemented directly because implementing it in terms of [fold] is + slower. *) +let iter t ~f = + let num_mutations = t.num_mutations in + for i = 0 to t.length - 1 do + f (unsafe_get t i); + ensure_no_mutation t num_mutations; + done; +;; + +let iteri t ~f = + let num_mutations = t.num_mutations in + for i = 0 to t.length - 1 do + f i (unsafe_get t i); + ensure_no_mutation t num_mutations; + done; +;; + +module C = + Indexed_container.Make (struct + type nonrec 'a t = 'a t + let fold = fold + let iter = `Custom iter + let length = `Custom length + let foldi = `Custom foldi + let iteri = `Custom iteri + end) + +let count = C.count +let exists = C.exists +let find = C.find +let find_map = C.find_map +let fold_result = C.fold_result +let fold_until = C.fold_until +let for_all = C.for_all +let max_elt = C.max_elt +let mem = C.mem +let min_elt = C.min_elt +let sum = C.sum +let to_list = C.to_list + +let counti = C.counti +let existsi = C.existsi +let find_mapi = C.find_mapi +let findi = C.findi +let for_alli = C.for_alli + + +(* For [concat_map], [filter_map], and [filter], we don't create [t_result] with [t]'s + capacity because we have no idea how many elements [t_result] will ultimately hold. *) +let concat_map t ~f = + let t_result = create () in + iter t ~f:(fun a -> List.iter (f a) ~f:(fun b -> enqueue t_result b)); + t_result +;; + +let concat_mapi t ~f = + let t_result = create () in + iteri t ~f:(fun i a -> List.iter (f i a) ~f:(fun b -> enqueue t_result b)); + t_result +;; + +let filter_map t ~f = + let t_result = create () in + iter t ~f:(fun a -> + match f a with + | None -> () + | Some b -> enqueue t_result b); + t_result +;; + +let filter_mapi t ~f = + let t_result = create () in + iteri t ~f:(fun i a -> + match f i a with + | None -> () + | Some b -> enqueue t_result b); + t_result +;; + +let filter t ~f = + let t_result = create () in + iter t ~f:(fun a -> if f a then enqueue t_result a); + t_result +;; + +let filteri t ~f = + let t_result = create () in + iteri t ~f:(fun i a -> if f i a then enqueue t_result a); + t_result +;; + +let filter_inplace t ~f = + let t2 = filter t ~f in + clear t; + blit_transfer ~src:t2 ~dst:t (); +;; + +let filteri_inplace t ~f = + let t2 = filteri t ~f in + clear t; + blit_transfer ~src:t2 ~dst:t (); +;; + +let copy src = + let dst = create ~capacity:src.length () in + blit_to_array ~src dst.elts; + dst.length <- src.length; + dst +;; + +let of_list l = + (* Traversing the list up front to compute its length is probably (but not definitely) + better than doubling the underlying array size several times for large queues. *) + let t = create ~capacity:(List.length l) () in + List.iter l ~f:(fun x -> enqueue t x); + t +;; + +(* The queue [t] returned by [create] will have [t.length = 0], [t.front = 0], and + [capacity t = Int.ceil_pow2 len]. So, we only have to set [t.length] to [len] after + the blit to maintain all the invariants: [t.length] is equal to the number of elements + in the queue, [t.front] is the array index of the first element in the queue, and + [capacity t = Option_array.length t.elts]. *) +let init len ~f = + if len < 0 + then Error.raise_s + (Sexp.message "Queue.init: negative length" + [ "length", len |> Int.sexp_of_t ]); + let t = create ~capacity:len () in + assert (Option_array.length t.elts >= len); + for i = 0 to len - 1 do + Option_array.unsafe_set_some t.elts i (f i); + done; + t.length <- len; + t +;; + +let of_array a = init (Array.length a) ~f:(Array.unsafe_get a) + +let to_array t = Array.init t.length ~f:(fun i -> unsafe_get t i) + +let map ta ~f = + let num_mutations = ta.num_mutations in + let tb = create ~capacity:ta.length () in + tb.length <- ta.length; + for i = 0 to ta.length - 1 do + let b = f (unsafe_get ta i) in + ensure_no_mutation ta num_mutations; + Option_array.unsafe_set_some tb.elts i b; + done; + tb +;; + +let mapi t ~f = + let i = ref 0 in + map t ~f:(fun a -> + let result = f !i a in + i := !i + 1; + result) +;; + +let singleton x = + let t = create () in + enqueue t x; + t +;; + +let sexp_of_t sexp_of_a t = to_list t |> List.sexp_of_t sexp_of_a +let t_of_sexp a_of_sexp sexp = List.t_of_sexp a_of_sexp sexp |> of_list diff --git a/src/queue.mli b/src/queue.mli new file mode 100644 index 0000000..4e095d9 --- /dev/null +++ b/src/queue.mli @@ -0,0 +1 @@ +include Queue_intf.Queue (** @inline *) diff --git a/src/queue_intf.ml b/src/queue_intf.ml new file mode 100644 index 0000000..a403a47 --- /dev/null +++ b/src/queue_intf.ml @@ -0,0 +1,135 @@ +open! Import + +(** An interface for queues that follows Base's conventions, as opposed to OCaml's + standard [Queue] module. *) +module type S = sig + type 'a t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Indexed_container.S1 with type 'a t := 'a t + + (** [singleton a] returns a queue with one element. *) + val singleton : 'a -> 'a t + + (** [of_list list] returns a queue [t] with the elements of [list] in the same order as + the elements of [list] (i.e. the first element of [t] is the first element of the + list). *) + val of_list : 'a list -> 'a t + val of_array : 'a array -> 'a t + + (** [init n ~f] is equivalent to [of_list (List.init n ~f)]. *) + val init : int -> f:(int -> 'a) -> 'a t + + (** [enqueue t a] adds [a] to the end of [t].*) + val enqueue : 'a t -> 'a -> unit + + (** [enqueue_all t list] adds all elements in [list] to [t] in order of [list]. *) + val enqueue_all : 'a t -> 'a list -> unit + + (** [dequeue t] removes and returns the front element of [t], if any. *) + val dequeue : 'a t -> 'a option + val dequeue_exn : 'a t -> 'a + + (** [peek t] returns but does not remove the front element of [t], if any. *) + val peek : 'a t -> 'a option + val peek_exn : 'a t -> 'a + + (** [clear t] discards all elements from [t]. *) + val clear : _ t -> unit + + (** [copy t] returns a copy of [t]. *) + val copy : 'a t -> 'a t + + val map : 'a t -> f:( 'a -> 'b) -> 'b t + val mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + + (** Creates a new queue with elements equal to [List.concat_map ~f (to_list t)]. *) + val concat_map : 'a t -> f:( 'a -> 'b list) -> 'b t + val concat_mapi : 'a t -> f:(int -> 'a -> 'b list) -> 'b t + + (** [filter_map] creates a new queue with elements equal to [List.filter_map ~f (to_list + t)]. *) + val filter_map : 'a t -> f:( 'a -> 'b option) -> 'b t + val filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + + (** [filter] is like [filter_map], except with [List.filter]. *) + val filter : 'a t -> f:( 'a -> bool) -> 'a t + val filteri : 'a t -> f:(int -> 'a -> bool) -> 'a t + + (** [filter_inplace t ~f] removes all elements of [t] that don't satisfy [f]. If [f] + raises, [t] is unchanged. This is inplace in that it modifies [t]; however, it uses + space linear in the final length of [t]. *) + val filter_inplace : 'a t -> f:( 'a -> bool) -> unit + val filteri_inplace : 'a t -> f:(int -> 'a -> bool) -> unit + +end + +module type Queue = sig + (** A queue implemented with an array. + + The implementation will grow the array as necessary. The array will + never automatically be shrunk, but the size can be interrogated and set + with [capacity] and [set_capacity]. + + Iteration functions ([iter], [fold], [map], [concat_map], [filter], + [filter_map], [filter_inplace], and some functions from [Container.S1]) + will raise if the queue is modified during iteration. + + Also see {!Linked_queue}, which has different performance characteristics. *) + + module type S = S + + type 'a t [@@deriving_inline compare] + include + sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + end[@@ocaml.doc "@inline"] + [@@@end] + + include S with type 'a t := 'a t + + include Equal. S1 with type 'a t := 'a t + include Invariant.S1 with type 'a t := 'a t + + (** Create an empty queue. *) + val create + : ?capacity : int (** default is [1]. *) + -> unit + -> _ t + + (** [last t] returns the most recently enqueued element in [t], if any. *) + val last : 'a t -> 'a option + val last_exn : 'a t -> 'a + + (** Transfers up to [len] elements from the front of [src] to the end of [dst], removing + them from [src]. It is an error if [len < 0]. + + Aside from a call to [set_capacity dst] if needed, runs in O([len]) time *) + val blit_transfer + : src : 'a t + -> dst : 'a t + -> ?len : int (** default is [length src] *) + -> unit + -> unit + + (** [get t i] returns the [i]'th element in [t], where the 0'th element is at the front of + [t] and the [length t - 1] element is at the back. *) + val get : 'a t -> int -> 'a + val set : 'a t -> int -> 'a -> unit + + (** Returns the current length of the backing array. *) + val capacity : _ t -> int + + (** [set_capacity t c] sets the capacity of [t]'s backing array to at least [max c (length + t)]. If [t]'s capacity changes, then this involves allocating a new backing array and + copying the queue elements over. [set_capacity] may decrease the capacity of [t], if + [c < capacity t]. *) + val set_capacity : _ t -> int -> unit + +end diff --git a/src/random.ml b/src/random.ml new file mode 100644 index 0000000..b7b579d --- /dev/null +++ b/src/random.ml @@ -0,0 +1,242 @@ +open! Import +open Caml.Random + +module Array = Array0 +module Int = Int0 + +(* Unfortunately, because the standard library does not expose + [Caml.Random.State.default], we have to construct our own. We then build the + [Caml.Random.int], [Caml.Random.bool] functions and friends using that default state in + exactly the same way as the standard library. + + One other trickiness is that we need access to the unexposed [Caml.Random.State.assign] + function, which accesses the unexposed state representation. So, we copy the + [State.repr] type definition and [assign] function to here from the standard library, + and use [Obj.magic] to get access to the underlying implementation. *) + +(* Regression tests ought to be deterministic because that way anyone who breaks the test + knows that it's their code that broke the test. If tests are nondeterministic, a test + failure may instead happen because the test runner got unlucky and uncovered an + existing bug in the code supposedly being "protected" by the test in question. *) +let forbid_nondeterminism_in_tests ~allow_in_tests = + if am_testing then + match allow_in_tests with + | Some true -> () + | None | Some false -> + failwith "\ +initializing Random with a nondeterministic seed is forbidden in inline tests" +;; + +external random_seed: unit -> int array = "caml_sys_random_seed";; +let random_seed ?allow_in_tests () = + forbid_nondeterminism_in_tests ~allow_in_tests; + random_seed () +;; + +module State = struct + include State + + let make_self_init ?allow_in_tests () = + forbid_nondeterminism_in_tests ~allow_in_tests; + make_self_init () + ;; + + type repr = { st : int array; mutable idx : int } + + let assign t1 t2 = + let t1 = (Caml.Obj.magic t1 : repr) in + let t2 = (Caml.Obj.magic t2 : repr) in + Array.blit ~src:t2.st ~src_pos:0 ~dst:t1.st ~dst_pos:0 + ~len:(Array.length t1.st); + t1.idx <- t2.idx; + ;; + + let full_init t seed = assign t (make seed) + + let default = + (* We define Core's default random state as a copy of OCaml's default random state. + This means that programs that use Core.Random will see the same sequence of random + bits as if they had used Caml.Random. However, because [get_state] returns a + copy, Core.Random and OCaml.Random are not using the same state. If a program used + both, each of them would go through the same sequence of random bits. To avoid + that, we reset OCaml's random state to a different seed, giving it a different + sequence. *) + let t = Caml.Random.get_state () in + Caml.Random.init 137; + t + ;; + + let int_on_64bits t bound = + if bound <= 0x3FFFFFFF (* (1 lsl 30) - 1 *) + then int t bound + else Caml.Int64.to_int (int64 t (Caml.Int64.of_int bound)) + ;; + + let int_on_32bits t bound = + (* Not always true with the JavaScript backend. *) + if bound <= 0x3FFFFFFF (* (1 lsl 30) - 1 *) + then int t bound + else Caml.Int32.to_int (int32 t (Caml.Int32.of_int bound)) + ;; + + let int = + match Word_size.word_size with + | W64 -> int_on_64bits + | W32 -> int_on_32bits + ;; + + let full_range_int64 = + let open Caml.Int64 in + let bits state = of_int (bits state) in + fun state -> + logxor (bits state) + (logxor + (shift_left (bits state) 30) + (shift_left (bits state) 60)) + ;; + + let full_range_int32 = + let open Caml.Int32 in + let bits state = of_int (bits state) in + fun state -> + logxor (bits state) + (shift_left (bits state) 30) + ;; + + let full_range_int_on_64bits state = + Caml.Int64.to_int (full_range_int64 state) + ;; + + let full_range_int_on_32bits state = + Caml.Int32.to_int (full_range_int32 state) + ;; + + let full_range_int = + match Word_size.word_size with + | W64 -> full_range_int_on_64bits + | W32 -> full_range_int_on_32bits + ;; + + let full_range_nativeint_on_64bits state = + Caml.Int64.to_nativeint (full_range_int64 state) + ;; + + let full_range_nativeint_on_32bits state = + Caml.Nativeint.of_int32 (full_range_int32 state) + ;; + + let full_range_nativeint = + match Word_size.word_size with + | W64 -> full_range_nativeint_on_64bits + | W32 -> full_range_nativeint_on_32bits + ;; + + let [@inline never] raise_crossed_bounds name lower_bound upper_bound string_of_bound = + Printf.failwithf "Random.%s: crossed bounds [%s > %s]" + name (string_of_bound lower_bound) (string_of_bound upper_bound) () + ;; + + let int_incl = + let rec in_range state lo hi = + let int = full_range_int state in + if int >= lo && int <= hi + then int + else in_range state lo hi + in + fun state lo hi -> + if lo > hi then raise_crossed_bounds "int" lo hi Int.to_string; + let diff = hi - lo in + if diff = Int.max_value + then lo + ((full_range_int state) land Int.max_value) + else if diff >= 0 + then lo + int state (Int.succ diff) + else in_range state lo hi + ;; + + let int32_incl = + let open Int32_replace_polymorphic_compare in + let rec in_range state lo hi = + let int = full_range_int32 state in + if int >= lo && int <= hi + then int + else in_range state lo hi + in + let open Caml.Int32 in + fun state lo hi -> + if lo > hi then raise_crossed_bounds "int32" lo hi to_string; + let diff = sub hi lo in + if diff = max_int + then add lo (logand (full_range_int32 state) max_int) + else if diff >= 0l + then add lo (int32 state (succ diff)) + else in_range state lo hi + ;; + + let nativeint_incl = + let open Nativeint_replace_polymorphic_compare in + let rec in_range state lo hi = + let int = full_range_nativeint state in + if int >= lo && int <= hi + then int + else in_range state lo hi + in + let open Caml.Nativeint in + fun state lo hi -> + if lo > hi then raise_crossed_bounds "nativeint" lo hi to_string; + let diff = sub hi lo in + if diff = max_int + then add lo (logand (full_range_nativeint state) max_int) + else if diff >= 0n + then add lo (nativeint state (succ diff)) + else in_range state lo hi + ;; + + let int64_incl = + let open Int64_replace_polymorphic_compare in + let rec in_range state lo hi = + let int = full_range_int64 state in + if int >= lo && int <= hi + then int + else in_range state lo hi + in + let open Caml.Int64 in + fun state lo hi -> + if lo > hi then raise_crossed_bounds "int64" lo hi to_string; + let diff = sub hi lo in + if diff = max_int + then add lo (logand (full_range_int64 state) max_int) + else if diff >= 0L + then add lo (int64 state (succ diff)) + else in_range state lo hi + ;; + + let float_range state lo hi = + let open Float_replace_polymorphic_compare in + if lo > hi then raise_crossed_bounds "float" lo hi Caml.string_of_float; + lo +. float state (hi -. lo) + ;; +end + +let default = State.default + +let bits () = State.bits default + +let int x = State.int default x +let int32 x = State.int32 default x +let nativeint x = State.nativeint default x +let int64 x = State.int64 default x +let float x = State.float default x + +let int_incl x y = State.int_incl default x y +let int32_incl x y = State.int32_incl default x y +let nativeint_incl x y = State.nativeint_incl default x y +let int64_incl x y = State.int64_incl default x y +let float_range x y = State.float_range default x y + +let bool () = State.bool default + +let full_init seed = State.full_init default seed +let init seed = full_init [| seed |] +let self_init ?allow_in_tests () = full_init (random_seed ?allow_in_tests ()) + +let set_state s = State.assign default s diff --git a/src/random.mli b/src/random.mli new file mode 100644 index 0000000..5c3e3b8 --- /dev/null +++ b/src/random.mli @@ -0,0 +1,129 @@ +(** Pseudo-random number generation. + + This is a wrapper of the standard library's [Random] library, though it does not share + state with that library. +*) + +(*_ + (***********************************************************************) + (* *) + (* Objective Caml *) + (* *) + (* Damien Doligez, projet Para, INRIA Rocquencourt *) + (* *) + (* Copyright 1996 Institut National de Recherche en Informatique et *) + (* en Automatique. All rights reserved. This file is distributed *) + (* under the terms of the Apache 2.0 license. See ../THIRD-PARTY.txt *) + (* for details. *) + (* *) + (***********************************************************************) *) + +open! Import + +(** {6 Basic functions} *) + +(** Note that all of these "basic" functions mutate a global random state. *) + +(** Initialize the generator, using the argument as a seed. The same seed will always + yield the same sequence of numbers. *) +val init : int -> unit + +(** Same as {!Random.init} but takes more data as seed. *) +val full_init : int array -> unit + +(** Initialize the generator with a more-or-less random seed chosen in a system-dependent + way. By default, [self_init] is disallowed in inline tests, as it's often used for no + good reason and it just creates nondeterministic failures for everyone. Passing + [~allow_in_tests:true] removes this restriction in case you legitimately want + nondeterministic values, like in [Filename.temp_dir]. *) +val self_init : ?allow_in_tests:bool -> unit -> unit + +(** Return 30 random bits in a nonnegative integer. @before 3.12.0 used a different + algorithm (affects all the following functions) *) +val bits : unit -> int + +(** [Random.int bound] returns a random integer between 0 (inclusive) and [bound] + (exclusive). [bound] must be greater than 0. *) +val int : int -> int + +(** [Random.int32 bound] returns a random integer between 0 (inclusive) and [bound] + (exclusive). [bound] must be greater than 0. *) +val int32 : int32 -> int32 + +(** [Random.nativeint bound] returns a random integer between 0 (inclusive) and [bound] + (exclusive). [bound] must be greater than 0. *) +val nativeint : nativeint -> nativeint + +(** [Random.int64 bound] returns a random integer between 0 (inclusive) and [bound] + (exclusive). [bound] must be greater than 0. *) +val int64 : int64 -> int64 + +(** [Random.float bound] returns a random floating-point number between 0 (inclusive) and + [bound] (exclusive). If [bound] is negative, the result is negative or zero. If + [bound] is 0, the result is 0. *) +val float : float -> float + +(** Produces a random value between the given inclusive bounds. Raises if bounds are + given in decreasing order. *) +val int_incl : int -> int -> int +val int32_incl : int32 -> int32 -> int32 +val nativeint_incl : nativeint -> nativeint -> nativeint +val int64_incl : int64 -> int64 -> int64 + +(** Produces a value between the given bounds (inclusive and exclusive, respectively). + Raises if bounds are given in decreasing order. *) +val float_range : float -> float -> float + +(** [Random.bool ()] returns [true] or [false] with probability 0.5 each. *) +val bool : unit -> bool + +(** {6 Advanced functions} *) + +(** The functions from module [State] manipulate the current state of the random generator + explicitly. This allows using one or several deterministic PRNGs, even in a + multi-threaded program, without interference from other parts of the program. + + Note that [Random.get_state] from the standard library is not exposed, because it + misleadingly makes a copy of random state, which is not typically the desired outcome + for accessing the shared state. + + Obtaining multiple generators with good independence properties is nontrivial; see + the [Splittable_random] library for that. *) +module State : sig + type t + + (** This gives access to the default random state, allowing user code to share (and + thereby mutate) the random state used by the main functions in [Random]. *) + val default : t + + (** Creates a new state and initializes it with the given seed. *) + val make : int array -> t + + (** Creates a new state and initializes it with a system-dependent low-entropy seed. *) + val make_self_init : ?allow_in_tests:bool -> unit -> t + + val copy : t -> t + + (** These functions are the same as the basic functions, except that they use (and + update) the given PRNG state instead of the default one. *) + + val bits : t -> int + + val int : t -> int -> int + val int32 : t -> int32 -> int32 + val nativeint : t -> nativeint -> nativeint + val int64 : t -> int64 -> int64 + val float : t -> float -> float + + val int_incl : t -> int -> int -> int + val int32_incl : t -> int32 -> int32 -> int32 + val nativeint_incl : t -> nativeint -> nativeint -> nativeint + val int64_incl : t -> int64 -> int64 -> int64 + + val float_range : t -> float -> float -> float + + val bool : t -> bool +end + +(** Sets the state of the generator used by the basic functions. *) +val set_state : State.t -> unit diff --git a/src/ref.ml b/src/ref.ml new file mode 100644 index 0000000..15bfa7a --- /dev/null +++ b/src/ref.ml @@ -0,0 +1,47 @@ +open! Import + +(* In the definition of [t], we do not have [[@@deriving_inline compare, sexp][@@@end]] because + in general, syntax extensions tend to use the implementation when available rather than + using the alias. Here that would lead to use the record representation [ { mutable + contents : 'a } ] which would result in different (and unwanted) behavior. *) +type 'a t = 'a ref = { mutable contents : 'a } + +include (struct + type 'a t = 'a ref [@@deriving_inline compare, equal, sexp] + let compare : 'a . ('a -> 'a -> int) -> 'a t -> 'a t -> int = compare_ref + let equal : 'a . ('a -> 'a -> bool) -> 'a t -> 'a t -> bool = equal_ref + let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + ref_of_sexp + let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + sexp_of_ref + [@@@end] +end : sig + type 'a t = 'a ref [@@deriving_inline compare, equal, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t + end[@@ocaml.doc "@inline"] + [@@@end] + end with type 'a t := 'a t) + +external create : 'a -> 'a t = "%makemutable" +external ( ! ) : 'a t -> 'a = "%field0" +external ( := ) : 'a t -> 'a -> unit = "%setfield0" + +let swap t1 t2 = + let tmp = !t1 in + t1 := !t2; + t2 := tmp + +let replace t f = t := f !t + +let set_temporarily t a ~f = + let restore_to = !t in + t := a; + Exn.protect ~f ~finally:(fun () -> t := restore_to); +;; diff --git a/src/ref.mli b/src/ref.mli new file mode 100644 index 0000000..7024c80 --- /dev/null +++ b/src/ref.mli @@ -0,0 +1,30 @@ +(** Module for the type [ref], mutable indirection cells [r] containing a value of type + ['a], accessed with [!r] and set by [r := a]. *) + +open! Import + +type 'a t = 'a Caml.ref = { mutable contents : 'a } +[@@deriving_inline compare, equal, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +(*_ defined as externals to avoid breaking the inliner *) +external create : 'a -> 'a t = "%makemutable" +external ( ! ) : 'a t -> 'a = "%field0" +external ( := ) : 'a t -> 'a -> unit = "%setfield0" + +(** [swap t1 t2] swaps the values in [t1] and [t2]. *) +val swap : 'a t -> 'a t -> unit + +(** [replace t f] is [t := f !t] *) +val replace : 'a t -> ('a -> 'a) -> unit + +(** [set_temporarily t a ~f] sets [t] to [a], calls [f ()], and then restores [t] to its + value prior to [set_temporarily] being called, whether [f] returns or raises. *) +val set_temporarily : 'a t -> 'a -> f:(unit -> 'b) -> 'b diff --git a/src/result.ml b/src/result.ml new file mode 100644 index 0000000..3e35ef6 --- /dev/null +++ b/src/result.ml @@ -0,0 +1,188 @@ +open! Import + +type ('a, 'b) t = ('a, 'b) Caml.result = + | Ok of 'a + | Error of 'b +[@@deriving_inline sexp, compare, hash] +let t_of_sexp : type a b. + (Ppx_sexp_conv_lib.Sexp.t -> a) -> + (Ppx_sexp_conv_lib.Sexp.t -> b) -> Ppx_sexp_conv_lib.Sexp.t -> (a, b) t + = + let _tp_loc = "src/result.ml.t" in + fun _of_a -> + fun _of_b -> + function + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("ok"|"Ok" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_a v0 in Ok v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("error"|"Error" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_b v0 in Error v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.Atom ("ok"|"Ok") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.Atom ("error"|"Error") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as + sexp -> + Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_t : type a b. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> + (b -> Ppx_sexp_conv_lib.Sexp.t) -> (a, b) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_a -> + fun _of_b -> + function + | Ok v0 -> + let v0 = _of_a v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Ok"; v0] + | Error v0 -> + let v0 = _of_b v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Error"; v0] +let compare : + 'a 'b . + ('a -> 'a -> int) -> ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + = + fun _cmp__a -> + fun _cmp__b -> + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match (a__001_, b__002_) with + | (Ok _a__003_, Ok _b__004_) -> _cmp__a _a__003_ _b__004_ + | (Ok _, _) -> (-1) + | (_, Ok _) -> 1 + | (Error _a__005_, Error _b__006_) -> _cmp__b _a__005_ _b__006_) +let hash_fold_t : type a b. + (Ppx_hash_lib.Std.Hash.state -> a -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> b -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> (a, b) t -> Ppx_hash_lib.Std.Hash.state + = + fun _hash_fold_a -> + fun _hash_fold_b -> + fun hsv -> + fun arg -> + match arg with + | Ok _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 0 in + let hsv = hsv in _hash_fold_a hsv _a0 + | Error _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 1 in + let hsv = hsv in _hash_fold_b hsv _a0 +[@@@end] + +include Monad.Make2 (struct + type nonrec ('a, 'b) t = ('a,'b) t + + let bind x ~f = match x with + | Error _ as x -> x + | Ok x -> f x + + let map x ~f = match x with + | Error _ as x -> x + | Ok x -> Ok (f x) + + let map = `Custom map + + let return x = Ok x + end) + +let ignore = ignore_m + +let fail x = Error x;; +let failf format = Printf.ksprintf fail format + +let map_error t ~f = match t with + | Ok _ as x -> x + | Error x -> Error (f x) + +let is_ok = function + | Ok _ -> true + | Error _ -> false + +let is_error = function + | Ok _ -> false + | Error _ -> true + +let ok = function + | Ok x -> Some x + | Error _ -> None + +let error = function + | Ok _ -> None + | Error x -> Some x + +let of_option opt ~error = + match opt with + | Some x -> Ok x + | None -> Error error + +let iter v ~f = match v with + | Ok x -> f x + | Error _ -> () + +let iter_error v ~f = match v with + | Ok _ -> () + | Error x -> f x + +let ok_fst = function + | Ok x -> `Fst x + | Error x -> `Snd x + +let ok_if_true bool ~error = + if bool + then Ok () + else Error error + +let try_with f = + try Ok (f ()) + with exn -> Error exn + +let ok_unit = Ok () + +let ok_exn = function + | Ok x -> x + | Error exn -> raise exn + +let ok_or_failwith = function + | Ok x -> x + | Error str -> failwith str + +module Export = struct + type ('ok, 'err) _result = + ('ok, 'err) t = + | Ok of 'ok + | Error of 'err + + let is_error = is_error + let is_ok = is_ok +end + +let combine t1 t2 ~ok ~err = + match t1, t2 with + | Ok _, Error e | Error e, Ok _ -> Error e + | Ok ok1 , Ok ok2 -> Ok (ok ok1 ok2 ) + | Error err1, Error err2 -> Error (err err1 err2) +;; + +let combine_errors l = + let ok, errs = List1.partition_map l ~f:ok_fst in + match errs with + | [] -> Ok ok + | _ :: _ -> Error errs +;; + +let combine_errors_unit l = map (combine_errors l) ~f:(fun (_ : unit list) -> ()) diff --git a/src/result.mli b/src/result.mli new file mode 100644 index 0000000..4e3501c --- /dev/null +++ b/src/result.mli @@ -0,0 +1,108 @@ +(** [Result] is often used to handle error messages. *) + +open! Import + +(** ['ok] is a function's expected return type, and ['err] is often an error message + string. + + {[ + let ric_of_ticker = function + | "IBM" -> Ok "IBM.N" + | "MSFT" -> Ok "MSFT.OQ" + | "AA" -> Ok "AA.N" + | "CSCO" -> Ok "CSCO.OQ" + | _ as ticker -> Error (sprintf "can't find ric of %s" ticker) + ]} + + The return type of ric_of_ticker could be [string option], but [(string, string) + Result.t] gives more control over the error message. *) +type ('ok, 'err) t = ('ok, 'err) Caml.result = + | Ok of 'ok + | Error of 'err +[@@deriving_inline sexp, compare, hash] +include +sig + [@@@ocaml.warning "-32"] + include + Ppx_sexp_conv_lib.Sexpable.S2 with type ('ok,'err) t := ('ok, 'err) t + val compare : + ('ok -> 'ok -> int) -> + ('err -> 'err -> int) -> ('ok, 'err) t -> ('ok, 'err) t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'ok -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> 'err -> Ppx_hash_lib.Std.Hash.state) + -> + Ppx_hash_lib.Std.Hash.state -> + ('ok, 'err) t -> Ppx_hash_lib.Std.Hash.state +end[@@ocaml.doc "@inline"] +[@@@end] + +include Monad.S2 with type ('a,'err) t := ('a,'err) t + +val ignore : (_, 'err) t -> (unit, 'err) t + +val fail : 'err -> (_, 'err) t + +(** e.g., [failf "Couldn't find bloogle %s" (Bloogle.to_string b)]. *) +val failf : ('a, unit, string, (_, string) t) format4 -> 'a + +val is_ok : (_, _) t -> bool +val is_error : (_, _) t -> bool + +val ok : ('ok, _ ) t -> 'ok option +val ok_exn : ('ok, exn ) t -> 'ok +val ok_or_failwith : ('ok, string) t -> 'ok + +val error : (_ , 'err) t -> 'err option + +val of_option : 'ok option -> error:'err -> ('ok, 'err) t + +val iter : ('ok, _ ) t -> f:('ok -> unit) -> unit +val iter_error : (_ , 'err) t -> f:('err -> unit) -> unit + +val map : ('ok, 'err) t -> f:('ok -> 'c) -> ('c , 'err) t +val map_error : ('ok, 'err) t -> f:('err -> 'c) -> ('ok, 'c ) t + +(** Returns [Ok] if both are [Ok] and [Error] otherwise. *) +val combine + : ('ok1, 'err) t + -> ('ok2, 'err) t + -> ok: ('ok1 -> 'ok2 -> 'ok3) + -> err:('err -> 'err -> 'err) + -> ('ok3, 'err) t + +(** [combine_errors ts] returns [Ok] if every element in [ts] is [Ok], else it returns + [Error] with all the errors in [ts]. + + This is similar to [all] from [Monad.S2], with the difference that [all] only returns + the first error. *) +val combine_errors : ('ok, 'err) t list -> ('ok list, 'err list) t + +(** [combine_errors_unit] returns [Ok] if every element in [ts] is [Ok ()], else it + returns [Error] with all the errors in [ts], like [combine_errors]. *) +val combine_errors_unit : (unit, 'err) t list -> (unit, 'err list) t + +(** [ok_fst] is useful with [List.partition_map]. Continuing the above example: + {[ + let rics, errors = List.partition_map ~f:Result.ok_fst + (List.map ~f:ric_of_ticker ["AA"; "F"; "CSCO"; "AAPL"]) ]} *) +val ok_fst : ('ok, 'err) t -> [ `Fst of 'ok | `Snd of 'err ] + +(** [ok_if_true] returns [Ok ()] if [bool] is true, and [Error error] if it is false. *) +val ok_if_true : bool -> error : 'err -> (unit, 'err) t + +val try_with : (unit -> 'a) -> ('a, exn) t + +(** [ok_unit = Ok ()], used to avoid allocation as a performance hack. *) +val ok_unit : (unit, _) t + +module Export : sig + type ('ok, 'err) _result + = ('ok, 'err) t + = Ok of 'ok + | Error of 'err + + val is_ok : (_, _) t -> bool + val is_error : (_, _) t -> bool +end + diff --git a/src/runtime.js b/src/runtime.js new file mode 100644 index 0000000..dfab440 --- /dev/null +++ b/src/runtime.js @@ -0,0 +1,118 @@ +//Provides: Base_int_math_int_popcount const +function Base_int_math_int_popcount(v) { + v = v - ((v >>> 1) & 0x55555555); + v = (v & 0x33333333) + ((v >>> 2) & 0x33333333); + return ((v + (v >>> 4) & 0xF0F0F0F) * 0x1010101) >>> 24; +} + +//Provides: Base_clear_caml_backtrace_pos const +function Base_clear_caml_backtrace_pos(x) { + return 0; +} + +//Provides: Base_int_math_int32_clz const +function Base_int_math_int32_clz(x) { + var n = 32; + var y; + y = x >>16; if (y != 0) { n = n -16; x = y; } + y = x >> 8; if (y != 0) { n = n - 8; x = y; } + y = x >> 4; if (y != 0) { n = n - 4; x = y; } + y = x >> 2; if (y != 0) { n = n - 2; x = y; } + y = x >> 1; if (y != 0) return n - 2; + return n - x; +} + +//Provides: Base_int_math_int_clz const +//Requires: Base_int_math_int32_clz +function Base_int_math_int_clz(x) { return Base_int_math_int32_clz(x); } + +//Provides: Base_int_math_nativeint_clz const +//Requires: Base_int_math_int32_clz +function Base_int_math_nativeint_clz(x) { return Base_int_math_int32_clz(x); } + +//Provides: Base_int_math_int64_clz const +//Requires: caml_int64_shift_right_unsigned, caml_int64_is_zero, caml_int64_to_int32 +function Base_int_math_int64_clz(x) { + var n = 64; + var y; + y = caml_int64_shift_right_unsigned(x, 32); + if (!caml_int64_is_zero(y)) { n = n -32; x = y; } + y = caml_int64_shift_right_unsigned(x, 16); + if (!caml_int64_is_zero(y)) { n = n -16; x = y; } + y = caml_int64_shift_right_unsigned(x, 8); + if (!caml_int64_is_zero(y)) { n = n - 8; x = y; } + y = caml_int64_shift_right_unsigned(x, 4); + if (!caml_int64_is_zero(y)) { n = n - 4; x = y; } + y = caml_int64_shift_right_unsigned(x, 2); + if (!caml_int64_is_zero(y)) { n = n - 2; x = y; } + y = caml_int64_shift_right_unsigned(x, 1); + if (!caml_int64_is_zero(y)) return n - 2; + return n - caml_int64_to_int32(x); +} + +//Provides: Base_int_math_int_pow_stub const +function Base_int_math_int_pow_stub(base, exponent) { + var one = 1; + var mul = [one, base, one, one]; + var res = one; + while (!exponent==0) { + mul[1] = (mul[1] * mul[3]) | 0; + mul[2] = (mul[1] * mul[1]) | 0; + mul[3] = (mul[2] * mul[1]) | 0; + res = (res * mul[exponent & 3]) | 0; + exponent = exponent >> 2; + } + return res; +} + +//Provides: Base_int_math_int64_pow_stub const +//Requires: caml_int64_mul, caml_int64_is_zero, caml_int64_shift_right_unsigned +function Base_int_math_int64_pow_stub(base, exponent) { + var one = [255,1,0,0]; + var mul = [one, base, one, one]; + var res = one; + while (!caml_int64_is_zero(exponent)) { + mul[1] = caml_int64_mul(mul[1], mul[3]); + mul[2] = caml_int64_mul(mul[1], mul[1]); + mul[3] = caml_int64_mul(mul[2], mul[1]); + res = caml_int64_mul(res, mul[exponent[1] & 3]); + exponent = caml_int64_shift_right_unsigned(exponent, 2); + } + return res; +} + +//Provides: Base_internalhash_fold_int64 +//Requires: caml_hash_mix_int64 +var Base_internalhash_fold_int64 = caml_hash_mix_int64; +//Provides: Base_internalhash_fold_int +//Requires: caml_hash_mix_int +var Base_internalhash_fold_int = caml_hash_mix_int; +//Provides: Base_internalhash_fold_float +//Requires: caml_hash_mix_float +var Base_internalhash_fold_float = caml_hash_mix_float; +//Provides: Base_internalhash_fold_string +//Requires: caml_hash_mix_string +var Base_internalhash_fold_string = caml_hash_mix_string; +//Provides: Base_internalhash_get_hash_value +//Requires: caml_hash_mix_final +function Base_internalhash_get_hash_value(seed) { + var h = caml_hash_mix_final(seed); + return h & 0x3FFFFFFF; +} + +//Provides: Base_hash_string mutable +//Requires: caml_hash +function Base_hash_string(s) { + return caml_hash(1,1,0,s) +} +//Provides: Base_hash_double const +//Requires: caml_hash +function Base_hash_double(d) { + return caml_hash(1,1,0,d); +} + +//Provides: Base_am_testing const +//Weakdef +function Base_am_testing(x) { + return 0; +} diff --git a/src/select-bytes-set-primitives/select.ml b/src/select-bytes-set-primitives/select.ml new file mode 100644 index 0000000..20cfcbe --- /dev/null +++ b/src/select-bytes-set-primitives/select.ml @@ -0,0 +1,20 @@ +let () = + let ver, output = + try + match Sys.argv with + | [|_; "-ocaml-version"; v; "-o"; fn|] -> + (Scanf.sscanf v "%d.%d" (fun major minor -> (major, minor)), + fn) + | _ -> raise Exit + with _ -> + failwith "bad command line arguments" + in + let prefix = + if ver >= (4, 04) then "bytes" else "string" + in + let oc = open_out output in + Printf.fprintf oc {| +external set : %s -> int -> char -> unit = "%%%s_safe_set" +external unsafe_set : %s -> int -> char -> unit = "%%%s_unsafe_set" +|} prefix prefix prefix prefix; + close_out oc diff --git a/src/select-int63-backend/select.ml b/src/select-int63-backend/select.ml new file mode 100644 index 0000000..535eca9 --- /dev/null +++ b/src/select-int63-backend/select.ml @@ -0,0 +1,29 @@ +let () = + let portable_int63, arch_sixtyfour, output = + try + match Sys.argv with + | [|_; "-portable-int63"; x; "-arch-sixtyfour"; y; "-o"; fn|] -> + let x = + match x with + | "true" | "!false" -> true + | "false" | "!true" -> false + | _ -> failwith "invalid value for -portable-int63" + in + (x, + bool_of_string y, + fn) + | _ -> raise Exit + with _ -> + failwith "bad command line arguments" + in + let backend = + if portable_int63 then + "Dynamic" + else if arch_sixtyfour then + "Native" + else + "Emulated" + in + let oc = open_out output in + Printf.fprintf oc "include Int63_backends.%s" backend; + close_out oc diff --git a/src/sequence.ml b/src/sequence.ml new file mode 100644 index 0000000..48badf4 --- /dev/null +++ b/src/sequence.ml @@ -0,0 +1,1097 @@ +open! Import +open Container_intf.Export + +module Array = Array0 +module List = List0 + +module Step = struct + (* 'a is an item in the sequence, 's is the state that will produce the remainder of + the sequence *) + type ('a,'s) t = + | Done + | Skip of 's + | Yield of 'a * 's + [@@deriving_inline sexp_of] + let sexp_of_t : type a s. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> + (s -> Ppx_sexp_conv_lib.Sexp.t) -> (a, s) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_a -> + fun _of_s -> + function + | Done -> Ppx_sexp_conv_lib.Sexp.Atom "Done" + | Skip v0 -> + let v0 = _of_s v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Skip"; v0] + | Yield (v0, v1) -> + let v0 = _of_a v0 + and v1 = _of_s v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Yield"; v0; v1] + [@@@end] +end + +open Step + +(* 'a is an item in the sequence, 's is the state that will produce the remainder of the + sequence *) +type +_ t = + | Sequence : 's * ('s -> ('a,'s) Step.t) -> 'a t + +type 'a sequence = 'a t + +module Expert = struct + let next_step (Sequence (s, f)) = + match f s with + | Done -> Done + | Skip s -> Skip (Sequence (s, f)) + | Yield (a, s) -> Yield (a, Sequence (s, f)) + ;; + + let delayed_fold_step s ~init ~f ~finish = + let rec loop s next finish f acc = + match next s with + | Done -> finish acc + | Skip s -> f acc None ~k:(loop s next finish f) + | Yield (a, s) -> f acc (Some a) ~k:(loop s next finish f) + in + match s with + | Sequence(s, next) -> loop s next finish f init + ;; +end + +let unfold_step ~init ~f = + Sequence (init,f) + +let unfold ~init ~f = + unfold_step ~init + ~f:(fun s -> + match f s with + | None -> Step.Done + | Some(a,s) -> Step.Yield(a,s)) + +let unfold_with s ~init ~f = + match s with + | Sequence(s, next) -> + Sequence((init, s) , + (fun (seed, s) -> + match next s with + | Done -> Done + | Skip s -> Skip (seed, s) + | Yield(a,s) -> + match f seed a with + | Done -> Done + | Skip seed -> Skip (seed, s) + | Yield(a,seed) -> Yield(a,(seed,s)))) + +let unfold_with_and_finish s ~init ~running_step ~inner_finished ~finishing_step = + match s with + | Sequence (s, next) -> + Sequence (`Inner_running (init, s), (fun state -> + match state with + | `Inner_running (state, inner_state) -> begin + match next inner_state with + | Done -> + Skip (`Inner_finished (inner_finished state)) + | Skip inner_state -> + Skip (`Inner_running (state, inner_state)) + | Yield (x, inner_state) -> + match running_step state x with + | Done -> Done + | Skip state -> + Skip (`Inner_running (state, inner_state)) + | Yield (y, state) -> + Yield (y, `Inner_running (state, inner_state)) + end + | `Inner_finished state -> begin + match finishing_step state with + | Done -> Done + | Skip state -> + Skip (`Inner_finished state) + | Yield (y, state) -> + Yield (y, `Inner_finished state) + end)) + +let of_list l = + unfold_step ~init:l + ~f:(function + | [] -> Done + | x::l -> Yield(x,l)) + + +let fold t ~init ~f = + let rec loop seed v next f = + match next seed with + | Done -> v + | Skip s -> loop s v next f + | Yield(a,s) -> loop s (f v a) next f + in + match t with + | Sequence(seed, next) -> loop seed init next f + +let to_list_rev t = + fold t ~init:[] ~f:(fun l x -> x::l) + + +let to_list (Sequence(s,next)) = + let safe_to_list t = + List.rev (to_list_rev t) + in + let rec to_list s next i = + if i = 0 then safe_to_list (Sequence(s,next)) + else + match next s with + | Done -> [] + | Skip s -> to_list s next i + | Yield(a,s) -> a::(to_list s next (i-1)) + in + to_list s next 500 + +let sexp_of_t sexp_of_a t = sexp_of_list sexp_of_a (to_list t) + +let range ?(stride=1) ?(start=`inclusive) ?(stop=`exclusive) start_v stop_v = + let step = + match stop with + | `inclusive when stride >= 0 -> + fun i -> if i > stop_v then Done else Yield(i, i + stride) + | `inclusive -> + fun i -> if i < stop_v then Done else Yield(i, i + stride) + | `exclusive when stride >= 0 -> + fun i -> if i >= stop_v then Done else Yield(i,i + stride) + | `exclusive -> + fun i -> if i <= stop_v then Done else Yield(i,i + stride) + in + let init = + match start with + | `inclusive -> start_v + | `exclusive -> start_v + stride + in + unfold_step ~init ~f:step + +let of_lazy t_lazy = + unfold_step ~init:t_lazy ~f:(fun t_lazy -> + let Sequence (s, next) = Lazy.force t_lazy in + match next s with + | Done -> Done + | Skip s -> Skip (let v = Sequence (s, next) in lazy v) + | Yield (x, s) -> Yield (x, let v = Sequence (s, next) in lazy v)) + +let map t ~f = + match t with + | Sequence(seed, next) -> + Sequence(seed, + fun seed -> + match next seed with + | Done -> Done + | Skip s -> Skip s + | Yield(a,s) -> Yield(f a,s)) + + +let mapi t ~f = + match t with + | Sequence(s, next) -> + Sequence((0,s), + fun (i,s) -> + match next s with + | Done -> Done + | Skip s -> Skip (i,s) + | Yield(a,s) -> Yield(f i a,(i+1,s))) + +let folding_map t ~init ~f = + unfold_with t ~init ~f:(fun acc x -> + let acc, x = f acc x in + Yield (x, acc)) + +let folding_mapi t ~init ~f = + unfold_with t ~init:(0, init) ~f:(fun (i, acc) x -> + let acc, x = f i acc x in + Yield (x, (i+1, acc))) + +let filter t ~f = + match t with + | Sequence(seed, next) -> + Sequence(seed, + fun seed -> + match next seed with + | Done -> Done + | Skip s -> Skip s + | Yield(a,s) when f a -> Yield(a,s) + | Yield (_,s) -> Skip s) + +let filteri t ~f = + map ~f:snd ( + filter (mapi t ~f:(fun i s -> (i,s))) + ~f:(fun (i,s) -> f i s)) + +let length t = + let rec loop i s next = + match next s with + | Done -> i + | Skip s -> loop i s next + | Yield(_,s) -> loop (i+1) s next + in + match t with + | Sequence (seed, next) -> loop 0 seed next + +let to_list_rev_with_length t = + fold t ~init:([],0) ~f:(fun (l,i) x -> (x::l,i+1)) + +let to_array t = + let (l,len) = to_list_rev_with_length t in + match l with + | [] -> [||] + | x::l -> + let a = Array.create ~len x in + let rec loop i l = + match l with + | [] -> assert (i = -1) + | x::l -> a.(i) <- x; loop (i-1) l + in + loop (len - 2) l; + a + +let find t ~f = + let rec loop s next f = + match next s with + | Done -> None + | Yield(a,_) when f a -> Some a + | Yield(_,s) | Skip s -> loop s next f + in + match t with + | Sequence (seed, next) -> loop seed next f + +let find_map t ~f = + let rec loop s next f = + match next s with + | Done -> None + | Yield(a,s) -> + (match f a with + | None -> loop s next f + | some_b -> some_b) + | Skip s -> loop s next f + in + match t with + | Sequence (seed, next) -> loop seed next f + + +let find_mapi t ~f = + let rec loop s next f i = + match next s with + | Done -> None + | Yield(a,s) -> + (match f i a with + | None -> loop s next f (i+1) + | some_b -> some_b) + | Skip s -> loop s next f i + in + match t with + | Sequence (seed, next) -> loop seed next f 0 + +let for_all t ~f = + let rec loop s next f = + match next s with + | Done -> true + | Yield(a,_) when not (f a) -> false + | Yield (_,s) | Skip s -> loop s next f + in + match t with + | Sequence (seed, next) -> loop seed next f + +let for_alli t ~f = + let rec loop s next f i = + match next s with + | Done -> true + | Yield(a,_) when not (f i a) -> false + | Yield (_,s) -> loop s next f (i+1) + | Skip s -> loop s next f i + in + match t with + | Sequence (seed, next) -> loop seed next f 0 + +let exists t ~f = + let rec loop s next f = + match next s with + | Done -> false + | Yield(a,_) when f a -> true + | Yield(_,s) | Skip s -> loop s next f + in + match t with + | Sequence (seed, next) -> loop seed next f + +let existsi t ~f = + let rec loop s next f i = + match next s with + | Done -> false + | Yield(a,_) when f i a -> true + | Yield(_,s) -> loop s next f (i+1) + | Skip s -> loop s next f i + in + match t with + | Sequence (seed, next) -> loop seed next f 0 + +let iter t ~f = + let rec loop seed next f = + match next seed with + | Done -> () + | Skip s -> loop s next f + | Yield(a,s) -> + begin + f a; + loop s next f + end + in + match t with + | Sequence(seed, next) -> loop seed next f + +let is_empty t = + let rec loop s next = + match next s with + | Done -> true + | Skip s -> loop s next + | Yield _ -> false + in + match t with + | Sequence(seed, next) -> loop seed next + +let mem t a ~equal = + let rec loop s next a = + match next s with + | Done -> false + | Yield(b,_) when equal a b -> true + | Yield(_,s) | Skip s -> loop s next a + in + match t with + | Sequence(seed, next) -> loop seed next a + +let empty = + Sequence((), fun () -> Done) + +let bind t ~f = + unfold_step + ~f:(function + | Sequence(seed,next), rest -> + match next seed with + | Done -> + begin + match rest with + | Sequence(seed, next) -> + match next seed with + | Done -> Done + | Skip s -> Skip (empty, Sequence(s, next)) + | Yield(a, s) -> Skip(f a, Sequence(s, next)) + end + | Skip s -> Skip (Sequence(s,next), rest) + | Yield(a,s) -> Yield(a, (Sequence(s,next) , rest))) + ~init:(empty,t) + +let return x = + unfold_step ~init:(Some x) + ~f:(function + | None -> Done + | Some x -> Yield(x,None)) + +include Monad.Make(struct + type nonrec 'a t = 'a t + let map = `Custom map + let bind = bind + let return = return + end) + +let nth s n = + if n < 0 then None + else + let rec loop i s next = + match next s with + | Done -> None + | Skip s -> loop i s next + | Yield(a,s) -> if phys_equal i 0 then Some a else loop (i-1) s next + in + match s with + | Sequence(s,next) -> + loop n s next + +let nth_exn s n = + if n < 0 then raise (Invalid_argument "Sequence.nth") + else + match nth s n with + | None -> failwith "Sequence.nth" + | Some x -> x + +module Merge_with_duplicates_element = struct + type ('a, 'b) t = + | Left of 'a + | Right of 'b + | Both of 'a * 'b + [@@deriving_inline compare, hash, sexp] + let compare : + 'a 'b . + ('a -> 'a -> int) -> ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + = + fun _cmp__a -> + fun _cmp__b -> + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match (a__001_, b__002_) with + | (Left _a__003_, Left _b__004_) -> _cmp__a _a__003_ _b__004_ + | (Left _, _) -> (-1) + | (_, Left _) -> 1 + | (Right _a__005_, Right _b__006_) -> _cmp__b _a__005_ _b__006_ + | (Right _, _) -> (-1) + | (_, Right _) -> 1 + | (Both (_a__007_, _a__009_), Both (_b__008_, _b__010_)) -> + (match _cmp__a _a__007_ _b__008_ with + | 0 -> _cmp__b _a__009_ _b__010_ + | n -> n)) + let hash_fold_t : type a b. + (Ppx_hash_lib.Std.Hash.state -> a -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> b -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> (a, b) t -> Ppx_hash_lib.Std.Hash.state + = + fun _hash_fold_a -> + fun _hash_fold_b -> + fun hsv -> + fun arg -> + match arg with + | Left _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 0 in + let hsv = hsv in _hash_fold_a hsv _a0 + | Right _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 1 in + let hsv = hsv in _hash_fold_b hsv _a0 + | Both (_a0, _a1) -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 2 in + let hsv = let hsv = hsv in _hash_fold_a hsv _a0 in + _hash_fold_b hsv _a1 + let t_of_sexp : type a b. + (Ppx_sexp_conv_lib.Sexp.t -> a) -> + (Ppx_sexp_conv_lib.Sexp.t -> b) -> Ppx_sexp_conv_lib.Sexp.t -> (a, b) t + = + let _tp_loc = "src/sequence.ml.Merge_with_duplicates_element.t" in + fun _of_a -> + fun _of_b -> + function + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("left"|"Left" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_a v0 in Left v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("right"|"Right" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_b v0 in Right v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("both"|"Both" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::v1::[] -> + let v0 = _of_a v0 + and v1 = _of_b v1 in Both (v0, v1) + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.Atom ("left"|"Left") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.Atom ("right"|"Right") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.Atom ("both"|"Both") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as + sexp -> + Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp + let sexp_of_t : type a b. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> + (b -> Ppx_sexp_conv_lib.Sexp.t) -> (a, b) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_a -> + fun _of_b -> + function + | Left v0 -> + let v0 = _of_a v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Left"; v0] + | Right v0 -> + let v0 = _of_b v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Right"; v0] + | Both (v0, v1) -> + let v0 = _of_a v0 + and v1 = _of_b v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Both"; v0; v1] + [@@@end] +end + +let merge_with_duplicates (Sequence (s1, next1)) (Sequence (s2, next2)) ~compare = + let unshadowed_compare = compare in + let open Merge_with_duplicates_element in + let next = function + | Skip s1, s2 -> Skip (next1 s1, s2) + | s1, Skip s2 -> Skip (s1, next2 s2) + | (Yield (a, s1') as s1), (Yield (b, s2') as s2) -> + let comparison = unshadowed_compare a b in + if comparison < 0 + then Yield (Left a, (Skip s1', s2)) + else if comparison = 0 + then Yield (Both (a, b), (Skip s1', Skip s2')) + else Yield (Right b, (s1, Skip s2')) + | Done, Done -> Done + | Yield (a, s1), Done -> Yield (Left a, (Skip s1, Done)) + | Done, Yield (b, s2) -> Yield (Right b, (Done, Skip s2)) + in + Sequence((Skip s1, Skip s2), next) + +let merge s1 s2 ~compare = + map (merge_with_duplicates s1 s2 ~compare) + ~f:(function Left x | Right x | Both (x, _) -> x) + +let hd s = + let rec loop s next = + match next s with + | Done -> None + | Skip s -> loop s next + | Yield(a,_) -> Some a + in + match s with + | Sequence (s,next) -> loop s next + +let hd_exn s = + match hd s with + | None -> failwith "hd_exn" + | Some a -> a + +let tl s = + let rec loop s next = + match next s with + | Done -> None + | Skip s -> loop s next + | Yield(_,a) -> Some a + in + match s with + | Sequence (s,next) -> + match loop s next with + | None -> None + | Some s -> Some (Sequence(s,next)) + +let tl_eagerly_exn s = + match tl s with + | None -> failwith "Sequence.tl_exn" + | Some s -> s + +let lift_identity next s = + match next s with + | Done -> Done + | Skip s -> Skip (`Identity s) + | Yield(a,s) -> Yield(a, `Identity s) + +let next s = + let rec loop s next = + match next s with + | Done -> None + | Skip s -> loop s next + | Yield(a,s) -> Some (a, Sequence(s, next)) + in + match s with + | Sequence(s, next) -> loop s next + +let filter_opt s = + match s with + | Sequence(s, next) -> + Sequence(s, + fun s -> + match next s with + | Done -> Done + | Skip s -> Skip s + | Yield(None, s) -> Skip s + | Yield(Some a, s) -> Yield(a, s)) + +let filter_map s ~f = + filter_opt (map s ~f) + +let filter_mapi s ~f = + filter_map (mapi s ~f:(fun i s -> (i,s))) + ~f:(fun (i, s) -> f i s) + +let split_n s n = + let rec loop s i accum next = + if i <= 0 then + (List.rev accum, Sequence(s,next)) + else + match next s with + | Done -> (List.rev accum, empty) + | Skip s -> loop s i accum next + | Yield(a,s) -> loop s (i-1) (a::accum) next + in + match s with + | Sequence(s, next) -> loop s n [] next + +let chunks_exn t n = + if n <= 0 + then raise (Invalid_argument "Sequence.chunks_exn") + else + unfold_step ~init:t ~f:(fun t -> + match split_n t n with + | [], _empty -> Done + | _::_ as xs, t -> Yield (xs, t)) + +let findi s ~f = + find (mapi s ~f:(fun i s -> (i,s))) + ~f:(fun (i,s) -> f i s) + +let find_exn s ~f = + match find s ~f with + | None -> failwith "Sequence.find_exn" + | Some x -> x + +let append s1 s2 = + match s1, s2 with + | Sequence(s1, next1), Sequence(s2, next2) -> + Sequence(`First_list s1, + function + | `First_list s1 -> + begin + match next1 s1 with + | Done -> Skip (`Second_list s2) + | Skip s1 -> Skip (`First_list s1) + | Yield(a,s1) -> Yield(a, `First_list s1) + end + | `Second_list s2 -> + begin + match next2 s2 with + | Done -> Done + | Skip s2 -> Skip (`Second_list s2) + | Yield(a,s2) -> Yield(a, `Second_list s2) + end) + +let concat_map s ~f = bind s ~f + +let concat s = concat_map s ~f:Fn.id + +let concat_mapi s ~f = + concat_map (mapi s ~f:(fun i s -> (i,s))) + ~f:(fun (i,s) -> f i s) + +let zip (Sequence (s1, next1)) (Sequence (s2, next2)) = + let next = function + | Yield (a, s1), Yield (b, s2) -> Yield ((a, b), (Skip s1, Skip s2)) + | Done, _ + | _, Done -> Done + | Skip s1, s2 -> Skip (next1 s1, s2) + | s1, Skip s2 -> Skip (s1, next2 s2) + in + Sequence ((Skip s1, Skip s2), next) + +let zip_full (Sequence(s1,next1)) (Sequence(s2,next2)) = + let next = function + | Yield (a, s1), Yield (b, s2) -> Yield (`Both (a, b), (Skip s1, Skip s2)) + | Done, Done -> Done + | Skip s1, s2 -> Skip (next1 s1, s2) + | s1, Skip s2 -> Skip (s1, next2 s2) + | Done, Yield(b, s2) -> Yield((`Right b), (Done, next2 s2)) + | Yield(a, s1), Done -> Yield((`Left a), (next1 s1, Done)) + in + Sequence ((Skip s1, Skip s2), next) + +let bounded_length (Sequence(seed,next)) ~at_most = + let rec loop i seed next = + if i > at_most then `Greater + else + match next seed with + | Done -> `Is i + | Skip seed -> loop i seed next + | Yield(_, seed) -> loop (i+1) seed next + in + loop 0 seed next + +let length_is_bounded_by ?(min=(-1)) ?max t = + let length_is_at_least (Sequence(s,next)) = + let rec loop s acc = + if acc >= min then true else + match next s with + | Done -> false + | Skip s -> loop s acc + | Yield(_,s) -> loop s (acc + 1) + in loop s 0 + in + match max with + | None -> length_is_at_least t + | Some max -> + begin + match bounded_length t ~at_most:max with + | `Is len when len >= min -> true + | _ -> false + end + +let iteri s ~f = + iter (mapi s ~f:(fun i s -> (i, s))) + ~f:(fun (i, s) -> f i s) + +let foldi s ~init ~f = + fold ~init (mapi s ~f:(fun i s -> (i,s))) + ~f:(fun acc (i, s) -> f i acc s) + +let reduce s ~f = + match next s with + | None -> None + | Some(a, s) -> Some (fold s ~init:a ~f) + +let reduce_exn s ~f = + match reduce s ~f with + | None -> failwith "Sequence.reduce_exn" + | Some res -> res + +let group (Sequence (s, next)) ~break = + unfold_step ~init:(Some ([], s)) ~f:(function + | None -> Done + | Some (acc, s) -> + match acc, next s with + | _, Skip s -> Skip (Some (acc, s)) + | [], Done -> Done + | acc, Done -> Yield (List.rev acc, None) + | [], Yield (cur, s) -> Skip (Some ([cur], s)) + | prev :: _ as acc, Yield (cur, s) -> + if break prev cur + then Yield (List.rev acc, Some ([cur], s)) + else Skip (Some (cur :: acc, s))) +;; + +let find_consecutive_duplicate (Sequence(s, next)) ~equal = + let rec loop last_elt s = + match next s with + | Done -> None + | Skip s -> loop last_elt s + | Yield(a,s) -> + match last_elt with + | Some b when equal a b -> Some (b, a) + | None | Some _ -> loop (Some a) s + in + loop None s + +let remove_consecutive_duplicates s ~equal = + unfold_with s ~init:None + ~f:(fun prev a -> + match prev with + | Some b when equal a b -> Skip(Some a) + | None | Some _ -> Yield(a, Some a)) + +let count s ~f = + length (filter s ~f) + +let counti t ~f = + length (filteri t ~f) + +let sum m t ~f = Container.sum ~fold m t ~f +let min_elt t ~compare = Container.min_elt ~fold t ~compare +let max_elt t ~compare = Container.max_elt ~fold t ~compare + +let init n ~f = + unfold_step ~init:0 + ~f:(fun i -> + if i >= n then Done + else Yield(f i, i + 1)) + +let sub s ~pos ~len = + if pos < 0 || len < 0 then failwith "Sequence.sub"; + match s with + | Sequence(s, next) -> + Sequence((0,s), + (fun (i, s) -> + if i - pos >= len then Done + else + match next s with + | Done -> Done + | Skip s -> Skip (i, s) + | Yield(a, s) when i >= pos -> Yield (a,(i + 1, s)) + | Yield(_, s) -> Skip(i + 1, s))) + +let take s len = + if len < 0 then failwith "Sequence.take"; + match s with + | Sequence(s, next) -> + Sequence((0,s), + (fun (i, s) -> + if i >= len then Done + else + match next s with + | Done -> Done + | Skip s -> Skip (i, s) + | Yield(a, s) -> Yield (a,(i + 1, s)))) + +let drop s len = + if len < 0 then failwith "Sequence.drop"; + match s with + | Sequence(s, next) -> + Sequence((0,s), + (fun (i, s) -> + match next s with + | Done -> Done + | Skip s -> Skip (i, s) + | Yield(a, s) when i >= len -> Yield (a,(i + 1, s)) + | Yield(_, s) -> Skip (i+1, s))) + +let take_while s ~f = + match s with + | Sequence(s, next) -> + Sequence(s, + fun s -> + match next s with + | Done -> Done + | Skip s -> Skip s + | Yield (a, s) when f a -> Yield(a,s) + | Yield (_,_) -> Done) + +let drop_while s ~f = + match s with + | Sequence(s, next) -> + Sequence(`Dropping s, + function + |`Dropping s -> + begin + match next s with + | Done -> Done + | Skip s -> Skip (`Dropping s) + | Yield(a, s) when f a -> Skip (`Dropping s) + | Yield(a, s) -> Yield(a, `Identity s) + end + | `Identity s -> lift_identity next s) + +let shift_right s x = + match s with + | Sequence(seed, next) -> + Sequence(`Consing (seed, x), + function + | `Consing (seed, x) -> Yield(x, `Identity seed) + | `Identity s -> lift_identity next s) + +let shift_right_with_list s l = + append (of_list l) s + +let shift_left = drop + +module Infix = struct + let (@) = append +end + +let intersperse s ~sep = + match s with + | Sequence(s, next) -> + Sequence(`Init s, + function + | `Init s -> + begin + match next s with + | Done -> Done + | Skip s -> Skip (`Init s) + | Yield(a, s) -> Yield(a, `Running s) + end + | `Running s -> + begin + match next s with + | Done -> Done + | Skip s -> Skip (`Running s) + | Yield(a, s) -> Yield(sep, `Putting(a,s)) + end + | `Putting(a,s) -> Yield(a,`Running s)) + +let repeat x = + unfold_step ~init:x ~f:(fun x -> Yield(x, x)) + +let cycle_list_exn xs = + if List.is_empty xs then raise (Invalid_argument "Sequence.cycle_list_exn"); + let s = of_list xs in + concat_map ~f:(fun () -> s) (repeat ()) + +let cartesian_product sa sb = + concat_map sa + ~f:(fun a -> zip (repeat a) sb) + +let singleton x = return x + +let delayed_fold s ~init ~f ~finish = + Expert.delayed_fold_step s ~init ~finish ~f:(fun acc option ~k -> + match option with + | None -> k acc + | Some a -> f acc a ~k) + +let fold_m ~bind ~return t ~init ~f = + Expert.delayed_fold_step t + ~init + ~f:(fun acc option ~k -> + match option with + | None -> bind (return acc) ~f:k + | Some a -> bind (f acc a) ~f:k) + ~finish:return +;; + +let iter_m ~bind ~return t ~f = + Expert.delayed_fold_step t + ~init:() + ~f:(fun () option ~k -> + match option with + | None -> bind (return ()) ~f:k + | Some a -> bind (f a) ~f:k) + ~finish:return +;; + +let fold_until s ~init ~f ~finish = + let rec loop s next f = + fun acc -> + match next s with + | Done -> finish acc + | Skip s -> loop s next f acc + | Yield(a, s) -> match (f acc a : ('a, 'b) Continue_or_stop.t) with + | Stop x -> x + | Continue acc -> loop s next f acc + in + match s with + | Sequence(s, next) -> loop s next f init + +let fold_result s ~init ~f = + let rec loop s next f = + fun acc -> + match next s with + | Done -> Result.return acc + | Skip s -> loop s next f acc + | Yield(a, s) -> match (f acc a : (_, _) Result.t) with + | Error _ as e -> e + | Ok acc -> loop s next f acc + in + match s with + | Sequence(s, next) -> loop s next f init + +let force_eagerly t = of_list (to_list t) + +let memoize (type a) (Sequence (s, next)) = + let module M = struct + type t = T of (a, t) Step.t Lazy.t + end in + let rec memoize s = M.T (lazy (find_step s)) + and find_step s = + match next s with + | Done -> Done + | Skip s -> find_step s + | Yield (a, s) -> Yield (a, memoize s) + in + Sequence (memoize s, (fun (M.T l) -> Lazy.force l)) + +let drop_eagerly s len = + let rec loop i ~len s next = + if i >= len then Sequence(s, next) + else + match next s with + | Done -> empty + | Skip s -> loop i ~len s next + | Yield(_,s) -> loop (i+1) ~len s next + in + match s with + | Sequence(s, next) -> loop 0 ~len s next + +let drop_while_option (Sequence (s, next)) ~f = + let rec loop s = + match next s with + | Done -> None + | Skip s -> loop s + | Yield (x, s) -> if f x then loop s else Some (x, Sequence (s, next)) + in + loop s + +let compare compare_a t1 t2 = + With_return.with_return (fun r -> + iter (zip_full t1 t2) ~f:(function + | `Left _ -> r.return 1 + | `Right _ -> r.return (-1) + | `Both (v1, v2) -> + let c = compare_a v1 v2 in + if c <> 0 + then r.return c); + 0); +;; + +let round_robin list = + let next (todo_stack, done_stack) = + match todo_stack with + | Sequence (s, f) :: todo_stack -> + begin + match f s with + | Yield (x, s) -> Yield (x, (todo_stack, Sequence (s, f) :: done_stack)) + | Skip s -> Skip (Sequence (s, f) :: todo_stack, done_stack) + | Done -> Skip (todo_stack, done_stack) + end + | [] -> + if List.is_empty done_stack + then Done + else Skip (List.rev done_stack, []) + in + let state = list, [] in + Sequence (state, next) + +let interleave (Sequence (s1, f1)) = + let next (todo_stack, done_stack, s1) = + match todo_stack with + | Sequence (s2, f2) :: todo_stack -> + begin + match f2 s2 with + | Yield (x, s2) -> Yield (x, (todo_stack, Sequence (s2, f2) :: done_stack, s1)) + | Skip s2 -> Skip (todo_stack, Sequence (s2, f2) :: done_stack, s1) + | Done -> Skip (todo_stack, done_stack, s1) + end + | [] -> + begin + match f1 s1, done_stack with + | Yield (t, s1), _ -> Skip (List.rev (t :: done_stack), [], s1) + | Skip s1 , _ -> Skip (List.rev done_stack , [], s1) + | Done , _::_ -> Skip (List.rev done_stack , [], s1) + | Done , [] -> Done + end + in + let state = [], [], s1 in + Sequence (state, next) + +let interleaved_cartesian_product s1 s2 = + map s1 ~f:(fun x1 -> + map s2 ~f:(fun x2 -> + (x1, x2))) + |> interleave + +module Generator = struct + + type 'elt steps = Wrap of ('elt, unit -> 'elt steps) Step.t + + let unwrap (Wrap step) = step + + module T = struct + type ('a, 'elt) t = ('a -> 'elt steps) -> 'elt steps + let return x = (); fun k -> k x + let bind m ~f = (); fun k -> m (fun a -> let m' = f a in m' k) + let map m ~f = (); fun k -> m (fun a -> k (f a)) + let map = `Custom map + end + include T + include Monad.Make2 (T) + + let yield e = (); fun k -> Wrap (Yield (e, k)) + + let to_steps t = t (fun () -> Wrap Done) + + let of_sequence sequence = + delayed_fold sequence + ~init:() + ~f:(fun () x ~k f -> Wrap (Yield (x, fun () -> k () f))) + ~finish:return + + let run t = + let init () = to_steps t in + let f thunk = unwrap (thunk ()) in + unfold_step ~init ~f + +end diff --git a/src/sequence.mli b/src/sequence.mli new file mode 100644 index 0000000..5f18e29 --- /dev/null +++ b/src/sequence.mli @@ -0,0 +1,487 @@ +(** A sequence of elements that can be produced one at a time, on demand, normally with no + sharing. + + The elements are computed on demand, possibly repeating work if they are demanded + multiple times. A sequence can be built by unfolding from some initial state, which + will in practice often be other containers. + + Most functions constructing a sequence will not immediately compute any elements of + the sequence. These functions will always return in O(1), but traversing the + resulting sequence may be more expensive. The most they will do immediately is + generate a new internal state and a new step function. + + Functions that transform existing sequences sometimes have to reconstruct some suffix + of the input sequence, even if it is unmodified. For example, calling [drop 1] will + return a sequence with a slightly larger state and whose elements all cost slightly + more to traverse. Because this is sometimes undesirable (for example, applying [drop + 1] n times will cost O(n) per element traversed in the result), there are also more + eager versions of many functions (whose names are suffixed with [_eagerly]) that do + more work up front. A function has the [_eagerly] suffix iff it matches both of these + conditions: + + - It might consume an element from an input [t] before returning. + + - It only returns a [t] (not paired with something else, not wrapped in an [option], + etc.). If it returns anything other than a [t] and it has at least one [t] input, + it's probably demanding elements from the input [t] anyway. + + Only [*_exn] functions can raise exceptions, except if the function underlying the + sequence (the [f] passed to [unfold]) raises, in which case the exception will + cascade. *) + +open! Import + +type +'a t [@@deriving_inline compare, sexp_of] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] +type 'a sequence = 'a t + +include Indexed_container.S1 with type 'a t := 'a t +include Monad.S with type 'a t := 'a t + +(** [empty] is a sequence with no elements. *) +val empty : _ t + +(** [next] returns the next element of a sequence and the next tail if the sequence is not + finished. *) +val next : 'a t -> ('a * 'a t) option + +(** A [Step] describes the next step of the sequence construction. [Done] indicates the + sequence is finished. [Skip] indicates the sequence continues with another state + without producing the next element yet. [Yield] outputs an element and introduces a + new state. + + Modifying ['s] doesn't violate any {e internal} invariants, but it may violate some + undocumented expectations. For example, one might expect that producing an element + from the same point in the sequence would always give the same value, but if the state + can mutate, that is not so. *) +module Step : sig + type ('a, 's) t = + | Done + | Skip of 's + | Yield of 'a * 's + [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('s -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 's) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] +end + +(** [unfold_step ~init ~f] constructs a sequence by giving an initial state [init] and a + function [f] explaining how to continue the next step from a given state. *) +val unfold_step : init:'s -> f:('s -> ('a, 's) Step.t) -> 'a t + +(** [unfold ~init f] is a simplified version of [unfold_step] that does not allow + [Skip]. *) +val unfold : init:'s -> f:('s -> ('a * 's) option) -> 'a t + +(** [unfold_with t ~init ~f] folds a state through the sequence [t] to create a new + sequence *) +val unfold_with : 'a t -> init:'s -> f:('s -> 'a -> ('b, 's) Step.t) -> 'b t + +(** [unfold_with_and_finish t ~init ~running_step ~inner_finished ~finishing_step] folds a + state through [t] to create a new sequence (like [unfold_with t ~init + ~f:running_step]), and then continues the new sequence by unfolding the final state + (like [unfold_step ~init:(inner_finished final_state) ~f:finishing_step]). *) +val unfold_with_and_finish + : 'a t + -> init : 's_a + -> running_step : ('s_a -> 'a -> ('b, 's_a) Step.t) + -> inner_finished : ('s_a -> 's_b) + -> finishing_step : ('s_b -> ('b, 's_b) Step.t) + -> 'b t + +(** Returns the nth element. *) +val nth : 'a t -> int -> 'a option +val nth_exn : 'a t -> int -> 'a + +(** [folding_map] is a version of [map] that threads an accumulator through calls to + [f]. *) +val folding_map : 'a t -> init:'b -> f:( 'b -> 'a -> 'b * 'c) -> 'c t +val folding_mapi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b * 'c) -> 'c t + +val mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + +val filteri : 'a t -> f: (int -> 'a -> bool) -> 'a t + +val filter : 'a t -> f: ('a -> bool) -> 'a t + +(** [merge t1 t2 ~compare] merges two sorted sequences [t1] and [t2], returning a sorted + sequence, all according to [compare]. If two elements are equal, the one from [t1] is + preferred. The behavior is undefined if the inputs aren't sorted. *) +val merge : 'a t -> 'a t -> compare:('a -> 'a -> int) -> 'a t + +module Merge_with_duplicates_element : sig + type ('a, 'b) t = + | Left of 'a + | Right of 'b + | Both of 'a * 'b + [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('a -> 'a -> int) -> + ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> 'b -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> + ('a, 'b) t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('a,'b) t := ('a, 'b) t + end[@@ocaml.doc "@inline"] + [@@@end] +end + +(** [merge_with_duplicates_element t1 t2 ~compare] is like [merge], except that for each + element it indicates which input(s) the element comes from, using + [Merge_with_duplicates_element]. *) +val merge_with_duplicates + : 'a t + -> 'b t + -> compare:('a -> 'b -> int) + -> ('a, 'b) Merge_with_duplicates_element.t t + +val hd : 'a t -> 'a option +val hd_exn : 'a t -> 'a + +(** [tl t] and [tl_eagerly_exn t] immediately evaluates the first element of [t] and + returns the unevaluated tail. *) +val tl : 'a t -> 'a t option +val tl_eagerly_exn : 'a t -> 'a t + +(** [find_exn t ~f] returns the first element of [t] that satisfies [f]. It raises if + there is no such element. *) +val find_exn : 'a t -> f:('a -> bool) -> 'a + +(** Like [for_all], but passes the index as an argument. *) +val for_alli : 'a t -> f:(int -> 'a -> bool) -> bool + +(** [append t1 t2] first produces the elements of [t1], then produces the elements of + [t2]. *) +val append : 'a t -> 'a t -> 'a t + +(** [concat tt] produces the elements of each inner sequence sequentially. If any inner + sequences are infinite, elements of subsequent inner sequences will not be reached. *) +val concat : 'a t t -> 'a t + +(** [concat_map t ~f] is [concat (map t ~f)].*) +val concat_map : 'a t -> f:('a -> 'b t) -> 'b t + +(** [concat_mapi t ~f] is like concat_map, but passes the index as an argument. *) +val concat_mapi : 'a t -> f:(int -> 'a -> 'b t) -> 'b t + +(** [interleave tt] produces each element of the inner sequences of [tt] eventually, even + if any or all of the inner sequences are infinite. The elements of each inner + sequence are produced in order with respect to that inner sequence. The manner of + interleaving among the separate inner sequences is deterministic but unspecified. *) +val interleave : 'a t t -> 'a t + +(** [round_robin list] is like [interleave (of_list list)], except that the manner of + interleaving among the inner sequences is guaranteed to be round-robin. The input + sequences may be of different lengths; an empty sequence is dropped from subsequent + rounds of interleaving. *) +val round_robin : 'a t list -> 'a t + +(** Transforms a pair of sequences into a sequence of pairs. The length of the returned + sequence is the length of the shorter input. The remaining elements of the longer + input are discarded. + + WARNING: Unlike [List.zip], this will not error out if the two input sequences are of + different lengths, because [zip] may have already returned some elements by the time + this becomes apparent. *) +val zip : 'a t -> 'b t -> ('a * 'b) t + +(** [zip_full] is like [zip], but if one sequence ends before the other, then it keeps + producing elements from the other sequence until it has ended as well. *) +val zip_full: 'a t -> 'b t -> [ `Left of 'a | `Both of 'a * 'b | `Right of 'b ] t + +(** [reduce_exn f [a1; ...; an]] is [f (... (f (f a1 a2) a3) ...) an]. It fails on the + empty sequence. *) +val reduce_exn : 'a t -> f:('a -> 'a -> 'a) -> 'a +val reduce : 'a t -> f:('a -> 'a -> 'a) -> 'a option + +(** [group l ~break] returns a sequence of lists (i.e., groups) whose concatenation is + equal to the original sequence. Each group is broken where [break] returns true on a + pair of successive elements. + + Example: + + {[ + group ~break:(<>) (of_list ['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i']) -> + + of_list [['M'];['i'];['s';'s'];['i'];['s';'s'];['i'];['p';'p'];['i']] ]} *) +val group : 'a t -> break:('a -> 'a -> bool) -> 'a list t + +(** [find_consecutive_duplicate t ~equal] returns the first pair of consecutive elements + [(a1, a2)] in [t] such that [equal a1 a2]. They are returned in the same order as + they appear in [t]. *) +val find_consecutive_duplicate : 'a t -> equal:('a -> 'a -> bool) -> ('a * 'a) option + +(** The same sequence with consecutive duplicates removed. The relative order of the + other elements is unaffected. *) +val remove_consecutive_duplicates : 'a t -> equal:('a -> 'a -> bool) -> 'a t + +(** [range ?stride ?start ?stop start_i stop_i] is the sequence of integers from [start_i] + to [stop_i], stepping by [stride]. If [stride] < 0 then we need [start_i] > [stop_i] + for the result to be nonempty (or [start_i] >= [stop_i] in the case where both bounds + are inclusive). *) +val range + : ?stride:int (** default is [1] *) + -> ?start:[`inclusive|`exclusive] (** default is [`inclusive] *) + -> ?stop:[`inclusive|`exclusive] (** default is [`exclusive] *) + -> int + -> int + -> int t + +(** [init n ~f] is [[(f 0); (f 1); ...; (f (n-1))]]. It is an error if [n < 0]. *) +val init : int -> f:(int -> 'a) -> 'a t + +(** [filter_map t ~f] produce mapped elements of [t] which are not [None]. *) +val filter_map : 'a t -> f:('a -> 'b option) -> 'b t + +(** [filter_mapi] is just like [filter_map], but it also passes in the index of each + element to [f]. *) +val filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + +(** [filter_opt t] produces the elements of [t] which are not [None]. [filter_opt t] = + [filter_map t ~f:ident]. *) +val filter_opt : 'a option t -> 'a t + +(** [sub t ~pos ~len] is the [len]-element subsequence of [t], starting at [pos]. If the + sequence is shorter than [pos + len], it returns [ t[pos] ... t[l-1] ], where [l] is + the length of the sequence. *) +val sub : 'a t -> pos:int -> len:int -> 'a t + +(** [take t n] produces the first [n] elements of [t]. *) +val take : 'a t -> int -> 'a t + +(** [drop t n] produces all elements of [t] except the first [n] elements. If there are + fewer than [n] elements in [t], there is no error; the resulting sequence simply + produces no elements. Usually you will probably want to use [drop_eagerly] because it + can be significantly cheaper. *) +val drop : 'a t -> int -> 'a t + +(** [drop_eagerly t n] immediately consumes the first [n] elements of [t] and returns the + unevaluated tail of [t]. *) +val drop_eagerly : 'a t -> int -> 'a t + +(** [take_while t ~f] produces the longest prefix of [t] for which [f] applied to each + element is [true]. *) +val take_while : 'a t -> f : ('a -> bool) -> 'a t + +(** [drop_while t ~f] produces the suffix of [t] beginning with the first element of [t] + for which [f] is [false]. Usually you will probably want to use [drop_while_option] + because it can be significantly cheaper. *) +val drop_while : 'a t -> f : ('a -> bool) -> 'a t + +(** [drop_while_option t ~f] immediately consumes the elements from [t] until the + predicate [f] fails and returns the first element that failed along with the + unevaluated tail of [t]. The first element is returned separately because the + alternatives would mean forcing the consumer to evaluate the first element again (if + the previous state of the sequence is returned) or take on extra cost for each element + (if the element is added to the final state of the sequence using [shift_right]). *) +val drop_while_option : 'a t -> f : ('a -> bool) -> ('a * 'a t) option + +(** [split_n t n] immediately consumes the first [n] elements of [t] and returns the + consumed prefix, as a list, along with the unevaluated tail of [t]. *) +val split_n : 'a t -> int -> 'a list * 'a t + +(** [chunks_exn t n] produces lists of elements of [t], up to [n] elements at a time. The + last list may contain fewer than [n] elements. No list contains zero elements. If [n] + is not positive, it raises. *) +val chunks_exn : 'a t -> int -> 'a list t + +(** [shift_right t a] produces [a] and then produces each element of [t]. *) +val shift_right : 'a t -> 'a -> 'a t + +(** [shift_right_with_list t l] produces the elements of [l], then produces the elements + of [t]. It is better to call [shift_right_with_list] with a list of size n than + [shift_right] n times; the former will require O(1) work per element produced and the + latter O(n) work per element produced. *) +val shift_right_with_list : 'a t -> 'a list -> 'a t + +(** [shift_left t n] is a synonym for [drop t n].*) +val shift_left : 'a t -> int -> 'a t + +module Infix : sig + val ( @ ) : 'a t -> 'a t -> 'a t +end + +(** Returns a sequence with all possible pairs. The stepper function of the second + sequence passed as argument may be applied to the same state multiple times, so be + careful using [cartesian_product] with expensive or side-effecting functions. If the + second sequence is infinite, some values in the first sequence may not be reached. *) +val cartesian_product : 'a t -> 'b t -> ('a * 'b) t + +(** Returns a sequence that eventually reaches every possible pair of elements of the + inputs, even if either or both are infinite. The step function of both inputs may be + applied to the same state repeatedly, so be careful using + [interleaved_cartesian_product] with expensive or side-effecting functions. *) +val interleaved_cartesian_product : 'a t -> 'b t -> ('a * 'b) t + +(** [intersperse xs ~sep] produces [sep] between adjacent elements of [xs], e.g., + [intersperse [1;2;3] ~sep:0 = [1;0;2;0;3]]. *) +val intersperse : 'a t -> sep:'a -> 'a t + +(** [cycle_list_exn xs] repeats the elements of [xs] forever. If [xs] is empty, it + raises. *) +val cycle_list_exn : 'a list -> 'a t + +(** [repeat a] repeats [a] forever. *) +val repeat : 'a -> 'a t + +(** [singleton a] produces [a] exactly once. *) +val singleton : 'a -> 'a t + +(** [delayed_fold] allows to do an on-demand fold, while maintaining a state. + + It is possible to exit early by not calling [k] in [f]. It is also possible to call + [k] multiple times. This results in the rest of the sequence being folded over + multiple times, independently. + + Note that [delayed_fold], when targeting JavaScript, can result in stack overflow as + JavaScript doesn't generally have tail call optimization. *) +val delayed_fold + : 'a t + -> init:'s + -> f:('s -> 'a -> k:('s -> 'r) -> 'r) (** [k] stands for "continuation" *) + -> finish:('s -> 'r) + -> 'r + +(** [fold_m] is a monad-friendly version of [fold]. Supply it with the monad's [return] + and [bind], and it will chain them through the computation. *) +val fold_m + : bind:('acc_m -> f:('acc -> 'acc_m) -> 'acc_m) + -> return:('acc -> 'acc_m) + -> 'elt t + -> init:'acc + -> f:('acc -> 'elt -> 'acc_m) + -> 'acc_m + +(** [iter_m] is a monad-friendly version of [iter]. Supply it with the monad's [return] + and [bind], and it will chain them through the computation. *) +val iter_m + : bind:('unit_m -> f:(unit -> 'unit_m) -> 'unit_m) + -> return:(unit -> 'unit_m) + -> 'elt t + -> f:('elt -> 'unit_m) + -> 'unit_m + +(** [to_list_rev t] returns a list of the elements of [t], in reverse order. It is faster + than [to_list]. *) +val to_list_rev : 'a t -> 'a list + +val of_list : 'a list -> 'a t + +(** [of_lazy t_lazy] produces a sequence that forces [t_lazy] the first time it needs to + compute an element. *) +val of_lazy : 'a t Lazy.t -> 'a t + +(** [memoize t] produces each element of [t], but also memoizes them so that if you + consume the same element multiple times it is only computed once. It's a non-eager + version of [force_eagerly]. *) +val memoize : 'a t -> 'a t + +(** [force_eagerly t] precomputes the sequence. It is behaviorally equivalent to [of_list + (to_list t)], but may at some point have a more efficient implementation. It's an + eager version of [memoize]. *) +val force_eagerly : 'a t -> 'a t + +(** [bounded_length ~at_most t] returns [`Is len] if [len = length t <= at_most], and + otherwise returns [`Greater]. Walks through only as much of the sequence as + necessary. Always returns [`Greater] if [at_most < 0]. *) +val bounded_length : _ t -> at_most:int -> [ `Is of int | `Greater ] + +(** [length_is_bounded_by ~min ~max t] returns true if [min <= length t] and [length t <= + max] When [min] or [max] are not provided, the check for that bound is omitted. Walks + through only as much of the sequence as necessary. *) +val length_is_bounded_by: ?min:int -> ?max:int -> _ t -> bool + +(** [Generator] is a monadic interface to generate sequences in a direct style, similar to + Python's generators. + + Here are some examples: + + {[ + open Generator + + let rec traverse_list = function + | [] -> return () + | x :: xs -> yield x >>= fun () -> traverse_list xs + + let traverse_option = function + | None -> return () + | Some x -> yield x + + let traverse_array arr = + let n = Array.length arr in + let rec loop i = + if i >= n then return () else yield arr.(i) >>= fun () -> loop (i + 1) + in + loop 0 + + let rec traverse_bst = function + | Node.Empty -> return () + | Node.Branch (left, value, right) -> + traverse_bst left >>= fun () -> + yield value >>= fun () -> + traverse_bst right + + let sequence_of_list x = Generator.run (traverse_list x) + let sequence_of_option x = Generator.run (traverse_option x) + let sequence_of_array x = Generator.run (traverse_array x) + let sequence_of_bst x = Generator.run (traverse_bst x) + ]} *) + +module Generator : sig + include Monad.S2 + val yield : 'elt -> (unit, 'elt) t + val of_sequence : 'elt sequence -> (unit, 'elt) t + val run : (unit, 'elt) t -> 'elt sequence +end + +(** The functions in [Expert] expose internal structure which is normally meant to be + hidden. For example, at least when [f] is purely functional, it is not intended for + client code to distinguish between + + {[ + List.filter xs ~f + |> Sequence.of_list + ]} + + and + + {[ + Sequence.of_list xs + |> Sequence.filter ~f + ]} + + But sometimes for operational reasons it still makes sense to distinguish them. For + example, being able to handle [Skip]s explicitly allows breaking up some + computationally expensive sequences into smaller chunks of work. *) +module Expert : sig + (** [next_step] returns the next step in a sequence's construction. It is like [next], + but it also allows observing [Skip] steps. *) + val next_step : 'a t -> ('a, 'a t) Step.t + + (** [delayed_fold_step] is liked [delayed_fold], but [f] takes an option where [None] + represents a [Skip] step. *) + val delayed_fold_step + : 'a t + -> init:'s + -> f:('s -> 'a option -> k:('s -> 'r) -> 'r) (** [k] stands for "continuation" *) + -> finish:('s -> 'r) + -> 'r +end diff --git a/src/set.ml b/src/set.ml new file mode 100644 index 0000000..13af40f --- /dev/null +++ b/src/set.ml @@ -0,0 +1,1314 @@ +(***********************************************************************) +(* *) +(* Objective Caml *) +(* *) +(* Xavier Leroy, projet Cristal, INRIA Rocquencourt *) +(* *) +(* Copyright 1996 Institut National de Recherche en Informatique et *) +(* en Automatique. All rights reserved. This file is distributed *) +(* under the terms of the Apache 2.0 license. See ../THIRD-PARTY.txt *) +(* for details. *) +(* *) +(***********************************************************************) + +(* Sets over ordered types *) + +open! Import + +include Set_intf + +let with_return = With_return.with_return + + +module Tree0 = struct + type 'a t = + | Empty + (* (Leaf x) is the same as (Node (Empty, x, Empty, 1, 1)) but uses less space. *) + | Leaf of 'a + (* first int is height, second is sub-tree size *) + | Node of 'a t * 'a * 'a t * int * int + + type 'a tree = 'a t + + (* Sets are represented by balanced binary trees (the heights of the children differ by + at most 2. *) + let height = function + | Empty -> 0 + | Leaf _ -> 1 + | Node(_, _, _, h, _) -> h + ;; + + let length = function + | Empty -> 0 + | Leaf _ -> 1 + | Node(_, _, _, _, s) -> s + ;; + + let invariants = + let in_range lower upper compare_elt v = + (match lower with + | None -> true + | Some lower -> compare_elt lower v < 0 + ) + && (match upper with + | None -> true + | Some upper -> compare_elt v upper < 0 + ) + in + let rec loop lower upper compare_elt t = + match t with + | Empty -> true + | Leaf v -> in_range lower upper compare_elt v + | Node (l, v, r, h, n) -> + let hl = height l and hr = height r in + abs (hl - hr) <= 2 + && h = (max hl hr) + 1 + && n = length l + length r + 1 + && in_range lower upper compare_elt v + && loop lower (Some v) compare_elt l + && loop (Some v) upper compare_elt r + in + fun t ~compare_elt -> loop None None compare_elt t + ;; + + let is_empty = function Empty -> true | Leaf _ | Node _ -> false + + (* Creates a new node with left son l, value v and right son r. + We must have all elements of l < v < all elements of r. + l and r must be balanced and | height l - height r | <= 2. + Inline expansion of height for better speed. *) + + let create l v r = + let hl = match l with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,h,_) -> h in + let hr = match r with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,h,_) -> h in + let h = if hl >= hr then hl + 1 else hr + 1 in + if h = 1 + then Leaf v + else begin + let sl = match l with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,_,s) -> s in + let sr = match r with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,_,s) -> s in + Node (l, v, r, h, sl + sr + 1) + end + + (* We must call [f] with increasing indexes, because the bin_prot reader in + Core_kernel.Set needs it. *) + let of_increasing_iterator_unchecked ~len ~f = + let rec loop n ~f i = + match n with + | 0 -> Empty + | 1 -> + let k = f i in + Leaf k + | 2 -> + let kl = f i in + let k = f (i + 1) in + create (Leaf kl) k (Empty) + | 3 -> + let kl = f i in + let k = f (i + 1) in + let kr = f (i + 2) in + create (Leaf kl) k (Leaf kr) + | n -> + let left_length = n lsr 1 in + let right_length = n - left_length - 1 in + let left = loop left_length ~f i in + let k = f (i + left_length) in + let right = loop right_length ~f (i + left_length + 1) in + create left k right + in + loop len ~f 0 + + let of_sorted_array_unchecked array ~compare_elt = + let array_length = Array.length array in + let next = + (* We don't check if the array is sorted or keys are duplicated, because that + checking is slower than the whole [of_sorted_array] function *) + if array_length < 2 || compare_elt array.(0) array.(1) < 0 + then (fun i -> array.(i)) + else (fun i -> array.(array_length - 1 - i)) + in + of_increasing_iterator_unchecked ~len:array_length ~f:next + ;; + + let of_sorted_array array ~compare_elt = + match array with + | [||] | [|_|] -> Result.Ok (of_sorted_array_unchecked array ~compare_elt) + | _ -> + with_return (fun r -> + let increasing = + match compare_elt array.(0) array.(1) with + | 0 -> r.return (Or_error.error_string "of_sorted_array: duplicated elements") + | i -> i < 0 + in + for i = 1 to Array.length array - 2 do + match compare_elt array.(i) array.(i+1) with + | 0 -> r.return (Or_error.error_string "of_sorted_array: duplicated elements") + | i -> + if Poly.(<>) (i < 0) increasing then + r.return (Or_error.error_string "of_sorted_array: elements are not ordered") + done; + Result.Ok (of_sorted_array_unchecked array ~compare_elt) + ) + + (* Same as create, but performs one step of rebalancing if necessary. + Assumes l and r balanced and | height l - height r | <= 3. + Inline expansion of create for better speed in the most frequent case + where no rebalancing is required. *) + + let bal l v r = + let hl = match l with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,h,_) -> h in + let hr = match r with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,h,_) -> h in + if hl > hr + 2 then begin + match l with + | Empty -> assert false + | Leaf _ -> assert false (* because h(l)>h(r)+2 and h(leaf)=1 *) + | Node (ll, lv, lr, _, _) -> + if height ll >= height lr then + create ll lv (create lr v r) + else begin + match lr with + | Empty -> assert false + | Leaf lrv -> + assert (is_empty ll); + create (create ll lv Empty) lrv (create Empty v r) + | Node(lrl, lrv, lrr, _, _)-> + create (create ll lv lrl) lrv (create lrr v r) + end + end else if hr > hl + 2 then begin + match r with + Empty -> assert false + | Leaf rv -> create (create l v Empty) rv Empty + | Node(rl, rv, rr, _, _) -> + if height rr >= height rl then + create (create l v rl) rv rr + else begin + match rl with + Empty -> assert false + | Leaf rlv -> + assert (is_empty rr); + create (create l v Empty) rlv (create Empty rv rr) + | Node(rll, rlv, rlr, _, _) -> + create (create l v rll) rlv (create rlr rv rr) + end + end else begin + let h = if hl >= hr then hl + 1 else hr + 1 in + let sl = match l with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,_,s) -> s in + let sr = match r with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,_,s) -> s in + if h = 1 + then Leaf v + else Node (l, v, r, h, sl + sr + 1) + end + + (* Insertion of one element *) + + exception Same + + let add t x ~compare_elt = + let rec aux = function + | Empty -> Leaf x + | Leaf v -> + let c = compare_elt x v in + if c = 0 then + raise Same + else if c < 0 then + bal (Leaf x) v Empty + else + bal Empty v (Leaf x) + | Node(l, v, r, _, _) -> + let c = compare_elt x v in + if c = 0 then + raise Same + else if c < 0 then + bal (aux l) v r + else + bal l v (aux r) + in + try aux t with Same -> t + ;; + + (* Same as create and bal, but no assumptions are made on the relative heights of l and + r. *) + let rec join l v r ~compare_elt = + match (l, r) with + | (Empty, _) -> add r v ~compare_elt + | (_, Empty) -> add l v ~compare_elt + | (Leaf lv, _) -> add (add r v ~compare_elt) lv ~compare_elt + | (_, Leaf rv) -> add (add l v ~compare_elt) rv ~compare_elt + | (Node (ll, lv, lr, lh, _), Node (rl, rv, rr, rh, _)) -> + if lh > rh + 2 then bal ll lv (join lr v r ~compare_elt) else + if rh > lh + 2 then bal (join l v rl ~compare_elt) rv rr else + create l v r + ;; + + (* Smallest and greatest element of a set *) + let rec min_elt = function + | Empty -> None + | Leaf v + | Node(Empty, v, _, _, _) -> Some v + | Node(l, _, _, _, _) -> min_elt l + ;; + + exception Set_min_elt_exn_of_empty_set [@@deriving_inline sexp] + let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Set_min_elt_exn_of_empty_set]) + (function + | Set_min_elt_exn_of_empty_set -> + Ppx_sexp_conv_lib.Sexp.Atom + "src/set.ml.Tree0.Set_min_elt_exn_of_empty_set" + | _ -> assert false) + [@@@end] + exception Set_max_elt_exn_of_empty_set [@@deriving_inline sexp] + let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Set_max_elt_exn_of_empty_set]) + (function + | Set_max_elt_exn_of_empty_set -> + Ppx_sexp_conv_lib.Sexp.Atom + "src/set.ml.Tree0.Set_max_elt_exn_of_empty_set" + | _ -> assert false) + [@@@end] + + let min_elt_exn t = + match min_elt t with + | None -> raise Set_min_elt_exn_of_empty_set + | Some v -> v + ;; + + let fold_until t ~init ~f ~finish = + let rec fold_until_helper ~f t acc = + match t with + | Empty -> Continue_or_stop.Continue acc + | Leaf value -> f acc value + | Node(left, value, right, _, _) -> + match fold_until_helper ~f left acc with + | Stop _a as x -> x + | Continue acc -> + match f acc value with + | Stop _a as x -> x + | Continue a -> fold_until_helper ~f right a + in + match fold_until_helper ~f t init with + | Continue x -> finish x + | Stop x -> x + ;; + + let rec max_elt = function + | Empty -> None + | Leaf v + | Node(_, v, Empty, _, _) -> Some v + | Node(_, _, r, _, _) -> max_elt r + ;; + + let max_elt_exn t = + match max_elt t with + | None -> raise Set_max_elt_exn_of_empty_set + | Some v -> v + ;; + + (* Remove the smallest element of the given set *) + + let rec remove_min_elt = function + | Empty -> invalid_arg "Set.remove_min_elt" + | Leaf _ -> Empty + | Node(Empty, _, r, _, _) -> r + | Node(l, v, r, _, _) -> bal (remove_min_elt l) v r + ;; + + (* Merge two trees l and r into one. All elements of l must precede the elements of r. + Assume | height l - height r | <= 2. *) + let merge t1 t2 = + match (t1, t2) with + | (Empty, t) -> t + | (t, Empty) -> t + | (_, _) -> bal t1 (min_elt_exn t2) (remove_min_elt t2) + ;; + + (* Merge two trees l and r into one. All elements of l must precede the elements of r. + No assumption on the heights of l and r. *) + let concat t1 t2 ~compare_elt = + match (t1, t2) with + | Empty, t | t, Empty -> t + | (_, _) -> join t1 (min_elt_exn t2) (remove_min_elt t2) ~compare_elt + ;; + + let split t x ~compare_elt = + let rec split t = + match t with + | Empty -> (Empty, None, Empty) + | Leaf v -> + let c = compare_elt x v in + if c = 0 then (Empty, Some v, Empty) + else if c < 0 then (Empty, None, Leaf v) + else (Leaf v, None, Empty) + | Node (l, v, r, _, _) -> + let c = compare_elt x v in + if c = 0 then (l, Some v, r) + else if c < 0 then + let (ll, maybe_elt, rl) = split l in + (ll, maybe_elt, join rl v r ~compare_elt) + else + let (lr, maybe_elt, rr) = split r in + (join l v lr ~compare_elt, maybe_elt, rr) + in + split t + ;; + + (* Implementation of the set operations *) + + let empty = Empty + + let rec mem t x ~compare_elt = + match t with + | Empty -> false + | Leaf v -> + let c = compare_elt x v in + c = 0 + | Node(l, v, r, _, _) -> + let c = compare_elt x v in + c = 0 || mem (if c < 0 then l else r) x ~compare_elt + ;; + + let singleton x = Leaf x + + let remove t x ~compare_elt = + let rec aux t = + match t with + | Empty -> raise Same + | Leaf v -> if compare_elt x v = 0 then Empty else raise Same + | Node(l, v, r, _, _) -> + let c = compare_elt x v in + if c = 0 then + merge l r + else if c < 0 then + bal (aux l) v r + else + bal l v (aux r) + in + try aux t with Same -> t + ;; + + let remove_index t i ~compare_elt:_ = + let rec aux t i = + match t with + | Empty -> raise Same + | Leaf _ -> if i = 0 then Empty else raise Same + | Node (l, v, r, _, _) -> + let l_size = length l in + let c = Poly.compare i l_size in + if c = 0 then + merge l r + else if c < 0 then + bal (aux l i) v r + else + bal l v (aux r (i - l_size - 1)) + in + try aux t i with Same -> t + ;; + + let union s1 s2 ~compare_elt = + let rec union s1 s2 = + match s1, s2 with + | Empty, t | t, Empty -> t + | Leaf v1, _ -> union (Node(Empty, v1, Empty, 1, 1)) s2 + | _, Leaf v2 -> union s1 (Node(Empty, v2, Empty, 1, 1)) + | (Node(l1, v1, r1, h1, _), Node(l2, v2, r2, h2, _)) -> + if h1 >= h2 + then + if h2 = 1 + then add s1 v2 ~compare_elt + else begin + let (l2, _, r2) = split s2 v1 ~compare_elt in + join (union l1 l2) v1 (union r1 r2) ~compare_elt + end + else + if h1 = 1 + then add s2 v1 ~compare_elt + else begin + let (l1, _, r1) = split s1 v2 ~compare_elt in + join (union l1 l2) v2 (union r1 r2) ~compare_elt + end + in + union s1 s2 + ;; + + let union_list ~comparator ~to_tree xs = + let compare_elt = comparator.Comparator.compare in + List.fold xs ~init:empty ~f:(fun ac x -> union ac (to_tree x) ~compare_elt) + ;; + + let inter s1 s2 ~compare_elt = + let rec inter s1 s2 = + match s1, s2 with + | Empty, _ | _, Empty -> Empty + | ((Leaf elt as singleton), other_set) + | (other_set, (Leaf elt as singleton)) -> + if mem other_set elt ~compare_elt then singleton else Empty + | (Node (l1, v1, r1, _, _), t2) -> + match split t2 v1 ~compare_elt with + | (l2, None, r2) -> concat (inter l1 l2) (inter r1 r2) ~compare_elt + | (l2, Some v1, r2) -> join (inter l1 l2) v1 (inter r1 r2) ~compare_elt + in + inter s1 s2 + ;; + + let diff s1 s2 ~compare_elt = + let rec diff s1 s2 = + match s1, s2 with + | (Empty, _) -> Empty + | (t1, Empty) -> t1 + | (Leaf v1, t2) -> diff (Node(Empty, v1, Empty, 1, 1)) t2 + | (Node(l1, v1, r1, _, _), t2) -> + match split t2 v1 ~compare_elt with + | (l2, None, r2) -> + join (diff l1 l2) v1 (diff r1 r2) ~compare_elt + | (l2, Some _, r2) -> + concat (diff l1 l2) (diff r1 r2) ~compare_elt + in + diff s1 s2 + ;; + + module Enum = struct + type increasing + type decreasing + type ('a, 'direction) t = End | More of 'a * 'a tree * ('a, 'direction) t + + let rec cons s (e : (_, increasing) t) : (_, increasing) t = + match s with + | Empty -> e + | Leaf v -> (More (v, Empty, e)) + | Node (l, v, r, _, _) -> cons l (More (v, r, e)) + ;; + + let rec cons_right s (e : (_, decreasing) t) : (_, decreasing) t = + match s with + | Empty -> e + | Leaf v -> More (v, Empty, e) + | Node (l, v, r, _, _) -> cons_right r (More (v, l, e)) + ;; + + let of_set s : (_, increasing) t = cons s End + + let of_set_right s : (_, decreasing) t = cons_right s End + + let starting_at_increasing t key compare : (_, increasing) t = + let rec loop t e = + match t with + | Empty -> e + | Leaf v -> loop (Node (Empty, v, Empty, 1, 1)) e + | Node(_, v, r, _, _) when compare v key < 0 -> loop r e + | Node(l, v, r, _, _) -> loop l (More(v, r, e)) + in + loop t End + ;; + + let starting_at_decreasing t key compare : (_, decreasing) t = + let rec loop t e = + match t with + | Empty -> e + | Leaf v -> loop (Node (Empty, v, Empty, 1, 1)) e + | Node(l, v, _, _, _) when compare v key > 0 -> loop l e + | Node(l, v, r, _, _) -> loop r (More(v, l, e)) + in + loop t End + ;; + + let compare compare_elt e1 e2 = + let rec loop e1 e2 = + match e1, e2 with + | End, End -> 0 + | End, _ -> -1 + | _, End -> 1 + | More (v1, r1, e1), More (v2, r2, e2) -> + let c = compare_elt v1 v2 in + if c <> 0 + then c + else loop (cons r1 e1) (cons r2 e2) + in + loop e1 e2 + ;; + + let rec iter ~f = function + | End -> () + | More (a, tree, enum) -> + f a; + iter (cons tree enum) ~f + ;; + + let iter2 compare_elt t1 t2 ~f = + let rec loop t1 t2 = + match t1, t2 with + | End, End -> () + | End, _ -> iter t2 ~f:(fun a -> f (`Right a)) + | _, End -> iter t1 ~f:(fun a -> f (`Left a)) + | More (a1, tree1, enum1), More (a2, tree2, enum2) -> + let compare_result = compare_elt a1 a2 in + if compare_result = 0 then begin + f (`Both (a1, a2)); + loop (cons tree1 enum1) (cons tree2 enum2) + end else if compare_result < 0 then begin + f (`Left a1); + loop (cons tree1 enum1) t2 + end else begin + f (`Right a2); + loop t1 (cons tree2 enum2) + end + in + loop t1 t2 + + let symmetric_diff t1 t2 ~compare_elt = + let step state : ((_,_) Either.t, _) Sequence.Step.t = + match state with + | End, End -> + Done + | End, More (elt, tree, enum) -> + Yield (Second elt, (End, cons tree enum)) + | More (elt, tree, enum), End -> + Yield (First elt, (cons tree enum, End)) + | (More (a1, tree1, enum1) as left), (More (a2, tree2, enum2) as right) -> + let compare_result = compare_elt a1 a2 in + if compare_result = 0 then begin + let next_state = + if phys_equal tree1 tree2 + then (enum1, enum2) + else (cons tree1 enum1, cons tree2 enum2) + in + Skip next_state + end else if compare_result < 0 then begin + Yield (First a1, (cons tree1 enum1, right)) + end else begin + Yield (Second a2, (left, cons tree2 enum2)) + end + in + Sequence.unfold_step ~init:(of_set t1, of_set t2) ~f:step + ;; + + end + + let to_sequence_increasing comparator ~from_elt t = + let next enum = + match enum with + | Enum.End -> Sequence.Step.Done + | Enum.More (k, t, e) -> Sequence.Step.Yield (k, Enum.cons t e) + in + let init = + match from_elt with + | None -> Enum.of_set t + | Some key -> Enum.starting_at_increasing t key comparator.Comparator.compare + in + Sequence.unfold_step ~init ~f:next + ;; + + let to_sequence_decreasing comparator ~from_elt t = + let next enum = + match enum with + | Enum.End -> Sequence.Step.Done + | Enum.More (k, t, e) -> Sequence.Step.Yield (k, Enum.cons_right t e) + in + let init = + match from_elt with + | None -> Enum.of_set_right t + | Some key -> Enum.starting_at_decreasing t key comparator.Comparator.compare + in + Sequence.unfold_step ~init ~f:next + ;; + + let to_sequence comparator ?(order = `Increasing) + ?greater_or_equal_to ?less_or_equal_to t = + let inclusive_bound side t bound = + let compare_elt = comparator.Comparator.compare in + let l, maybe, r = split t bound ~compare_elt in + let t = side (l, r) in + match maybe with + | None -> t + | Some elt -> add t elt ~compare_elt + in + match order with + | `Increasing -> + let t = Option.fold less_or_equal_to ~init:t ~f:(inclusive_bound fst) in + to_sequence_increasing comparator ~from_elt:greater_or_equal_to t + | `Decreasing -> + let t = Option.fold greater_or_equal_to ~init:t ~f:(inclusive_bound snd) in + to_sequence_decreasing comparator ~from_elt:less_or_equal_to t + ;; + + let merge_to_sequence comparator ?(order = `Increasing) + ?greater_or_equal_to ?less_or_equal_to t t' = + Sequence.merge_with_duplicates + (to_sequence comparator ~order ?greater_or_equal_to ?less_or_equal_to t) + (to_sequence comparator ~order ?greater_or_equal_to ?less_or_equal_to t') + ~compare:begin + match order with + | `Increasing -> comparator.compare + | `Decreasing -> Fn.flip comparator.compare + end + ;; + + let compare compare_elt s1 s2 = + Enum.compare compare_elt (Enum.of_set s1) (Enum.of_set s2) + ;; + + let iter2 s1 s2 ~compare_elt = + Enum.iter2 compare_elt (Enum.of_set s1) (Enum.of_set s2) + + let equal s1 s2 ~compare_elt = compare compare_elt s1 s2 = 0 + + let is_subset s1 ~of_:s2 ~compare_elt = + let rec is_subset s1 ~of_:s2 = + match s1, s2 with + | Empty, _ -> true + | _, Empty -> false + | Leaf v1, t2 -> mem t2 v1 ~compare_elt + | Node (l1, v1, r1, _, _), Leaf v2 -> + begin match l1, r1 with + | Empty, Empty -> + (* This case shouldn't occur in practice because we should have constructed + a Leaf rather than a Node with two Empty subtrees *) + compare_elt v1 v2 = 0 + | _, _ -> false + end + | Node (l1, v1, r1, _, _), (Node (l2, v2, r2, _, _) as t2) -> + let c = compare_elt v1 v2 in + if c = 0 + then is_subset l1 ~of_:l2 && is_subset r1 ~of_:r2 + (* Note that height and size don't matter here. *) + else if c < 0 then + is_subset (Node (l1, v1, Empty, 0, 0)) ~of_:l2 && is_subset r1 ~of_:t2 + else + is_subset (Node (Empty, v1, r1, 0, 0)) ~of_:r2 && is_subset l1 ~of_:t2 + in + is_subset s1 ~of_:s2 + ;; + + let iter t ~f = + let rec iter = function + | Empty -> () + | Leaf v -> f v + | Node(l, v, r, _, _) -> iter l; f v; iter r + in + iter t + ;; + + let symmetric_diff = Enum.symmetric_diff + + let rec fold s ~init:accu ~f = + match s with + | Empty -> accu + | Leaf v -> f accu v + | Node(l, v, r, _, _) -> fold ~f r ~init:(f (fold ~f l ~init:accu) v) + ;; + + let hash_fold_t_ignoring_structure hash_fold_elem state t = + fold t ~init:(hash_fold_int state (length t)) ~f:hash_fold_elem + ;; + + let count t ~f = Container.count ~fold t ~f + let sum m t ~f = Container.sum ~fold m t ~f + + let rec fold_right s ~init:accu ~f = + match s with + | Empty -> accu + | Leaf v -> f v accu + | Node(l, v, r, _, _) -> fold_right ~f l ~init:(f v (fold_right ~f r ~init:accu)) + ;; + + let rec for_all t ~f:p = match t with + | Empty -> true + | Leaf v -> p v + | Node(l, v, r, _, _) -> p v && for_all ~f:p l && for_all ~f:p r + ;; + + let rec exists t ~f:p = match t with + | Empty -> false + | Leaf v -> p v + | Node(l, v, r, _, _) -> p v || exists ~f:p l || exists ~f:p r + ;; + + let filter s ~f:p ~compare_elt = + let rec filt accu = function + | Empty -> accu + | Leaf v -> if p v then add accu v ~compare_elt else accu + | Node(l, v, r, _, _) -> + filt (filt (if p v then add accu v ~compare_elt else accu) l) r + in + filt Empty s + ;; + + let filter_map s ~f:p ~compare_elt = + let rec filt accu = function + | Empty -> accu + | Leaf v -> + (match p v with + | None -> accu + | Some v -> add accu v ~compare_elt) + | Node(l, v, r, _, _) -> + filt (filt (match p v with + | None -> accu + | Some v -> add accu v ~compare_elt) l) r + in + filt Empty s + ;; + + let partition_tf s ~f:p ~compare_elt = + let rec part ((t, f) as accu) = function + | Empty -> accu + | Leaf v -> if p v then (add t v ~compare_elt, f) else (t, add f v ~compare_elt) + | Node(l, v, r, _, _) -> + part (part ( + if p v + then (add t v ~compare_elt, f) + else (t, add f v ~compare_elt)) l) r + in + part (Empty, Empty) s + ;; + + let rec elements_aux accu = function + | Empty -> accu + | Leaf v -> v :: accu + | Node(l, v, r, _, _) -> elements_aux (v :: elements_aux accu r) l + ;; + + let elements s = elements_aux [] s + + let choose t = + match t with + | Empty -> None + | Leaf v -> Some v + | Node (_, v, _, _, _) -> Some v + ;; + + let choose_exn t = + match choose t with + | None -> + raise Caml.Not_found + | Some v -> v + ;; + + let of_list lst ~compare_elt = + List.fold lst ~init:empty ~f:(fun t x -> add t x ~compare_elt) + ;; + + let to_list s = elements s + + let of_array a ~compare_elt = + Array.fold a ~init:empty ~f:(fun t x -> add t x ~compare_elt) + ;; + + (* faster but equivalent to [Array.of_list (to_list t)] *) + let to_array = function + | Empty -> [||] + | Leaf v -> [| v |] + | Node (l, v, r, _, s) -> + let res = Array.create ~len:s v in + let pos_ref = ref 0 in + let rec loop = function + (* Invariant: on entry and on exit to [loop], !pos_ref is the next + available cell in the array. *) + | Empty -> () + | Leaf v -> + res.(!pos_ref) <- v; + incr pos_ref + | Node (l, v, r, _, _) -> + loop l; + res.(!pos_ref) <- v; + incr pos_ref; + loop r + in + loop l; + (* res.(!pos_ref) is already initialized (by Array.create ~len:above). *) + incr pos_ref; + loop r; + res + ;; + + let map t ~f ~compare_elt = fold t ~init:empty ~f:(fun t x -> add t (f x) ~compare_elt) + + let group_by set ~equiv ~compare_elt = + let rec loop set equiv_classes = + if is_empty set + then equiv_classes + else + let x = choose_exn set in + let equiv_x, not_equiv_x = + partition_tf set ~f:(fun elt -> phys_equal x elt || equiv x elt) ~compare_elt + in + loop not_equiv_x (equiv_x :: equiv_classes) + in + loop set [] + ;; + + let rec find t ~f = + match t with + | Empty -> None + | Leaf v -> if f v then Some v else None + | Node(l, v, r, _, _) -> + if f v then Some v + else + match find l ~f with + | None -> find r ~f + | Some _ as r -> r + ;; + + let rec find_map t ~f = + match t with + | Empty -> None + | Leaf v -> f v + | Node(l, v, r, _, _) -> + match f v with + | Some _ as r -> r + | None -> + match find_map l ~f with + | None -> find_map r ~f + | Some _ as r -> r + ;; + + let find_exn t ~f = + match find t ~f with + | None -> failwith "Set.find_exn failed to find a matching element" + | Some e -> e + ;; + + let rec nth t i = + match t with + | Empty -> None + | Leaf v -> if i = 0 then Some v else None + | Node (l, v, r, _, s) -> + if i >= s then None + else begin + let l_size = length l in + let c = Poly.compare i l_size in + if c < 0 then nth l i + else if c = 0 then Some v + else nth r (i - l_size - 1) + end + ;; + + let stable_dedup_list xs ~compare_elt = + let rec loop xs leftovers already_seen = + match xs with + | [] -> List.rev leftovers + | hd :: tl -> + if mem already_seen hd ~compare_elt + then loop tl leftovers already_seen + else loop tl (hd :: leftovers) (add already_seen hd ~compare_elt) + in + loop xs [] empty + ;; + + let t_of_sexp_direct a_of_sexp sexp ~compare_elt = + match sexp with + | Sexp.List lst -> + let elt_lst = List.map lst ~f:a_of_sexp in + let set = of_list elt_lst ~compare_elt in + if length set = List.length lst then + set + else + let compare (_, e) (_, e') = compare_elt e e' in + begin match List.find_a_dup (List.zip_exn lst elt_lst) ~compare with + | None -> assert false + | Some (el_sexp, _) -> + of_sexp_error "Set.t_of_sexp: duplicate element in set" el_sexp + end + | sexp -> of_sexp_error "Set.t_of_sexp: list needed" sexp + ;; + + let sexp_of_t sexp_of_a t = + Sexp.List (fold_right t ~init:[] ~f:(fun el acc -> sexp_of_a el :: acc)) + ;; + + module Named = struct + type nonrec ('a, 'cmp) t = { + tree : 'a t; + name : string; + } + + let is_subset (subset : _ t) ~of_:(superset : _ t) ~sexp_of_elt + ~compare_elt = + let invalid_elements = diff subset.tree superset.tree ~compare_elt in + if is_empty invalid_elements + then Ok () + else begin + let invalid_elements_sexp = sexp_of_t sexp_of_elt invalid_elements in + Or_error.error_s ( + Sexp.message (subset.name ^ " is not a subset of " ^ superset.name) + [ "invalid_elements", invalid_elements_sexp ]) + end + + let equal s1 s2 ~sexp_of_elt ~compare_elt = + Or_error.combine_errors_unit + [ is_subset s1 ~of_:s2 ~sexp_of_elt ~compare_elt + ; is_subset s2 ~of_:s1 ~sexp_of_elt ~compare_elt + ] + end +end + +type ('a, 'comparator) t = + { (* [comparator] is the first field so that polymorphic equality fails on a map due + to the functional value in the comparator. + Note that this does not affect polymorphic [compare]: that still produces + nonsense. *) + comparator : ('a, 'comparator) Comparator.t; + tree : 'a Tree0.t; + } + +type ('a, 'comparator) tree = 'a Tree0.t + +let like { tree = _; comparator } tree = { tree; comparator } + +let compare_elt t = t.comparator.Comparator.compare + +module Accessors = struct + let comparator t = t.comparator + let invariants t = Tree0.invariants t.tree ~compare_elt:(compare_elt t) + let length t = Tree0.length t.tree + let is_empty t = Tree0.is_empty t.tree + let elements t = Tree0.elements t.tree + let min_elt t = Tree0.min_elt t.tree + let min_elt_exn t = Tree0.min_elt_exn t.tree + let max_elt t = Tree0.max_elt t.tree + let max_elt_exn t = Tree0.max_elt_exn t.tree + let choose t = Tree0.choose t.tree + let choose_exn t = Tree0.choose_exn t.tree + let to_list t = Tree0.to_list t.tree + let to_array t = Tree0.to_array t.tree + let fold t ~init ~f = Tree0.fold t.tree ~init ~f + let fold_until t ~init ~f = Tree0.fold_until t.tree ~init ~f + let fold_right t ~init ~f = Tree0.fold_right t.tree ~init ~f + let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t + + let iter t ~f = Tree0.iter t.tree ~f + let iter2 a b ~f = Tree0.iter2 a.tree b.tree ~f ~compare_elt:(compare_elt a) + let exists t ~f = Tree0.exists t.tree ~f + let for_all t ~f = Tree0.for_all t.tree ~f + let count t ~f = Tree0.count t.tree ~f + let sum m t ~f = Tree0.sum m t.tree ~f + let find t ~f = Tree0.find t.tree ~f + let find_exn t ~f = Tree0.find_exn t.tree ~f + let find_map t ~f = Tree0.find_map t.tree ~f + let mem t a = Tree0.mem t.tree a ~compare_elt:(compare_elt t) + let filter t ~f = like t (Tree0.filter t.tree ~f ~compare_elt:(compare_elt t)) + let add t a = like t (Tree0.add t.tree a ~compare_elt:(compare_elt t)) + let remove t a = like t (Tree0.remove t.tree a ~compare_elt:(compare_elt t)) + let union t1 t2 = like t1 (Tree0.union t1.tree t2.tree ~compare_elt:(compare_elt t1)) + let inter t1 t2 = like t1 (Tree0.inter t1.tree t2.tree ~compare_elt:(compare_elt t1)) + let diff t1 t2 = like t1 (Tree0.diff t1.tree t2.tree ~compare_elt:(compare_elt t1)) + let symmetric_diff t1 t2 = + Tree0.symmetric_diff t1.tree t2.tree ~compare_elt:(compare_elt t1) + let compare_direct t1 t2 = Tree0.compare (compare_elt t1) t1.tree t2.tree + let equal t1 t2 = Tree0.equal t1.tree t2.tree ~compare_elt:(compare_elt t1) + let is_subset t ~of_ = Tree0.is_subset t.tree ~of_:of_.tree ~compare_elt:(compare_elt t) + + module Named = struct + type nonrec ('a, 'cmp) t = { + set : ('a, 'cmp) t; + name : string; + } + + let to_named_tree { set; name } = { + Tree0.Named. + tree = set.tree; + name; + } + + let is_subset (subset : (_, _) t) ~of_:(superset : (_, _) t) = + Tree0.Named.is_subset (to_named_tree subset) + ~of_:(to_named_tree superset) + ~compare_elt:(compare_elt subset.set) + ~sexp_of_elt:subset.set.comparator.sexp_of_t + + let equal t1 t2 = + Or_error.combine_errors_unit + [ is_subset t1 ~of_:t2 + ; is_subset t2 ~of_:t1 + ] + end + + let partition_tf t ~f = + let (tree_t, tree_f) = Tree0.partition_tf t.tree ~f ~compare_elt:(compare_elt t) in + like t tree_t, like t tree_f + ;; + let split t a = + let (tree1, b, tree2) = Tree0.split t.tree a ~compare_elt:(compare_elt t) in + like t tree1, b, like t tree2 + ;; + let group_by t ~equiv = + List.map (Tree0.group_by t.tree ~equiv ~compare_elt:(compare_elt t)) ~f:(like t) + ;; + let nth t i = Tree0.nth t.tree i + let remove_index t i = like t (Tree0.remove_index t.tree i ~compare_elt:(compare_elt t)) + let sexp_of_t sexp_of_a _ t = Tree0.sexp_of_t sexp_of_a t.tree + let to_sequence ?order ?greater_or_equal_to ?less_or_equal_to t = + Tree0.to_sequence t.comparator ?order ?greater_or_equal_to ?less_or_equal_to t.tree + let merge_to_sequence ?order ?greater_or_equal_to ?less_or_equal_to t t' = + Tree0.merge_to_sequence + t.comparator + ?order + ?greater_or_equal_to + ?less_or_equal_to + t.tree + t'.tree + let hash_fold_direct hash_fold_key state t = + Tree0.hash_fold_t_ignoring_structure hash_fold_key state t.tree +end + +include Accessors + +let compare _ _ t1 t2 = compare_direct t1 t2 + +module Tree = struct + type ('a, 'comparator) t = ('a, 'comparator) tree + + let ce comparator = comparator.Comparator.compare + + let t_of_sexp_direct ~comparator a_of_sexp sexp = + Tree0.t_of_sexp_direct ~compare_elt:(ce comparator) a_of_sexp sexp + + let empty_without_value_restriction = Tree0.empty + let empty ~comparator:_ = empty_without_value_restriction + let singleton ~comparator:_ e = Tree0.singleton e + + let length t = Tree0.length t + let invariants ~comparator t = Tree0.invariants t ~compare_elt:(ce comparator) + let is_empty t = Tree0.is_empty t + let elements t = Tree0.elements t + let min_elt t = Tree0.min_elt t + let min_elt_exn t = Tree0.min_elt_exn t + let max_elt t = Tree0.max_elt t + let max_elt_exn t = Tree0.max_elt_exn t + let choose t = Tree0.choose t + let choose_exn t = Tree0.choose_exn t + let to_list t = Tree0.to_list t + let to_array t = Tree0.to_array t + + let iter t ~f = Tree0.iter t ~f + let exists t ~f = Tree0.exists t ~f + let for_all t ~f = Tree0.for_all t ~f + let count t ~f = Tree0.count t ~f + let sum m t ~f = Tree0.sum m t ~f + let find t ~f = Tree0.find t ~f + let find_exn t ~f = Tree0.find_exn t ~f + let find_map t ~f = Tree0.find_map t ~f + + let fold t ~init ~f = Tree0.fold t ~init ~f + let fold_until t ~init ~f = Tree0.fold_until t ~init ~f + let fold_right t ~init ~f = Tree0.fold_right t ~init ~f + + let map ~comparator t ~f = Tree0.map t ~f ~compare_elt:(ce comparator) + let filter ~comparator t ~f = Tree0.filter t ~f ~compare_elt:(ce comparator) + let filter_map ~comparator t ~f = Tree0.filter_map t ~f ~compare_elt:(ce comparator) + let partition_tf ~comparator t ~f = Tree0.partition_tf t ~f ~compare_elt:(ce comparator) + + let iter2 ~comparator a b ~f = Tree0.iter2 a b ~f ~compare_elt:(ce comparator) + + let mem ~comparator t a = Tree0.mem t a ~compare_elt:(ce comparator) + let add ~comparator t a = Tree0.add t a ~compare_elt:(ce comparator) + let remove ~comparator t a = Tree0.remove t a ~compare_elt:(ce comparator) + + let union ~comparator t1 t2 = Tree0.union t1 t2 ~compare_elt:(ce comparator) + let inter ~comparator t1 t2 = Tree0.inter t1 t2 ~compare_elt:(ce comparator) + let diff ~comparator t1 t2 = Tree0.diff t1 t2 ~compare_elt:(ce comparator) + let symmetric_diff ~comparator t1 t2 = + Tree0.symmetric_diff t1 t2 ~compare_elt:(ce comparator) + let compare_direct ~comparator t1 t2 = Tree0.compare (ce comparator) t1 t2 + let equal ~comparator t1 t2 = Tree0.equal t1 t2 ~compare_elt:(ce comparator) + + let is_subset ~comparator t ~of_ = Tree0.is_subset t ~of_ ~compare_elt:(ce comparator) + + let of_list ~comparator l = Tree0.of_list l ~compare_elt:(ce comparator) + let of_array ~comparator a = Tree0.of_array a ~compare_elt:(ce comparator) + let of_sorted_array_unchecked ~comparator a = + Tree0.of_sorted_array_unchecked a ~compare_elt:(ce comparator) + let of_increasing_iterator_unchecked ~comparator:_ ~len ~f = + Tree0.of_increasing_iterator_unchecked ~len ~f + let of_sorted_array ~comparator a = + Tree0.of_sorted_array a ~compare_elt:(ce comparator) + + let union_list ~comparator l = Tree0.union_list l ~to_tree:Fn.id ~comparator + let stable_dedup_list ~comparator xs = + Tree0.stable_dedup_list xs ~compare_elt:(ce comparator) + ;; + let group_by ~comparator t ~equiv = Tree0.group_by t ~equiv ~compare_elt:(ce comparator) + let split ~comparator t a = Tree0.split t a ~compare_elt:(ce comparator) + let nth t i = Tree0.nth t i + let remove_index ~comparator t i = Tree0.remove_index t i ~compare_elt:(ce comparator) + let sexp_of_t sexp_of_a _ t = Tree0.sexp_of_t sexp_of_a t + + let to_tree t = t + let of_tree ~comparator:_ t = t + + let to_sequence ~comparator ?order ?greater_or_equal_to ?less_or_equal_to t = + Tree0.to_sequence comparator ?order ?greater_or_equal_to ?less_or_equal_to t + + let merge_to_sequence ~comparator ?order ?greater_or_equal_to ?less_or_equal_to t t' = + Tree0.merge_to_sequence comparator ?order ?greater_or_equal_to ?less_or_equal_to t t' + + let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t + + module Named = struct + include Tree0.Named + + let is_subset ~comparator t1 ~of_:t2 = + Tree0.Named.is_subset t1 ~of_:t2 ~compare_elt:(ce comparator) + ~sexp_of_elt:comparator.Comparator.sexp_of_t + + let equal ~comparator t1 t2 = + Tree0.Named.equal t1 t2 ~compare_elt:(ce comparator) + ~sexp_of_elt:comparator.Comparator.sexp_of_t + end +end + +module Using_comparator = struct + type nonrec ('elt, 'cmp) t = ('elt, 'cmp) t + + include Accessors + + let to_tree t = t.tree + + let of_tree ~comparator tree = { comparator; tree } + + let t_of_sexp_direct ~comparator a_of_sexp sexp = + of_tree ~comparator + (Tree0.t_of_sexp_direct ~compare_elt:comparator.compare a_of_sexp sexp) + + let empty ~comparator = { comparator; tree = Tree0.empty } + + module Empty_without_value_restriction(Elt : Comparator.S1) = struct + let empty = { comparator = Elt.comparator; tree = Tree0.empty } + end + + let singleton ~comparator e = { comparator; tree = Tree0.singleton e } + + let union_list ~comparator l = + of_tree ~comparator (Tree0.union_list ~comparator ~to_tree l) + ;; + + let of_sorted_array_unchecked ~comparator array = + let tree = Tree0.of_sorted_array_unchecked array ~compare_elt:comparator.Comparator.compare in + { comparator; tree } + ;; + + let of_increasing_iterator_unchecked ~comparator ~len ~f = + of_tree ~comparator (Tree0.of_increasing_iterator_unchecked ~len ~f) + ;; + + let of_sorted_array ~comparator array = + Or_error.Monad_infix.( + Tree0.of_sorted_array array ~compare_elt:comparator.Comparator.compare + >>| fun tree -> { comparator; tree }) + ;; + + let of_list ~comparator l = + { comparator; tree = Tree0.of_list l ~compare_elt:comparator.Comparator.compare } + ;; + + let of_array ~comparator a = + { comparator; tree = Tree0.of_array a ~compare_elt:comparator.Comparator.compare } + ;; + + let stable_dedup_list ~comparator xs = + Tree0.stable_dedup_list xs ~compare_elt:comparator.Comparator.compare; + ;; + + let map ~comparator t ~f = + { comparator; tree = Tree0.map t.tree ~f ~compare_elt:comparator.Comparator.compare } + ;; + + let filter_map ~comparator t ~f = + { comparator; + tree = Tree0.filter_map t.tree ~f ~compare_elt:comparator.Comparator.compare; + } + ;; + + module Tree = Tree +end + +type ('elt, 'cmp) comparator = + (module Comparator.S with type t = 'elt and type comparator_witness = 'cmp) + +let comparator_s (type k cmp) t : (k, cmp) comparator = + (module struct + type t = k + type comparator_witness = cmp + let comparator = t.comparator + end) + +let to_comparator (type elt cmp) ((module M) : (elt, cmp) comparator) = M.comparator + +let empty m = Using_comparator.empty ~comparator:(to_comparator m) +let singleton m a = Using_comparator.singleton ~comparator:(to_comparator m) a +let union_list m a = Using_comparator.union_list ~comparator:(to_comparator m) a +let of_sorted_array_unchecked m a = Using_comparator.of_sorted_array_unchecked ~comparator:(to_comparator m) a +let of_increasing_iterator_unchecked m ~len ~f = + Using_comparator.of_increasing_iterator_unchecked ~comparator:(to_comparator m) ~len ~f +let of_sorted_array m a = Using_comparator.of_sorted_array ~comparator:(to_comparator m) a +let of_list m a = Using_comparator.of_list ~comparator:(to_comparator m) a +let of_array m a = Using_comparator.of_array ~comparator:(to_comparator m) a +let stable_dedup_list m a = Using_comparator.stable_dedup_list ~comparator:(to_comparator m) a +let map m a ~f = Using_comparator.map ~comparator:(to_comparator m) a ~f +let filter_map m a ~f = Using_comparator.filter_map ~comparator:(to_comparator m) a ~f + +module M (Elt : sig type t type comparator_witness end) = struct + type nonrec t = (Elt.t, Elt.comparator_witness) t +end + +module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end +module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Comparator.S with type t := t +end +module type Compare_m = sig end +module type Hash_fold_m = Hasher.S + +let sexp_of_m__t (type elt) (module Elt : Sexp_of_m with type t = elt) t = + sexp_of_t Elt.sexp_of_t (fun _ -> Sexp.Atom "_") t + +let m__t_of_sexp (type elt cmp) + (module Elt : M_of_sexp with type t = elt and type comparator_witness = cmp) + sexp = + Using_comparator.t_of_sexp_direct ~comparator:Elt.comparator Elt.t_of_sexp sexp + +let compare_m__t (module Elt : Compare_m) t1 t2 = + compare_direct t1 t2 + +let hash_fold_m__t (type elt) (module Elt : Hash_fold_m with type t = elt) state = + hash_fold_direct Elt.hash_fold_t state + +let hash_m__t folder t = + let state = hash_fold_m__t folder (Hash.create ()) t in + Hash.get_hash_value state + +module Poly = struct + type comparator_witness = Comparator.Poly.comparator_witness + type nonrec ('elt, 'cmp) set = ('elt, comparator_witness) t + type nonrec 'elt t = ('elt, comparator_witness) t + type nonrec 'elt tree = ('elt, comparator_witness) tree + type nonrec 'elt named = ('elt, comparator_witness) Named.t + + include Accessors + + let comparator = Comparator.Poly.comparator + + include Using_comparator.Empty_without_value_restriction(Comparator.Poly) + + let singleton a = Using_comparator.singleton ~comparator a + let union_list a = Using_comparator.union_list ~comparator a + let of_sorted_array_unchecked a = Using_comparator.of_sorted_array_unchecked ~comparator a + let of_increasing_iterator_unchecked ~len ~f = + Using_comparator.of_increasing_iterator_unchecked ~comparator ~len ~f + let of_sorted_array a = Using_comparator.of_sorted_array ~comparator a + let of_list a = Using_comparator.of_list ~comparator a + let of_array a = Using_comparator.of_array ~comparator a + let stable_dedup_list a = Using_comparator.stable_dedup_list ~comparator a + let map a ~f = Using_comparator.map ~comparator a ~f + let filter_map a ~f = Using_comparator.filter_map ~comparator a ~f + + let of_tree tree = { comparator; tree } + let to_tree t = t.tree +end diff --git a/src/set.mli b/src/set.mli new file mode 100644 index 0000000..5b64bbd --- /dev/null +++ b/src/set.mli @@ -0,0 +1 @@ +include Set_intf.Set (** @inline *) diff --git a/src/set_intf.ml b/src/set_intf.ml new file mode 100644 index 0000000..db1605a --- /dev/null +++ b/src/set_intf.ml @@ -0,0 +1,1278 @@ +open! Import +open! T + +module type Elt_plain = sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] +end + +module Without_comparator = Map_intf.Without_comparator +module With_comparator = Map_intf.With_comparator +module With_first_class_module = Map_intf.With_first_class_module + +include Container_intf.Export + +module Merge_to_sequence_element = Sequence.Merge_with_duplicates_element + +module type Accessors_generic = sig + + include Container.Generic_phantom + + type ('a, 'cmp) tree + + (** The [options] type is used to make [Accessors_generic] flexible as to whether a + comparator is required to be passed to certain functions. *) + type ('a, 'cmp, 'z) options + + type 'cmp cmp + + val invariants + : ('a, 'cmp, + ('a, 'cmp) t -> bool + ) options + + (** override [Container]'s [mem] *) + val mem : ('a, 'cmp, ('a, 'cmp) t -> 'a elt -> bool) options + val add + : ('a, 'cmp, + ('a, 'cmp) t -> 'a elt -> ('a, 'cmp) t + ) options + val remove + : ('a, 'cmp, + ('a, 'cmp) t -> 'a elt -> ('a, 'cmp) t + ) options + val union + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + ) options + val inter + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + ) options + val diff + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + ) options + val symmetric_diff + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t + -> ('a elt, 'a elt) Either.t Sequence.t + ) options + val compare_direct + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t -> int + ) options + val equal + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t -> bool + ) options + val is_subset + : ('a, 'cmp, + ('a, 'cmp) t -> of_:('a, 'cmp) t -> bool + ) options + + type ('a, 'cmp) named + module Named : sig + val is_subset + : ('a, 'cmp, + ('a, 'cmp) named + -> of_:('a, 'cmp) named + -> unit Or_error.t + ) options + + val equal + : ('a, 'cmp, + ('a, 'cmp) named + -> ('a, 'cmp) named + -> unit Or_error.t + ) options + end + + val fold_until + : ('a, _) t + -> init:'b + -> f:('b -> 'a elt -> ('b, 'final) Continue_or_stop.t) + -> finish:('b -> 'final) + -> 'final + val fold_right + : ('a, _) t + -> init:'b + -> f:('a elt -> 'b -> 'b) + -> 'b + val iter2 + : ('a, 'cmp, + ('a, 'cmp) t + -> ('a, 'cmp) t + -> f:([ `Left of 'a elt | `Right of 'a elt | `Both of 'a elt * 'a elt ] -> unit) + -> unit + ) options + val filter + : ('a, 'cmp, + ('a, 'cmp) t -> f:('a elt -> bool) -> ('a, 'cmp) t + ) options + val partition_tf + : ('a, 'cmp, + ('a, 'cmp) t + -> f:('a elt -> bool) + -> ('a, 'cmp) t * ('a, 'cmp) t + ) options + + val elements : ('a, _) t -> 'a elt list + + val min_elt : ('a, _) t -> 'a elt option + val min_elt_exn : ('a, _) t -> 'a elt + val max_elt : ('a, _) t -> 'a elt option + val max_elt_exn : ('a, _) t -> 'a elt + val choose : ('a, _) t -> 'a elt option + val choose_exn : ('a, _) t -> 'a elt + + val split + : ('a, 'cmp, + ('a, 'cmp) t + -> 'a elt + -> ('a, 'cmp) t * 'a elt option * ('a, 'cmp) t + ) options + + val group_by + : ('a, 'cmp, + ('a, 'cmp) t + -> equiv:('a elt -> 'a elt -> bool) + -> ('a, 'cmp) t list + ) options + + val find_exn : ('a, _) t -> f:('a elt -> bool) -> 'a elt + val nth : ('a, _) t -> int -> 'a elt option + val remove_index + : ('a, 'cmp, + ('a, 'cmp) t -> int -> ('a, 'cmp) t + ) options + + val to_tree : ('a, 'cmp) t -> ('a elt, 'cmp) tree + + val to_sequence + : ('a, 'cmp, + ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a elt + -> ?less_or_equal_to:'a elt + -> ('a, 'cmp) t + -> 'a elt Sequence.t + ) options + + val merge_to_sequence + : ('a, 'cmp, + ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a elt + -> ?less_or_equal_to:'a elt + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> ('a elt, 'a elt) Merge_to_sequence_element.t Sequence.t + ) options +end + +module type Accessors0 = sig + include Container.S0 + type tree + type comparator_witness + val invariants : t -> bool + val mem : t -> elt -> bool + val add : t -> elt -> t + val remove : t -> elt -> t + val union : t -> t -> t + val inter : t -> t -> t + val diff : t -> t -> t + val symmetric_diff : t -> t -> (elt, elt) Either.t Sequence.t + val compare_direct : t -> t -> int + val equal : t -> t -> bool + val is_subset : t -> of_:t -> bool + + type named + module Named : sig + val is_subset : named -> of_:named -> unit Or_error.t + val equal : named -> named -> unit Or_error.t + end + + val fold_until + : t + -> init:'b + -> f:('b -> elt -> ('b, 'final) Continue_or_stop.t) + -> finish:('b -> 'final) + -> 'final + val fold_right : t -> init:'b -> f:(elt -> 'b -> 'b) -> 'b + val iter2 + : t -> t -> f:([ `Left of elt | `Right of elt | `Both of elt * elt ] -> unit) -> unit + val filter : t -> f:(elt -> bool) -> t + val partition_tf : t -> f:(elt -> bool) -> t * t + val elements : t -> elt list + val min_elt : t -> elt option + val min_elt_exn : t -> elt + val max_elt : t -> elt option + val max_elt_exn : t -> elt + val choose : t -> elt option + val choose_exn : t -> elt + val split : t -> elt -> t * elt option * t + val group_by : t -> equiv:(elt -> elt -> bool) -> t list + val find_exn : t -> f:(elt -> bool) -> elt + val nth : t -> int -> elt option + val remove_index : t -> int -> t + val to_tree : t -> tree + val to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:elt + -> ?less_or_equal_to:elt + -> t + -> elt Sequence.t + val merge_to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:elt + -> ?less_or_equal_to:elt + -> t + -> t + -> (elt, elt) Merge_to_sequence_element.t Sequence.t +end + +module type Accessors1 = sig + include Container.S1 + type 'a tree + type comparator_witness + val invariants : _ t -> bool + val mem : 'a t -> 'a -> bool + val add : 'a t -> 'a -> 'a t + val remove : 'a t -> 'a -> 'a t + val union : 'a t -> 'a t -> 'a t + val inter : 'a t -> 'a t -> 'a t + val diff : 'a t -> 'a t -> 'a t + val symmetric_diff : 'a t -> 'a t -> ('a, 'a) Either.t Sequence.t + val compare_direct : 'a t -> 'a t -> int + val equal : 'a t -> 'a t -> bool + val is_subset : 'a t -> of_:'a t -> bool + + type 'a named + module Named : sig + val is_subset : 'a named -> of_:'a named -> unit Or_error.t + val equal : 'a named -> 'a named -> unit Or_error.t + end + + val fold_until + : 'a t + -> init:'b + -> f:('b -> 'a -> ('b, 'final) Continue_or_stop.t) + -> finish:('b -> 'final) + -> 'final + val fold_right : 'a t -> init:'b -> f:('a -> 'b -> 'b) -> 'b + val iter2 + : 'a t -> 'a t -> f:([ `Left of 'a | `Right of 'a | `Both of 'a * 'a ] -> unit) -> unit + val filter : 'a t -> f:('a -> bool) -> 'a t + val partition_tf : 'a t -> f:('a -> bool) -> 'a t * 'a t + val elements : 'a t -> 'a list + val min_elt : 'a t -> 'a option + val min_elt_exn : 'a t -> 'a + val max_elt : 'a t -> 'a option + val max_elt_exn : 'a t -> 'a + val choose : 'a t -> 'a option + val choose_exn : 'a t -> 'a + val split : 'a t -> 'a -> 'a t * 'a option * 'a t + val group_by : 'a t -> equiv:('a -> 'a -> bool) -> 'a t list + val find_exn : 'a t -> f:('a -> bool) -> 'a + val nth : 'a t -> int -> 'a option + val remove_index : 'a t -> int -> 'a t + val to_tree : 'a t -> 'a tree + val to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> 'a t + -> 'a Sequence.t + val merge_to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> 'a t + -> 'a t + -> ('a, 'a) Merge_to_sequence_element.t Sequence.t +end + +module type Accessors2 = sig + include Container.S1_phantom_invariant + type ('a, 'cmp) tree + val invariants : (_, _) t -> bool + val mem : ('a, _) t -> 'a -> bool + val add : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + val remove : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + val union : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val inter : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val diff : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val symmetric_diff : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'a) Either.t Sequence.t + val compare_direct : ('a, 'cmp) t -> ('a, 'cmp) t -> int + val equal : ('a, 'cmp) t -> ('a, 'cmp) t -> bool + val is_subset : ('a, 'cmp) t -> of_:('a, 'cmp) t -> bool + + type ('a, 'cmp) named + module Named : sig + val is_subset : ('a, 'cmp) named -> of_:('a, 'cmp) named -> unit Or_error.t + val equal : ('a, 'cmp) named -> ('a, 'cmp) named -> unit Or_error.t + end + + val fold_until + : ('a, _) t + -> init:'b + -> f:('b -> 'a -> ('b, 'final) Continue_or_stop.t) + -> finish:('b -> 'final) + -> 'final + + val fold_right : ('a, _) t -> init:'b -> f:('a -> 'b -> 'b) -> 'b + val iter2 + : ('a, 'cmp) t + -> ('a, 'cmp) t -> f:([ `Left of 'a | `Right of 'a | `Both of 'a * 'a ] -> unit) + -> unit + val filter : ('a, 'cmp) t -> f:('a -> bool) -> ('a, 'cmp) t + val partition_tf : ('a, 'cmp) t -> f:('a -> bool) -> ('a, 'cmp) t * ('a, 'cmp) t + val elements : ('a, _) t -> 'a list + val min_elt : ('a, _) t -> 'a option + val min_elt_exn : ('a, _) t -> 'a + val max_elt : ('a, _) t -> 'a option + val max_elt_exn : ('a, _) t -> 'a + val choose : ('a, _) t -> 'a option + val choose_exn : ('a, _) t -> 'a + val split : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t * 'a option * ('a, 'cmp) t + val group_by : ('a, 'cmp) t -> equiv:('a -> 'a -> bool) -> ('a, 'cmp) t list + val find_exn : ('a, _) t -> f:('a -> bool) -> 'a + val nth : ('a, _) t -> int -> 'a option + val remove_index : ('a, 'cmp) t -> int -> ('a, 'cmp) t + val to_tree : ('a, 'cmp) t -> ('a, 'cmp) tree + val to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> ('a, 'cmp) t + -> 'a Sequence.t + val merge_to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> ('a, 'a) Merge_to_sequence_element.t Sequence.t +end + +module type Accessors2_with_comparator = sig + include Container.S1_phantom_invariant + type ('a, 'cmp) tree + val invariants : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> bool + val mem : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> 'a -> bool + val add + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + val remove + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + val union + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val inter + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val diff + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val symmetric_diff + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t + -> ('a, 'a) Either.t Sequence.t + val compare_direct + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t -> int + val equal + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t -> bool + val is_subset + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> of_:('a, 'cmp) t -> bool + + type ('a, 'cmp) named + module Named : sig + val is_subset + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) named + -> of_:('a, 'cmp) named + -> unit Or_error.t + val equal + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) named + -> ('a, 'cmp) named + -> unit Or_error.t + end + + val fold_until + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + val fold_right : ('a, _) t -> init:'accum -> f:('a -> 'accum -> 'accum) -> 'accum + val iter2 + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> f:([ `Left of 'a | `Right of 'a | `Both of 'a * 'a ] -> unit) + -> unit + val filter + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> f:('a -> bool) -> ('a, 'cmp) t + val partition_tf + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t -> f:('a -> bool) -> ('a, 'cmp) t * ('a, 'cmp) t + val elements : ('a, _) t -> 'a list + val min_elt : ('a, _) t -> 'a option + val min_elt_exn : ('a, _) t -> 'a + val max_elt : ('a, _) t -> 'a option + val max_elt_exn : ('a, _) t -> 'a + val choose : ('a, _) t -> 'a option + val choose_exn : ('a, _) t -> 'a + val split + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t -> 'a -> ('a, 'cmp) t * 'a option * ('a, 'cmp) t + val group_by + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t -> equiv:('a -> 'a -> bool) -> ('a, 'cmp) t list + val find_exn : ('a, _) t -> f:('a -> bool) -> 'a + val nth : ('a, _) t -> int -> 'a option + + val remove_index + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> int -> ('a, 'cmp) t + val to_tree : ('a, 'cmp) t -> ('a, 'cmp) tree + val to_sequence + : comparator:('a, 'cmp) Comparator.t + -> ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> ('a, 'cmp) t + -> 'a Sequence.t + val merge_to_sequence + : comparator:('a, 'cmp) Comparator.t + -> ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> ('a, 'a) Merge_to_sequence_element.t Sequence.t +end + +(** Consistency checks (same as in [Container]). *) +module Check_accessors (T : T2) (Tree : T2) (Elt : T1) (Named : T2) (Cmp : T1) (Options : T3) + (M : Accessors_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) Options.t + with type ('a, 'b) t := ('a, 'b) T.t + with type ('a, 'b) tree := ('a, 'b) Tree.t + with type 'a elt := 'a Elt.t + with type 'cmp cmp := 'cmp Cmp.t + with type ('a, 'b) named := ('a, 'b) Named.t) += struct end + +module Check_accessors0 (M : Accessors0) = + Check_accessors + (struct type ('a, 'b) t = M.t end) + (struct type ('a, 'b) t = M.tree end) + (struct type 'a t = M.elt end) + (struct type ('a, 'b) t = M.named end) + (struct type 'a t = M.comparator_witness end) + (Without_comparator) + (M) + +module Check_accessors1 (M : Accessors1) = + Check_accessors + (struct type ('a, 'b) t = 'a M.t end) + (struct type ('a, 'b) t = 'a M.tree end) + (struct type 'a t = 'a end) + (struct type ('a, 'b) t = 'a M.named end) + (struct type 'a t = M.comparator_witness end) + (Without_comparator) + (M) + +module Check_accessors2 (M : Accessors2) = + Check_accessors + (struct type ('a, 'b) t = ('a, 'b) M.t end) + (struct type ('a, 'b) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (struct type ('a, 'b) t = ('a, 'b) M.named end) + (struct type 'a t = 'a end) + (Without_comparator) + (M) + +module Check_accessors2_with_comparator (M : Accessors2_with_comparator) = + Check_accessors + (struct type ('a, 'b) t = ('a, 'b) M.t end) + (struct type ('a, 'b) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (struct type ('a, 'b) t = ('a, 'b) M.named end) + (struct type 'a t = 'a end) + (With_comparator) + (M) + +module type Creators_generic = sig + type ('a, 'cmp) t + type ('a, 'cmp) set + type ('a, 'cmp) tree + type 'a elt + type ('a, 'cmp, 'z) options + type 'cmp cmp + + val empty : ('a, 'cmp, ('a, 'cmp) t) options + val singleton : ('a, 'cmp, 'a elt -> ('a, 'cmp) t) options + val union_list + : ('a, 'cmp, + ('a, 'cmp) t list -> ('a, 'cmp) t + ) options + val of_list : ('a, 'cmp, 'a elt list -> ('a, 'cmp) t) options + val of_array : ('a, 'cmp, 'a elt array -> ('a, 'cmp) t) options + + val of_sorted_array : ('a, 'cmp, 'a elt array -> ('a, 'cmp) t Or_error.t) options + + val of_sorted_array_unchecked : ('a, 'cmp, 'a elt array -> ('a, 'cmp) t) options + + val of_increasing_iterator_unchecked + : ('a, 'cmp, len:int -> f:(int -> 'a elt) -> ('a, 'cmp) t) options + + val stable_dedup_list : ('a, _, 'a elt list -> 'a elt list) options + + (** The types of [map] and [filter_map] are subtle. The input set, [('a, _) set], + reflects the fact that these functions take a set of *any* type, with any + comparator, while the output set, [('b, 'cmp) t], reflects that the output set has + the particular ['cmp] of the creation function. The comparator can come in one of + three ways, depending on which set module is used + + - [Set.map] -- comparator comes as an argument + - [Set.Poly.map] -- comparator is polymorphic comparison + - [Foo.Set.map] -- comparator is [Foo.comparator] *) + val map + : ('b, 'cmp, ('a, _) set -> f:('a -> 'b elt ) -> ('b, 'cmp) t + ) options + val filter_map + : ('b, 'cmp, ('a, _) set -> f:('a -> 'b elt option) -> ('b, 'cmp) t + ) options + + val of_tree + : ('a, 'cmp, + ('a elt, 'cmp) tree -> ('a, 'cmp) t + ) options +end + +module type Creators0 = sig + type ('a, 'cmp) set + type t + type tree + type elt + type comparator_witness + val empty : t + val singleton : elt -> t + val union_list : t list -> t + val of_list : elt list -> t + val of_array : elt array -> t + val of_sorted_array : elt array -> t Or_error.t + val of_sorted_array_unchecked : elt array -> t + val of_increasing_iterator_unchecked : len:int -> f:(int -> elt) -> t + val stable_dedup_list : elt list -> elt list + val map : ('a, _) set -> f:('a -> elt ) -> t + val filter_map : ('a, _) set -> f:('a -> elt option) -> t + val of_tree : tree -> t +end + +module type Creators1 = sig + type ('a, 'cmp) set + type 'a t + type 'a tree + type comparator_witness + val empty : 'a t + val singleton : 'a -> 'a t + val union_list : 'a t list -> 'a t + val of_list : 'a list -> 'a t + val of_array : 'a array -> 'a t + val of_sorted_array : 'a array -> 'a t Or_error.t + val of_sorted_array_unchecked : 'a array -> 'a t + val of_increasing_iterator_unchecked : len:int -> f:(int -> 'a) -> 'a t + val stable_dedup_list : 'a list -> 'a list + val map : ('a, _) set -> f:('a -> 'b ) -> 'b t + val filter_map : ('a, _) set -> f:('a -> 'b option) -> 'b t + val of_tree : 'a tree -> 'a t +end + +module type Creators2 = sig + type ('a, 'cmp) set + type ('a, 'cmp) t + type ('a, 'cmp) tree + val empty : ('a, 'cmp) t + val singleton : 'a -> ('a, 'cmp) t + val union_list : ('a, 'cmp) t list -> ('a, 'cmp) t + val of_list : 'a list -> ('a, 'cmp) t + val of_array : 'a array -> ('a, 'cmp) t + val of_sorted_array : 'a array -> ('a, 'cmp) t Or_error.t + val of_sorted_array_unchecked : 'a array -> ('a, 'cmp) t + val of_increasing_iterator_unchecked : len:int -> f:(int -> 'a) -> ('a, 'cmp) t + val stable_dedup_list : 'a list -> 'a list + val map : ('a, _) set -> f:('a -> 'b ) -> ('b, 'cmp) t + val filter_map : ('a, _) set -> f:('a -> 'b option) -> ('b, 'cmp) t + val of_tree : ('a, 'cmp) tree -> ('a, 'cmp) t +end + +module type Creators2_with_comparator = sig + type ('a, 'cmp) set + type ('a, 'cmp) t + type ('a, 'cmp) tree + val empty : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t + val singleton : comparator:('a, 'cmp) Comparator.t -> 'a -> ('a, 'cmp) t + val union_list : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t list + -> ('a, 'cmp) t + val of_list : comparator:('a, 'cmp) Comparator.t -> 'a list + -> ('a, 'cmp) t + val of_array : comparator:('a, 'cmp) Comparator.t -> 'a array + -> ('a, 'cmp) t + val of_sorted_array : comparator:('a, 'cmp) Comparator.t -> 'a array + -> ('a, 'cmp) t Or_error.t + val of_sorted_array_unchecked : comparator:('a, 'cmp) Comparator.t -> 'a array + -> ('a, 'cmp) t + val of_increasing_iterator_unchecked + : comparator:('a, 'cmp) Comparator.t -> len:int -> f:(int -> 'a) -> ('a, 'cmp) t + val stable_dedup_list : comparator:('a, 'cmp) Comparator.t -> 'a list -> 'a list + val map : comparator:('b, 'cmp) Comparator.t -> ('a, _) set + -> f:('a -> 'b ) -> ('b, 'cmp) t + val filter_map : comparator:('b, 'cmp) Comparator.t -> ('a, _) set + -> f:('a -> 'b option) -> ('b, 'cmp) t + val of_tree : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) tree -> ('a, 'cmp) t +end + +module Check_creators (T : T2) (Tree : T2) (Elt : T1) (Cmp : T1) (Options : T3) + (M : Creators_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) Options.t + with type ('a, 'b) t := ('a, 'b) T.t + with type ('a, 'b) tree := ('a, 'b) Tree.t + with type 'a elt := 'a Elt.t + with type 'cmp cmp := 'cmp Cmp.t) += struct end + +module Check_creators0 (M : Creators0) = + Check_creators + (struct type ('a, 'b) t = M.t end) + (struct type ('a, 'b) t = M.tree end) + (struct type 'a t = M.elt end) + (struct type 'cmp t = M.comparator_witness end) + (Without_comparator) + (M) + +module Check_creators1 (M : Creators1) = + Check_creators + (struct type ('a, 'b) t = 'a M.t end) + (struct type ('a, 'b) t = 'a M.tree end) + (struct type 'a t = 'a end) + (struct type 'cmp t = M.comparator_witness end) + (Without_comparator) + (M) + +module Check_creators2 (M : Creators2) = + Check_creators + (struct type ('a, 'b) t = ('a, 'b) M.t end) + (struct type ('a, 'b) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (struct type 'cmp t = 'cmp end) + (Without_comparator) + (M) + +module Check_creators2_with_comparator (M : Creators2_with_comparator) = + Check_creators + (struct type ('a, 'b) t = ('a, 'b) M.t end) + (struct type ('a, 'b) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (struct type 'cmp t = 'cmp end) + (With_comparator) + (M) + +module type Creators_and_accessors_generic = sig + include Accessors_generic + include Creators_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) options + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) tree + with type 'a elt := 'a elt + with type 'cmp cmp := 'cmp cmp +end + +module type Creators_and_accessors0 = sig + include Accessors0 + include Creators0 + with type t := t + with type tree := tree + with type elt := elt + with type comparator_witness := comparator_witness +end + +module type Creators_and_accessors1 = sig + include Accessors1 + include Creators1 + with type 'a t := 'a t + with type 'a tree := 'a tree + with type comparator_witness := comparator_witness +end + +module type Creators_and_accessors2 = sig + include Accessors2 + include Creators2 + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) tree +end + +module type Creators_and_accessors2_with_comparator = sig + include Accessors2_with_comparator + include Creators2_with_comparator + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) tree +end + +module type S_poly = Creators_and_accessors1 + +module type For_deriving = sig + type ('a, 'b) t + + module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end + module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Comparator.S with type t := t + end + module type Compare_m = sig end + module type Hash_fold_m = Hasher.S + + val sexp_of_m__t + : (module Sexp_of_m with type t = 'elt) + -> ('elt, 'cmp) t + -> Sexp.t + + val m__t_of_sexp + : (module M_of_sexp with type t = 'elt and type comparator_witness = 'cmp) + -> Sexp.t + -> ('elt, 'cmp) t + + val compare_m__t : (module Compare_m) -> ('elt, 'cmp) t -> ('elt, 'cmp) t -> int + + val hash_fold_m__t + : (module Hash_fold_m with type t = 'elt) + -> (Hash.state -> ('elt, _) t -> Hash.state) + + val hash_m__t + : (module Hash_fold_m with type t = 'elt) + -> (('elt, _) t -> int) +end + +module type Set = sig + (** This module defines the [Set] module for [Base]. Functions that construct a set take + as an argument the comparator for the element type. *) + + (** The type of a set. The first type parameter identifies the type of the element, and + the second identifies the comparator, which determines the comparison function that + is used for ordering elements in this set. Many operations (e.g., {!union}), + require that they be passed sets with the same element type and the same comparator + type. *) + type ('elt, 'cmp) t [@@deriving_inline compare] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('elt -> 'elt -> int) -> + ('cmp -> 'cmp -> int) -> ('elt, 'cmp) t -> ('elt, 'cmp) t -> int + end[@@ocaml.doc "@inline"] + [@@@end] + + type ('k, 'cmp) comparator = + (module Comparator.S with type t = 'k and type comparator_witness = 'cmp) + + (** Tests internal invariants of the set data structure. Returns true on success. *) + val invariants : (_, _) t -> bool + + (** Returns a first-class module that can be used to build other map/set/etc + with the same notion of comparison. *) + val comparator_s : ('a, 'cmp) t -> ('a, 'cmp) comparator + + val comparator : ('a, 'cmp) t -> ('a, 'cmp) Comparator.t + + (** Creates an empty set based on the provided comparator. *) + val empty : ('a, 'cmp) comparator -> ('a, 'cmp) t + + (** Creates a set based on the provided comparator that contains only the provided + element. *) + val singleton : ('a, 'cmp) comparator -> 'a -> ('a, 'cmp) t + + (** Returns the cardinality of the set. [O(1)]. *) + val length : (_, _) t -> int + + (** [is_empty t] is [true] iff [t] is empty. [O(1)]. *) + val is_empty : (_, _) t -> bool + + (** [mem t a] returns [true] iff [a] is in [t]. [O(log n)]. *) + val mem : ('a, _) t -> 'a -> bool + + (** [add t a] returns a new set with [a] added to [t], or returns [t] if [mem t a]. + [O(log n)]. *) + val add : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + + (** [remove t a] returns a new set with [a] removed from [t] if [mem t a], or returns [t] + otherwise. [O(log n)]. *) + val remove : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + + (** [union t1 t2] returns the union of the two sets. [O(length t1 + length t2)]. *) + val union : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + + (** [union c list] returns the union of all the sets in [list]. The + [comparator] argument is required for the case where [list] is empty. + [O(max(List.length list, n log n))], where [n] is the sum of sizes of the input sets. *) + val union_list : ('a, 'cmp) comparator -> ('a, 'cmp) t list -> ('a, 'cmp) t + + (** [inter t1 t2] computes the intersection of sets [t1] and [t2]. [O(length t1 + + length t2)]. *) + val inter : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + + (** [diff t1 t2] computes the set difference [t1 - t2], i.e., the set containing all + elements in [t1] that are not in [t2]. [O(length t1 + length t2)]. *) + val diff : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + + (** [symmetric_diff t1 t2] returns a sequence of changes between [t1] and [t2]. It is + intended to be efficient in the case where [t1] and [t2] share a large amount of + structure. *) + val symmetric_diff + : ('a, 'cmp) t + -> ('a, 'cmp) t + -> ('a, 'a) Either.t Sequence.t + + (** [compare_direct t1 t2] compares the sets [t1] and [t2]. It returns the same result + as [compare], but unlike compare, doesn't require arguments to be passed in for the + type parameters of the set. [O(length t1 + length t2)]. *) + val compare_direct : ('a, 'cmp) t -> ('a, 'cmp) t -> int + + (** Hash function: a building block to use when hashing data structures containing sets in + them. [hash_fold_direct hash_fold_key] is compatible with [compare_direct] iff + [hash_fold_key] is compatible with [(comparator s).compare] of the set [s] being + hashed. *) + val hash_fold_direct + : 'a Hash.folder + -> ('a, 'cmp) t Hash.folder + + (** [equal t1 t2] returns [true] iff the two sets have the same elements. [O(length t1 + + length t2)] *) + val equal : ('a, 'cmp) t -> ('a, 'cmp) t -> bool + + (** [exists t ~f] returns [true] iff there exists an [a] in [t] for which [f a]. [O(n)], + but returns as soon as it finds an [a] for which [f a]. *) + val exists : ('a, _) t -> f:('a -> bool) -> bool + + (** [for_all t ~f] returns [true] iff for all [a] in [t], [f a]. [O(n)], but returns as + soon as it finds an [a] for which [not (f a)]. *) + val for_all : ('a, _) t -> f:('a -> bool) -> bool + + (** [count t] returns the number of elements of [t] for which [f] returns [true]. + [O(n)]. *) + val count : ('a, _) t -> f:('a -> bool) -> int + + (** [sum t] returns the sum of [f t] for each [t] in the set. + [O(n)]. *) + val sum + : (module Container.Summable with type t = 'sum) + -> ('a, _) t -> f:('a -> 'sum) -> 'sum + + (** [find t f] returns an element of [t] for which [f] returns true, with no guarantee as + to which element is returned. [O(n)], but returns as soon as a suitable element is + found. *) + val find : ('a, _) t -> f:('a -> bool) -> 'a option + + (** [find_map t f] returns [b] for some [a] in [t] for which [f a = Some b]. If no such + [a] exists, then [find] returns [None]. [O(n)], but returns as soon as a suitable + element is found. *) + val find_map : ('a, _) t -> f:('a -> 'b option) -> 'b option + + (** Like [find], but throws an exception on failure. *) + val find_exn : ('a, _) t -> f:('a -> bool) -> 'a + + (** [nth t i] returns the [i]th smallest element of [t], in [O(log n)] time. The + smallest element has [i = 0]. Returns [None] if [i < 0] or [i >= length t]. *) + val nth : ('a, _) t -> int -> 'a option + + (** [remove_index t i] returns a version of [t] with the [i]th smallest element removed, + in [O(log n)] time. The smallest element has [i = 0]. Returns [t] if [i < 0] or + [i >= length t]. *) + val remove_index : ('a, 'cmp) t -> int -> ('a, 'cmp) t + + (** [is_subset t1 ~of_:t2] returns true iff [t1] is a subset of [t2]. *) + val is_subset : ('a, 'cmp) t -> of_:('a, 'cmp) t -> bool + + (** [Named] allows the validation of subset and equality relationships between sets. A + [Named.t] is a record of a set and a name, where the name is used in error messages, + and [Named.is_subset] and [Named.equal] validate subset and equality relationships + respectively. + + The error message for, e.g., + {[ + Named.is_subset { set = set1; name = "set1" } ~of_:{set = set2; name = "set2" } + ]} + + looks like + {v + ("set1 is not a subset of set2" (invalid_elements (...elements of set1 - set2...))) + v} + + so [name] should be a noun phrase that doesn't sound awkward in the above error + message. Even though it adds verbosity, choosing [name]s that start with the phrase + "the set of" often makes the error message sound more natural. + *) + module Named : sig + type nonrec ('a, 'cmp) t = { + set : ('a, 'cmp) t; + name : string; + } + + (** [is_subset t1 ~of_:t2] returns [Ok ()] if [t1] is a subset of [t2] and a + human-readable error otherwise. *) + val is_subset : ('a, 'cmp) t -> of_:('a, 'cmp) t -> unit Or_error.t + + (** [equal t1 t2] returns [Ok ()] if [t1] is equal to [t2] and a human-readable + error otherwise. *) + val equal : ('a, 'cmp) t -> ('a, 'cmp) t -> unit Or_error.t + end + + (** The list or array given to [of_list] and [of_array] need not be sorted. *) + val of_list : ('a, 'cmp) comparator -> 'a list -> ('a, 'cmp) t + val of_array : ('a, 'cmp) comparator -> 'a array -> ('a, 'cmp) t + + (** [to_list] and [to_array] produce sequences sorted in ascending order according to the + comparator. *) + val to_list : ('a, _) t -> 'a list + val to_array : ('a, _) t -> 'a array + + (** Create set from sorted array. The input must be sorted (either in ascending or + descending order as given by the comparator) and contain no duplicates, otherwise the + result is an error. The complexity of this function is [O(n)]. *) + val of_sorted_array + : ('a, 'cmp) comparator + -> 'a array + -> ('a, 'cmp) t Or_error.t + + (** Similar to [of_sorted_array], but without checking the input array. *) + val of_sorted_array_unchecked + : ('a, 'cmp) comparator + -> 'a array + -> ('a, 'cmp) t + + (** [of_increasing_iterator_unchecked c ~len ~f] behaves like [of_sorted_array_unchecked c + (Array.init len ~f)], with the additional restriction that a decreasing order is not + supported. The advantage is not requiring you to allocate an intermediate array. [f] + will be called with 0, 1, ... [len - 1], in order. *) + val of_increasing_iterator_unchecked + : ('a, 'cmp) comparator + -> len:int + -> f:(int -> 'a) + -> ('a, 'cmp) t + + (** [stable_dedup_list] is here rather than in the [List] module because the + implementation relies crucially on sets, and because doing so allows one to avoid uses + of polymorphic comparison by instantiating the functor at a different implementation + of [Comparator] and using the resulting [stable_dedup_list]. *) + val stable_dedup_list : ('a, _) comparator -> 'a list -> 'a list + + (** [map c t ~f] returns a new set created by applying [f] to every element in + [t]. The returned set is based on the provided [comparator]. [O(n log n)]. *) + val map + : ('b, 'cmp) comparator + -> ('a, _) t + -> f:('a -> 'b) + -> ('b, 'cmp) t + + (** Like {!map}, except elements for which [f] returns [None] will be dropped. *) + val filter_map + : ('b, 'cmp) comparator + -> ('a, _) t + -> f:('a -> 'b option) + -> ('b, 'cmp) t + + (** [filter t ~f] returns the subset of [t] for which [f] evaluates to true. [O(n log + n)]. *) + val filter : ('a, 'cmp) t -> f:('a -> bool) -> ('a, 'cmp) t + + (** [fold t ~init ~f] folds over the elements of the set from smallest to largest. *) + val fold + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> 'accum) + -> 'accum + + (** [fold_result ~init ~f] folds over the elements of the set from smallest to + largest, short circuiting the fold if [f accum x] is an [Error _] *) + val fold_result + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + (** [fold_until t ~init ~f] is a short-circuiting version of [fold]. If [f] + returns [Stop _] the computation ceases and results in that value. If [f] returns + [Continue _], the fold will proceed. *) + val fold_until + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + (** Like {!fold}, except that it goes from the largest to the smallest element. *) + val fold_right + : ('a, _) t + -> init:'accum + -> f:('a -> 'accum -> 'accum) + -> 'accum + + (** [iter t ~f] calls [f] on every element of [t], going in order from the smallest to + largest. *) + val iter : ('a, _) t -> f:('a -> unit) -> unit + + (** Iterate two sets side by side. Complexity is [O(m+n)] where [m] and [n] are the sizes + of the two input sets. As an example, with the inputs [0; 1] and [1; 2], [f] will be + called with [`Left 0]; [`Both (1, 1)]; and [`Right 2]. *) + val iter2 + : ('a, 'cmp) t + -> ('a, 'cmp) t + -> f:([`Left of 'a | `Right of 'a | `Both of 'a * 'a] -> unit) + -> unit + + (** if [a, b = partition_tf set ~f] then [a] is the elements on which [f] produced [true], + and [b] is the elements on which [f] produces [false]. *) + val partition_tf + : ('a, 'cmp) t + -> f:('a -> bool) + -> ('a, 'cmp) t * ('a, 'cmp) t + + (** Same as {!to_list}. *) + val elements : ('a, _) t -> 'a list + + (** Returns the smallest element of the set. [O(log n)]. *) + val min_elt : ('a, _) t -> 'a option + + (** Like {!min_elt}, but throws an exception when given an empty set. *) + val min_elt_exn : ('a, _) t -> 'a + + (** Returns the largest element of the set. [O(log n)]. *) + val max_elt : ('a, _) t -> 'a option + + (** Like {!max_elt}, but throws an exception when given an empty set. *) + val max_elt_exn : ('a, _) t -> 'a + + (** returns an arbitrary element, or [None] if the set is empty. *) + val choose : ('a, _) t -> 'a option + + (** Like {!choose}, but throws an exception on an empty set. *) + val choose_exn : ('a, _) t -> 'a + + (** [split t x] produces a triple [(t1, maybe_x, t2)] where [t1] is the set of elements + strictly less than [x], [maybe_x] is the member (if any) of [t] which compares equal + to [x], and [t2] is the set of elements strictly larger than [x]. *) + val split : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t * 'a option * ('a, 'cmp) t + + (** if [equiv] is an equivalence predicate, then [group_by set ~equiv] produces a list + of equivalence classes (i.e., a set-theoretic quotient). E.g., + + {[ + let chars = Set.of_list ['A'; 'a'; 'b'; 'c'] in + let equiv c c' = Char.equal (Char.uppercase c) (Char.uppercase c') in + group_by chars ~equiv + ]} + + produces: + + {[ + [Set.of_list ['A';'a']; Set.singleton 'b'; Set.singleton 'c'] + ]} + + [group_by] runs in O(n^2) time, so if you have a comparison function, it's usually + much faster to use [Set.of_list]. *) + val group_by : ('a, 'cmp) t -> equiv:('a -> 'a -> bool) -> ('a, 'cmp) t list + + (** [to_sequence t] converts the set [t] to a sequence of the elements between + [greater_or_equal_to] and [less_or_equal_to] inclusive in the order indicated by + [order]. If [greater_or_equal_to > less_or_equal_to] the sequence is empty. Cost is + O(log n) up front and amortized O(1) for each element produced. *) + val to_sequence + : ?order : [ `Increasing (** default *) | `Decreasing ] + -> ?greater_or_equal_to : 'a + -> ?less_or_equal_to : 'a + -> ('a, 'cmp) t + -> 'a Sequence.t + + (** Produces the elements of the two sets between [greater_or_equal_to] and + [less_or_equal_to] in [order], noting whether each element appears in the left set, + the right set, or both. In the both case, both elements are returned, in case the + caller can distinguish between elements that are equal to the sets' comparator. Runs + in O(length t + length t'). *) + module Merge_to_sequence_element : sig + type ('a, 'b) t = ('a, 'b) Sequence.Merge_with_duplicates_element.t = + | Left of 'a + | Right of 'b + | Both of 'a * 'b + [@@deriving_inline compare, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('a -> 'a -> int) -> + ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('a,'b) t := ('a, 'b) t + end[@@ocaml.doc "@inline"] + [@@@end] + end + + val merge_to_sequence + : ?order : [ `Increasing (** default *) | `Decreasing ] + -> ?greater_or_equal_to : 'a + -> ?less_or_equal_to : 'a + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> ('a, 'a) Merge_to_sequence_element.t Sequence.t + + (** [M] is meant to be used in combination with OCaml applicative functor types: + + {[ + type string_set = Set.M(String).t + ]} + + which stands for: + + {[ + type string_set = (String.t, String.comparator_witness) Set.t + ]} + + The point is that [Set.M(String).t] supports deriving, whereas the second syntax + doesn't (because there is no such thing as, say, String.sexp_of_comparator_witness, + instead you would want to pass the comparator directly). *) + module M (Elt : sig type t type comparator_witness end) : sig + type nonrec t = (Elt.t, Elt.comparator_witness) t + end + + include For_deriving with type ('a, 'b) t := ('a, 'b) t + + (** A polymorphic Set. *) + module Poly : S_poly with type 'elt t = ('elt, Comparator.Poly.comparator_witness) t + + (** Using comparator is a similar interface as the toplevel of [Set], except the functions + take a [~comparator:('elt, 'cmp) Comparator.t] where the functions at the toplevel of + [Set] takes a [('elt, 'cmp) comparator]. *) + module Using_comparator : sig + type nonrec ('elt, 'cmp) t = ('elt, 'cmp) t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('elt -> Ppx_sexp_conv_lib.Sexp.t) -> + ('cmp -> Ppx_sexp_conv_lib.Sexp.t) -> + ('elt, 'cmp) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + val t_of_sexp_direct + : comparator:('elt, 'cmp) Comparator.t + -> (Sexp.t -> 'elt) + -> Sexp.t + -> ('elt, 'cmp) t + + module Tree : sig + (** A [Tree.t] contains just the tree data structure that a set is based on, without + including the comparator. Accordingly, any operation on a [Tree.t] must also take + as an argument the corresponding comparator. *) + type ('a, 'cmp) t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('cmp -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 'cmp) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + val t_of_sexp_direct + : comparator:('elt, 'cmp) Comparator.t + -> (Sexp.t -> 'elt) + -> Sexp.t + -> ('elt, 'cmp) t + + module Named : sig + type nonrec ('a, 'cmp) t = { + tree : ('a, 'cmp) t; + name : string; + } + + val is_subset + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t + -> of_:('a, 'cmp) t + -> unit Or_error.t + + val equal + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> unit Or_error.t + end + + include Creators_and_accessors2_with_comparator + with type ('a, 'b) set := ('a, 'b) t + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) t + with type ('a, 'b) named := ('a, 'b) Named.t + with module Named := Named + + val empty_without_value_restriction : (_, _) t + end + + include Accessors2 + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) Tree.t + with type ('a, 'b) named := ('a, 'b) Named.t + + include Creators2_with_comparator + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) Tree.t + with type ('a, 'b) set := ('a, 'b) t + + val comparator : ('a, 'cmp) t -> ('a, 'cmp) Comparator.t + + val hash_fold_direct + : 'elt Hash.folder + -> ('elt, 'cmp) t Hash.folder + + module Empty_without_value_restriction (Elt : Comparator.S1) : sig + val empty : ('a Elt.t, Elt.comparator_witness) t + end + end + + (** {2 Modules and module types for extending [Set]} + + For use in extensions of Base, like [Core_kernel]. *) + + module With_comparator = With_comparator + module With_first_class_module = With_first_class_module + module Without_comparator = Without_comparator + + module type For_deriving = For_deriving + + module type S_poly = S_poly + module type Accessors0 = Accessors0 + module type Accessors1 = Accessors1 + module type Accessors2 = Accessors2 + module type Accessors2_with_comparator = Accessors2_with_comparator + module type Accessors_generic = Accessors_generic + module type Creators0 = Creators0 + module type Creators1 = Creators1 + module type Creators2 = Creators2 + module type Creators2_with_comparator = Creators2_with_comparator + module type Creators_and_accessors0 = Creators_and_accessors0 + module type Creators_and_accessors1 = Creators_and_accessors1 + module type Creators_and_accessors2 = Creators_and_accessors2 + module type Creators_and_accessors2_with_comparator = Creators_and_accessors2_with_comparator + module type Creators_generic = Creators_generic + module type Elt_plain = Elt_plain +end diff --git a/src/sexp.ml b/src/sexp.ml new file mode 100644 index 0000000..079acef --- /dev/null +++ b/src/sexp.ml @@ -0,0 +1,42 @@ +open Hash.Builtin +open Ppx_compare_lib.Builtin + +include (Sexplib0.Sexp : module type of Sexplib0.Sexp with type t := Sexplib0.Sexp.t) + +(** Type of S-expressions *) +type t = Sexplib0.Sexp.t = Atom of string | List of t list +[@@deriving_inline compare, hash] +let rec compare : t -> t -> int = + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match (a__001_, b__002_) with + | (Atom _a__003_, Atom _b__004_) -> compare_string _a__003_ _b__004_ + | (Atom _, _) -> (-1) + | (_, Atom _) -> 1 + | (List _a__005_, List _b__006_) -> + compare_list compare _a__005_ _b__006_) +let rec (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + (fun hsv -> + fun arg -> + match arg with + | Atom _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 0 in + let hsv = hsv in hash_fold_string hsv _a0 + | List _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 1 in + let hsv = hsv in hash_fold_list hash_fold_t hsv _a0 : Ppx_hash_lib.Std.Hash.state + -> + t -> + Ppx_hash_lib.Std.Hash.state) +and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in hash_fold_t hsv arg) in + fun x -> func x +[@@@end] + +let of_string = () diff --git a/src/sexp.mli b/src/sexp.mli new file mode 100644 index 0000000..282c13b --- /dev/null +++ b/src/sexp.mli @@ -0,0 +1,20 @@ +(** Type of S-expressions *) +type t = Sexplib0.Sexp.t = Atom of string | List of t list +[@@deriving_inline compare, hash] +include +sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value +end[@@ocaml.doc "@inline"] +[@@@end] + +include module type of Sexplib0.Sexp with type t := Sexplib0.Sexp.t + +(** Base has never had an [of_string] function. We expose a deprecated [of_string] here + so that people can find it (e.g. with merlin), and learn what we recommend. This + [of_string] has type [unit] because we don't want it to be accidentally used. *) +val of_string : unit +[@@deprecated "[since 2018-02] Use [Parsexp.Single.parse_string_exn]"] diff --git a/src/sexp_with_comparable.ml b/src/sexp_with_comparable.ml new file mode 100644 index 0000000..87005aa --- /dev/null +++ b/src/sexp_with_comparable.ml @@ -0,0 +1,3 @@ +include Sexp + +include Comparable.Make (Sexp) diff --git a/src/sexp_with_comparable.mli b/src/sexp_with_comparable.mli new file mode 100644 index 0000000..a92ccd4 --- /dev/null +++ b/src/sexp_with_comparable.mli @@ -0,0 +1,6 @@ +(*_ This module is separated from Sexp to avoid circular dependencies as many things use + s-expressions *) + +include module type of struct include Sexp end (** @inline *) + +include Comparable.S with type t := t diff --git a/src/sexpable.ml b/src/sexpable.ml new file mode 100644 index 0000000..ad37db9 --- /dev/null +++ b/src/sexpable.ml @@ -0,0 +1,79 @@ +open! Import + +include Sexplib0.Sexpable + +module Of_sexpable + (Sexpable : S) + (M : sig + type t + val to_sexpable : t -> Sexpable.t + val of_sexpable : Sexpable.t -> t + end) + : S with type t := M.t = +struct + let t_of_sexp sexp = + let s = Sexpable.t_of_sexp sexp in + (try M.of_sexpable s with exn -> of_sexp_error_exn exn sexp) + + let sexp_of_t t = Sexpable.sexp_of_t (M.to_sexpable t) +end + +module Of_sexpable1 + (Sexpable : S1) + (M : sig + type 'a t + val to_sexpable : 'a t -> 'a Sexpable.t + val of_sexpable : 'a Sexpable.t -> 'a t + end) + : S1 with type 'a t := 'a M.t = +struct + let t_of_sexp a_of_sexp sexp = + let s = Sexpable.t_of_sexp a_of_sexp sexp in + (try M.of_sexpable s with exn -> of_sexp_error_exn exn sexp) + + let sexp_of_t sexp_of_a t = Sexpable.sexp_of_t sexp_of_a (M.to_sexpable t) +end + +module Of_sexpable2 (Sexpable : S2) + (M : sig + type ('a, 'b) t + val to_sexpable : ('a, 'b) t -> ('a, 'b) Sexpable.t + val of_sexpable : ('a, 'b) Sexpable.t -> ('a, 'b) t + end) + : S2 with type ('a, 'b) t := ('a, 'b) M.t = +struct + let t_of_sexp a_of_sexp b_of_sexp sexp = + let s = Sexpable.t_of_sexp a_of_sexp b_of_sexp sexp in + (try M.of_sexpable s with exn -> of_sexp_error_exn exn sexp) + + let sexp_of_t sexp_of_a sexp_of_b t = + Sexpable.sexp_of_t sexp_of_a sexp_of_b (M.to_sexpable t) +end + +module Of_sexpable3 (Sexpable : S3) + (M : sig + type ('a, 'b, 'c) t + val to_sexpable : ('a, 'b, 'c) t -> ('a, 'b, 'c) Sexpable.t + val of_sexpable : ('a, 'b, 'c) Sexpable.t -> ('a, 'b, 'c) t + end) + : S3 with type ('a, 'b, 'c) t := ('a, 'b, 'c) M.t = +struct + let t_of_sexp a_of_sexp b_of_sexp c_of_sexp sexp = + let s = Sexpable.t_of_sexp a_of_sexp b_of_sexp c_of_sexp sexp in + (try M.of_sexpable s with exn -> of_sexp_error_exn exn sexp) + + let sexp_of_t sexp_of_a sexp_of_b sexp_of_c t = + Sexpable.sexp_of_t sexp_of_a sexp_of_b sexp_of_c (M.to_sexpable t) +end + +module Of_stringable (M : Stringable.S) : S with type t := M.t = struct + let t_of_sexp sexp = + match sexp with + | Sexp.Atom s -> + (try M.of_string s with exn -> of_sexp_error_exn exn sexp) + | Sexp.List _ -> + of_sexp_error + "Sexpable.Of_stringable.t_of_sexp expected an atom, but got a list" sexp + + let sexp_of_t t = Sexp.Atom (M.to_string t) +end diff --git a/src/sexpable.mli b/src/sexpable.mli new file mode 100644 index 0000000..9c1ef50 --- /dev/null +++ b/src/sexpable.mli @@ -0,0 +1,78 @@ +(** Provides functors for making modules sexpable. New code should use the [[@@deriving + sexp]] syntax directly. These module types ([S], [S1], [S2], and [S3]) are exported + for backwards compatibility only. *) + +open! Import + +module type S = sig + type t + + val t_of_sexp : Sexp.t -> t + val sexp_of_t : t -> Sexp.t +end + +module type S1 = sig + type 'a t + + val t_of_sexp : (Sexp.t -> 'a) -> Sexp.t -> 'a t + val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t +end + +module type S2 = sig + type ('a, 'b) t + + val t_of_sexp : (Sexp.t -> 'a) -> (Sexp.t -> 'b) -> Sexp.t -> ('a, 'b) t + val sexp_of_t : ('a -> Sexp.t) -> ('b -> Sexp.t) -> ('a, 'b) t -> Sexp.t +end + +module type S3 = sig + type ('a, 'b, 'c) t + + val t_of_sexp + : (Sexp.t -> 'a) -> (Sexp.t -> 'b) -> (Sexp.t -> 'c) -> Sexp.t + -> ('a, 'b, 'c) t + + val sexp_of_t + : ('a -> Sexp.t) -> ('b -> Sexp.t) -> ('c -> Sexp.t) -> ('a, 'b, 'c) t + -> Sexp.t +end + +(** For when you want the sexp representation of one type to be the same as that for + some other isomorphic type. *) +module Of_sexpable + (Sexpable : S) + (M : sig + type t + val to_sexpable : t -> Sexpable.t + val of_sexpable : Sexpable.t -> t + end) + : S with type t := M.t + +module Of_sexpable1 + (Sexpable : S1) + (M : sig + type 'a t + val to_sexpable : 'a t -> 'a Sexpable.t + val of_sexpable : 'a Sexpable.t -> 'a t + end) + : S1 with type 'a t := 'a M.t + +module Of_sexpable2 + (Sexpable : S2) + (M : sig + type ('a, 'b) t + val to_sexpable : ('a, 'b) t -> ('a, 'b) Sexpable.t + val of_sexpable : ('a, 'b) Sexpable.t -> ('a, 'b) t + end) + : S2 with type ('a, 'b) t := ('a, 'b) M.t + +module Of_sexpable3 + (Sexpable : S3) + (M : sig + type ('a, 'b, 'c) t + val to_sexpable : ('a, 'b, 'c) t -> ('a, 'b, 'c) Sexpable.t + val of_sexpable : ('a, 'b, 'c) Sexpable.t -> ('a, 'b, 'c) t + end) + : S3 with type ('a, 'b, 'c) t := ('a, 'b, 'c) M.t + +module Of_stringable (M : Stringable.S) : S with type t := M.t diff --git a/src/sexplib.ml b/src/sexplib.ml new file mode 100644 index 0000000..148fd14 --- /dev/null +++ b/src/sexplib.ml @@ -0,0 +1,10 @@ +(** This module is for use by ppx_sexp_conv, and is thus not in the interface of + Base. *) +module Conv_error = Sexplib0.Sexp_conv_error +module Conv = Sexplib0.Sexp_conv + +(** @canonical Base.Sexp *) +module Sexp = Sexp + +(** @canonical Base.Sexpable *) +module Sexpable = Sexpable diff --git a/src/sign.ml b/src/sign.ml new file mode 100644 index 0000000..5b9848b --- /dev/null +++ b/src/sign.ml @@ -0,0 +1,27 @@ +open! Import + +include Sign0 +include Identifiable.Make(Sign0) + +(* Open [Replace_polymorphic_compare] after including functor applications so + they do not shadow its definitions. This is here so that efficient versions + of the comparison functions are available within this module. *) +open! Replace_polymorphic_compare + +let to_float = function + | Neg -> -1. + | Zero -> 0. + | Pos -> 1. + +let flip = function + | Neg -> Pos + | Zero -> Zero + | Pos -> Neg + +let ( * ) t t' = of_int (to_int t * to_int t') + +(* Include type-specific [Replace_polymorphic_compare at the end, after any + functor applications that could shadow its definitions. This is here so + that efficient versions of the comparison functions are exported by this + module. *) +include Replace_polymorphic_compare diff --git a/src/sign.mli b/src/sign.mli new file mode 100644 index 0000000..89a9d7d --- /dev/null +++ b/src/sign.mli @@ -0,0 +1,32 @@ +(** A type for representing the sign of a numeric value. *) + +open! Import + +type t = Sign0.t = Neg | Zero | Pos [@@deriving_inline enumerate, hash] +include +sig + [@@@ocaml.warning "-32"] + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value +end[@@ocaml.doc "@inline"] +[@@@end] + +(** This provides [to_string]/[of_string], sexp conversion, Map, Hashtbl, etc. *) +include Identifiable.S with type t := t + +val of_int : int -> t + +(** Map [Neg/Zero/Pos] to [-1/0/1] respectively. *) +val to_int : t -> int + +(** Map [Neg/Zero/Pos] to [-1./0./1.] respectively. + (There is no [of_float] here, but see {!Float.sign_exn}.) *) +val to_float : t -> float + +(** Map [Neg/Zero/Pos] to [Pos/Zero/Neg] respectively. *) +val flip : t -> t + +(** [Neg * Neg = Pos], etc. *) +val ( * ) : t -> t -> t diff --git a/src/sign0.ml b/src/sign0.ml new file mode 100644 index 0000000..a7b8b9c --- /dev/null +++ b/src/sign0.ml @@ -0,0 +1,85 @@ +(* This is broken off to avoid circular dependency between Sign and Comparable. *) + +open! Import + +type t = Neg | Zero | Pos [@@deriving_inline sexp, compare, hash, enumerate] +let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = + let _tp_loc = "src/sign0.ml.t" in + function + | Ppx_sexp_conv_lib.Sexp.Atom ("neg"|"Neg") -> Neg + | Ppx_sexp_conv_lib.Sexp.Atom ("zero"|"Zero") -> Zero + | Ppx_sexp_conv_lib.Sexp.Atom ("pos"|"Pos") -> Pos + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("neg"|"Neg"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("zero"|"Zero"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("pos"|"Pos"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as sexp + -> Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | Neg -> Ppx_sexp_conv_lib.Sexp.Atom "Neg" + | Zero -> Ppx_sexp_conv_lib.Sexp.Atom "Zero" + | Pos -> Ppx_sexp_conv_lib.Sexp.Atom "Pos" +let compare : t -> t -> int = Ppx_compare_lib.polymorphic_compare +let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + (fun hsv -> + fun arg -> + match arg with + | Neg -> Ppx_hash_lib.Std.Hash.fold_int hsv 0 + | Zero -> Ppx_hash_lib.Std.Hash.fold_int hsv 1 + | Pos -> Ppx_hash_lib.Std.Hash.fold_int hsv 2 : Ppx_hash_lib.Std.Hash.state + -> + t -> + Ppx_hash_lib.Std.Hash.state) +let (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in hash_fold_t hsv arg) in + fun x -> func x +let all : t list = [Neg; Zero; Pos] +[@@@end] + +module Replace_polymorphic_compare = struct + let ( < ) (x : t) y = Poly.( < ) x y + let ( <= ) (x : t) y = Poly.( <= ) x y + let ( <> ) (x : t) y = Poly.( <> ) x y + let ( = ) (x : t) y = Poly.( = ) x y + let ( > ) (x : t) y = Poly.( > ) x y + let ( >= ) (x : t) y = Poly.( >= ) x y + + let ascending (x : t) y = Poly.ascending x y + let descending (x : t) y = Poly.descending x y + let compare (x : t) y = Poly.compare x y + let equal (x : t) y = Poly.equal x y + let max (x : t) y = if x >= y then x else y + let min (x : t) y = if x <= y then x else y +end + +let of_string s = t_of_sexp (sexp_of_string s) +let to_string t = string_of_sexp (sexp_of_t t) + +let to_int = function + | Neg -> -1 + | Zero -> 0 + | Pos -> 1 + +let _ = hash (* Ignore the hash function produced by [@@deriving_inline hash][@@@end] *) +let hash = to_int + +let module_name = "Base.Sign" + +let of_int n = + if n < 0 + then Neg + else if n = 0 + then Zero + else Pos diff --git a/src/sign_or_nan.ml b/src/sign_or_nan.ml new file mode 100644 index 0000000..33f2268 --- /dev/null +++ b/src/sign_or_nan.ml @@ -0,0 +1,117 @@ +open! Import + +module T = struct + type t = Neg | Zero | Pos | Nan [@@deriving_inline sexp, compare, hash, enumerate] + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = + let _tp_loc = "src/sign_or_nan.ml.T.t" in + function + | Ppx_sexp_conv_lib.Sexp.Atom ("neg"|"Neg") -> Neg + | Ppx_sexp_conv_lib.Sexp.Atom ("zero"|"Zero") -> Zero + | Ppx_sexp_conv_lib.Sexp.Atom ("pos"|"Pos") -> Pos + | Ppx_sexp_conv_lib.Sexp.Atom ("nan"|"Nan") -> Nan + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("neg"|"Neg"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("zero"|"Zero"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("pos"|"Pos"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("nan"|"Nan"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as sexp + -> Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | Neg -> Ppx_sexp_conv_lib.Sexp.Atom "Neg" + | Zero -> Ppx_sexp_conv_lib.Sexp.Atom "Zero" + | Pos -> Ppx_sexp_conv_lib.Sexp.Atom "Pos" + | Nan -> Ppx_sexp_conv_lib.Sexp.Atom "Nan" + let compare : t -> t -> int = Ppx_compare_lib.polymorphic_compare + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + (fun hsv -> + fun arg -> + match arg with + | Neg -> Ppx_hash_lib.Std.Hash.fold_int hsv 0 + | Zero -> Ppx_hash_lib.Std.Hash.fold_int hsv 1 + | Pos -> Ppx_hash_lib.Std.Hash.fold_int hsv 2 + | Nan -> Ppx_hash_lib.Std.Hash.fold_int hsv 3 : Ppx_hash_lib.Std.Hash.state + -> + t -> + Ppx_hash_lib.Std.Hash.state) + let (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in hash_fold_t hsv arg) in + fun x -> func x + let all : t list = [Neg; Zero; Pos; Nan] + [@@@end] + + let of_string s = t_of_sexp (sexp_of_string s) + let to_string t = string_of_sexp (sexp_of_t t) + + let module_name = "Base.Sign_or_nan" +end + +module Replace_polymorphic_compare = struct + let ( < ) (x : T.t) y = Poly.( < ) x y + let ( <= ) (x : T.t) y = Poly.( <= ) x y + let ( <> ) (x : T.t) y = Poly.( <> ) x y + let ( = ) (x : T.t) y = Poly.( = ) x y + let ( > ) (x : T.t) y = Poly.( > ) x y + let ( >= ) (x : T.t) y = Poly.( >= ) x y + + let ascending (x : T.t) y = Poly.ascending x y + let descending (x : T.t) y = Poly.descending x y + let compare (x : T.t) y = Poly.compare x y + let equal (x : T.t) y = Poly.equal x y + let max (x : T.t) y = if x >= y then x else y + let min (x : T.t) y = if x <= y then x else y +end + +include T +include Identifiable.Make(T) + +(* Open [Replace_polymorphic_compare] after including functor applications so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Replace_polymorphic_compare + +let of_sign = function + | Sign.Neg -> Neg + | Sign.Zero -> Zero + | Sign.Pos -> Pos + +let to_sign_exn = function + | Neg -> Sign.Neg + | Zero -> Sign.Zero + | Pos -> Sign.Pos + | Nan -> invalid_arg "Base.Sign_or_nan.to_sign_exn: Nan" + +let of_int n = of_sign (Sign.of_int n) + +let to_int_exn t = Sign.to_int (to_sign_exn t) + +let flip = function + | Neg -> Pos + | Zero -> Zero + | Pos -> Neg + | Nan -> Nan + +let ( * ) t t' = + match t, t' with + | Nan, _ | _, Nan -> + Nan + | _ -> + of_sign (Sign.( * ) (to_sign_exn t) (to_sign_exn t')) + +(* Include [Replace_polymorphic_compare] at the end, after any functor applications that + could shadow its definitions. This is here so that efficient versions of the comparison + functions are exported by this module. *) +include Replace_polymorphic_compare diff --git a/src/sign_or_nan.mli b/src/sign_or_nan.mli new file mode 100644 index 0000000..83d1e6f --- /dev/null +++ b/src/sign_or_nan.mli @@ -0,0 +1,34 @@ +(** An extension to [Sign] with a [Nan] constructor, for representing the sign + of float-like numeric values. *) + +open! Import + +type t = Neg | Zero | Pos | Nan [@@deriving_inline enumerate, hash] +include +sig + [@@@ocaml.warning "-32"] + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value +end[@@ocaml.doc "@inline"] +[@@@end] + +(** This provides [to_string]/[of_string], sexp conversion, Map, Hashtbl, etc. *) +include Identifiable.S with type t := t + +val of_int : int -> t + +(** Map [Neg/Zero/Pos] to [-1/0/1] respectively. [Nan] raises. *) +val to_int_exn : t -> int + +val of_sign : Sign.t -> t + +(** [Nan] raises. *) +val to_sign_exn : t -> Sign.t + +(** Map [Neg/Zero/Pos/Nan] to [Pos/Zero/Neg/Nan] respectively. *) +val flip : t -> t + +(** [Neg * Neg = Pos], etc. If either argument is [Nan] then the result is [Nan]. *) +val ( * ) : t -> t -> t diff --git a/src/source_code_position.ml b/src/source_code_position.ml new file mode 100644 index 0000000..9b4edad --- /dev/null +++ b/src/source_code_position.ml @@ -0,0 +1,20 @@ +open! Import + + +(* This is lifted out of [M] because [Source_code_position0] exports [String0] + as [String], which does not export a hash function. *) +let hash_override { Caml.Lexing. pos_fname; pos_lnum; pos_bol; pos_cnum } = + String.hash pos_fname + lxor Int.hash pos_lnum + lxor Int.hash pos_bol + lxor Int.hash pos_cnum +;; + +module M = struct + include Source_code_position0 + + let hash = hash_override +end + +include M +include Comparable.Make_using_comparator(M) diff --git a/src/source_code_position.mli b/src/source_code_position.mli new file mode 100644 index 0000000..8b9d424 --- /dev/null +++ b/src/source_code_position.mli @@ -0,0 +1,32 @@ +(** One typically obtains a [Source_code_position.t] using a [[%here]] expression, which + is implemented by the [ppx_here] preprocessor. *) + +open! Import + +(** See INRIA's OCaml documentation for a description of these fields. + + [sexp_of_t] uses the form ["FILE:LINE:COL"], and does not have a corresponding + [of_sexp]. *) +type t + = Caml.Lexing.position + = { pos_fname : string + ; pos_lnum : int + ; pos_bol : int + ; pos_cnum : int + } +[@@deriving_inline hash, sexp_of] +include +sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Comparable.S with type t := t + +(** [to_string t] converts [t] to the form ["FILE:LINE:COL"]. *) +val to_string : t -> string + diff --git a/src/source_code_position0.ml b/src/source_code_position0.ml new file mode 100644 index 0000000..92f8fd6 --- /dev/null +++ b/src/source_code_position0.ml @@ -0,0 +1,179 @@ +open! Import + +module Int = Int0 +module String = String0 + +module T = struct + type t = Caml.Lexing.position = + { pos_fname : string; + pos_lnum : int; + pos_bol : int; + pos_cnum : int; + } + [@@deriving_inline compare, hash, sexp] + let compare : t -> t -> int = + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match compare_string a__001_.pos_fname b__002_.pos_fname with + | 0 -> + (match compare_int a__001_.pos_lnum b__002_.pos_lnum with + | 0 -> + (match compare_int a__001_.pos_bol b__002_.pos_bol with + | 0 -> compare_int a__001_.pos_cnum b__002_.pos_cnum + | n -> n) + | n -> n) + | n -> n) + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + fun hsv -> + fun arg -> + let hsv = + let hsv = + let hsv = let hsv = hsv in hash_fold_string hsv arg.pos_fname in + hash_fold_int hsv arg.pos_lnum in + hash_fold_int hsv arg.pos_bol in + hash_fold_int hsv arg.pos_cnum + let (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in hash_fold_t hsv arg) in + fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = + let _tp_loc = "src/source_code_position0.ml.T.t" in + function + | Ppx_sexp_conv_lib.Sexp.List field_sexps as sexp -> + let pos_fname_field = ref None + and pos_lnum_field = ref None + and pos_bol_field = ref None + and pos_cnum_field = ref None + and duplicates = ref [] + and extra = ref [] in + let rec iter = + function + | (Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + field_name)::_field_sexp::[]))::tail -> + ((match field_name with + | "pos_fname" -> + (match !pos_fname_field with + | None -> + let fvalue = string_of_sexp _field_sexp in + pos_fname_field := (Some fvalue) + | Some _ -> duplicates := (field_name :: (!duplicates))) + | "pos_lnum" -> + (match !pos_lnum_field with + | None -> + let fvalue = int_of_sexp _field_sexp in + pos_lnum_field := (Some fvalue) + | Some _ -> duplicates := (field_name :: (!duplicates))) + | "pos_bol" -> + (match !pos_bol_field with + | None -> + let fvalue = int_of_sexp _field_sexp in + pos_bol_field := (Some fvalue) + | Some _ -> duplicates := (field_name :: (!duplicates))) + | "pos_cnum" -> + (match !pos_cnum_field with + | None -> + let fvalue = int_of_sexp _field_sexp in + pos_cnum_field := (Some fvalue) + | Some _ -> duplicates := (field_name :: (!duplicates))) + | _ -> + if !Ppx_sexp_conv_lib.Conv.record_check_extra_fields + then extra := (field_name :: (!extra)) + else ()); + iter tail) + | (Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + field_name)::[]))::tail -> + ((let _ = field_name in + if !Ppx_sexp_conv_lib.Conv.record_check_extra_fields + then extra := (field_name :: (!extra)) + else ()); + iter tail) + | (Ppx_sexp_conv_lib.Sexp.Atom _|Ppx_sexp_conv_lib.Sexp.List _ as + sexp)::_ + -> + Ppx_sexp_conv_lib.Conv_error.record_only_pairs_expected _tp_loc + sexp + | [] -> () in + (iter field_sexps; + (match !duplicates with + | _::_ -> + Ppx_sexp_conv_lib.Conv_error.record_duplicate_fields _tp_loc + (!duplicates) sexp + | [] -> + (match !extra with + | _::_ -> + Ppx_sexp_conv_lib.Conv_error.record_extra_fields _tp_loc + (!extra) sexp + | [] -> + (match ((!pos_fname_field), (!pos_lnum_field), + (!pos_bol_field), (!pos_cnum_field)) + with + | (Some pos_fname_value, Some pos_lnum_value, Some + pos_bol_value, Some pos_cnum_value) -> + { + pos_fname = pos_fname_value; + pos_lnum = pos_lnum_value; + pos_bol = pos_bol_value; + pos_cnum = pos_cnum_value + } + | _ -> + Ppx_sexp_conv_lib.Conv_error.record_undefined_elements + _tp_loc sexp + [((Ppx_sexp_conv_lib.Conv.(=) (!pos_fname_field) None), + "pos_fname"); + ((Ppx_sexp_conv_lib.Conv.(=) (!pos_lnum_field) None), + "pos_lnum"); + ((Ppx_sexp_conv_lib.Conv.(=) (!pos_bol_field) None), + "pos_bol"); + ((Ppx_sexp_conv_lib.Conv.(=) (!pos_cnum_field) None), + "pos_cnum")])))) + | Ppx_sexp_conv_lib.Sexp.Atom _ as sexp -> + Ppx_sexp_conv_lib.Conv_error.record_list_instead_atom _tp_loc sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | { pos_fname = v_pos_fname; pos_lnum = v_pos_lnum; pos_bol = v_pos_bol; + pos_cnum = v_pos_cnum } -> + let bnds = [] in + let bnds = + let arg = sexp_of_int v_pos_cnum in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "pos_cnum"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_pos_bol in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "pos_bol"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_pos_lnum in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "pos_lnum"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_string v_pos_fname in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "pos_fname"; arg]) + :: bnds in + Ppx_sexp_conv_lib.Sexp.List bnds + [@@@end] +end + +include T +include Comparator.Make(T) + +(* This is the same function as Ppx_here.lift_position_as_string. *) +let make_location_string ~pos_fname ~pos_lnum ~pos_cnum ~pos_bol = + String.concat + [ pos_fname + ; ":"; Int.to_string pos_lnum + ; ":"; Int.to_string (pos_cnum - pos_bol) + ] + +let to_string {Caml.Lexing.pos_fname; pos_lnum; pos_cnum; pos_bol} = + make_location_string ~pos_fname ~pos_lnum ~pos_cnum ~pos_bol + +let sexp_of_t t = Sexp.Atom (to_string t) diff --git a/src/stack.ml b/src/stack.ml new file mode 100644 index 0000000..721df1a --- /dev/null +++ b/src/stack.ml @@ -0,0 +1,209 @@ +open! Import + +include Stack_intf + +let raise_s = Error.raise_s + +(* This implementation is similar to [Deque] in that it uses an array of ['a] and + a mutable [int] to indicate what in the array is used. We choose to implement [Stack] + directly rather than on top of [Deque] for performance reasons. E.g. a simple + microbenchmark shows that push/pop is about 20% faster. *) +type 'a t = + { mutable length : int; + mutable elts : 'a Option_array.t; + } +[@@deriving_inline sexp_of] +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + function + | { length = v_length; elts = v_elts } -> + let bnds = [] in + let bnds = + let arg = Option_array.sexp_of_t _of_a v_elts in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "elts"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_length in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "length"; arg]) + :: bnds in + Ppx_sexp_conv_lib.Sexp.List bnds +[@@@end] + +let sexp_of_t_internal = sexp_of_t +let sexp_of_t = `Rebound_later +let _ = sexp_of_t + +let capacity t = Option_array.length t.elts + +let invariant invariant_a ({ length; elts } as t) : unit = + try + assert (0 <= length && length <= Option_array.length elts); + for i = 0 to length - 1 do + invariant_a (Option_array.get_some_exn elts i); + done; + (* We maintain the invariant that unused elements are unset to avoid a space + leak. *) + for i = length to Option_array.length elts - 1 do + assert (not (Option_array.is_some elts i)) + done; + with exn -> + raise_s (Sexp.message "Stack.invariant failed" + [ "exn", exn |> Exn.sexp_of_t + ; "stack", t |> sexp_of_t_internal sexp_of_opaque ]) +;; + +let create (type a) () : a t = + { length = 0; + elts = Option_array.empty; + } +;; + +let length t = t.length + +let is_empty t = length t = 0 + +(* The order in which elements are visited has been chosen so as to be backwards + compatible with both [Linked_stack] and [Caml.Stack] *) +let fold t ~init ~f = + let r = ref init in + for i = t.length - 1 downto 0 do + r := f !r (Option_array.get_some_exn t.elts i) + done; + !r +;; + +let iter t ~f = + for i = t.length - 1 downto 0 do + f (Option_array.get_some_exn t.elts i) + done; +;; + +module C = + Container.Make (struct + type nonrec 'a t = 'a t + let fold = fold + let iter = `Custom iter + let length = `Custom length + end) + +let mem = C.mem +let exists = C.exists +let for_all = C.for_all +let count = C.count +let sum = C.sum +let find = C.find +let find_map = C.find_map +let to_list = C.to_list +let to_array = C.to_array +let min_elt = C.min_elt +let max_elt = C.max_elt +let fold_result = C.fold_result +let fold_until = C.fold_until + +let of_list (type a) (l : a list) = + if List.is_empty l then + create () + else begin + let length = List.length l in + let elts = Option_array.create ~len:(2 * length) in + let r = ref l in + for i = length - 1 downto 0 do + match !r with + | [] -> assert false + | a :: l -> + Option_array.set_some elts i a; + r := l + done; + { length; elts } + end +;; + +let sexp_of_t sexp_of_a t = List.sexp_of_t sexp_of_a (to_list t) + +let t_of_sexp a_of_sexp sexp = of_list (List.t_of_sexp a_of_sexp sexp) + +let resize t size = + let arr = Option_array.create ~len:size in + Option_array.blit ~src:t.elts ~dst:arr ~src_pos:0 ~dst_pos:0 ~len:t.length; + t.elts <- arr +;; + +let set_capacity t new_capacity = + let new_capacity = max new_capacity (length t) in + if new_capacity <> capacity t then + resize t new_capacity +;; + +let push t a = + if t.length = Option_array.length t.elts then + resize t (2 * (t.length + 1)); + Option_array.set_some t.elts t.length a; + t.length <- t.length + 1; +;; + +let pop_nonempty t = + let i = t.length - 1 in + let result = Option_array.get_some_exn t.elts i in + Option_array.set_none t.elts i; + t.length <- i; + result +;; + +let pop_error = Error.of_string "Stack.pop of empty stack" + +let pop t = + if is_empty t + then None + else Some (pop_nonempty t) +;; + +let pop_exn t = + if is_empty t + then Error.raise pop_error + else pop_nonempty t +;; + +let top_nonempty t = Option_array.get_some_exn t.elts (t.length - 1) + +let top_error = Error.of_string "Stack.top of empty stack" + +let top t = + if is_empty t + then None + else Some (top_nonempty t) +;; + +let top_exn t = + if is_empty t + then Error.raise top_error + else top_nonempty t; +;; + +let copy { length; elts } = + { length; + elts = Option_array.copy elts; + } +;; + +let clear t = + if t.length > 0 then begin + for i = 0 to t.length - 1 do + Option_array.set_none t.elts i; + done; + t.length <- 0; + end; +;; + +let until_empty t f = + let rec loop () = if t.length > 0 then (f (pop_nonempty t); loop ()) in + loop () +;; + +let singleton x = + let t = create () in + push t x; + t +;; diff --git a/src/stack.mli b/src/stack.mli new file mode 100644 index 0000000..937ae3c --- /dev/null +++ b/src/stack.mli @@ -0,0 +1 @@ +include Stack_intf.Stack (** @inline *) diff --git a/src/stack_intf.ml b/src/stack_intf.ml new file mode 100644 index 0000000..f2c3ebd --- /dev/null +++ b/src/stack_intf.ml @@ -0,0 +1,79 @@ +(** An interface for stacks that follows [Core]'s conventions, as opposed to OCaml's + standard [Stack] module. *) + +open! Import + +module type S = sig + + type 'a t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Invariant.S1 with type 'a t := 'a t + + (** [fold], [iter], [find], and [find_map] visit the elements in order from the top of + the stack to the bottom. [to_list] and [to_array] return the elements in order from + the top of the stack to the bottom. + + Iteration functions ([iter], [fold], etc.) have unspecified behavior (although they + should still be memory-safe) when the stack is mutated while they are running (e.g. + by having the passed-in function call [push] or [pop] on the stack). + *) + include Container.S1 with type 'a t := 'a t + + (** [of_list l] returns a stack whose top is the first element of [l] and bottom is the + last element of [l]. *) + val of_list : 'a list -> 'a t + + (** [create ()] returns an empty stack. *) + val create : unit -> _ t + + (** [singleton a] creates a new stack containing only [a]. *) + val singleton : 'a -> 'a t + + (** [push t a] adds [a] to the top of stack [t]. *) + val push : 'a t -> 'a -> unit + + (** [pop t] removes and returns the top element of [t] as [Some a], or returns [None] if + [t] is empty. *) + val pop : 'a t -> 'a option + val pop_exn : 'a t -> 'a + + (** [top t] returns [Some a], where [a] is the top of [t], unless [is_empty t], in which + case [top] returns [None]. *) + val top : 'a t -> 'a option + val top_exn : 'a t -> 'a + + (** [clear t] discards all elements from [t]. *) + val clear : _ t -> unit + + (** [copy t] returns a copy of [t]. *) + val copy : 'a t -> 'a t + + (** [until_empty t f] repeatedly pops an element [a] off of [t] and runs [f a], until + [t] becomes empty. It is fine if [f] adds more elements to [t], in which case the + most-recently-added element will be processed next. *) + val until_empty : 'a t -> ('a -> unit) -> unit +end + +(** A stack implemented with an array. + + The implementation will grow the array as necessary, and will never automatically + shrink the array. One can use [set_capacity] to explicitly resize the array. *) +module type Stack = sig + module type S = S + + include S (** @open *) + + (** [capacity t] returns the length of the array backing [t]. *) + val capacity : _ t -> int + + (** [set_capacity t capacity] sets the length of the array backing [t] to [max capacity + (length t)]. To shrink as much as possible, do [set_capacity t 0]. *) + val set_capacity : _ t -> int -> unit +end + diff --git a/src/staged.ml b/src/staged.ml new file mode 100644 index 0000000..5ce96e3 --- /dev/null +++ b/src/staged.ml @@ -0,0 +1,6 @@ +open! Import + +type 'a t = 'a + +let stage = Fn.id +let unstage = Fn.id diff --git a/src/staged.mli b/src/staged.mli new file mode 100644 index 0000000..902e290 --- /dev/null +++ b/src/staged.mli @@ -0,0 +1,43 @@ +(** A type for making staging explicit in the type of a function. + + For example, you might want to have a function that creates a function for allocating + unique identifiers. Rather than using the type: + + {[ + val make_id_allocator : unit -> unit -> int + ]} + + you would have + + {[ + val make_id_allocator : unit -> (unit -> int) Staged.t + ]} + + Such a function could be defined as follows: + + {[ + let make_id_allocator () = + let ctr = ref 0 in + stage (fun () -> incr ctr; !ctr) + ]} + + and could be invoked as follows: + + {[ + let (id1,id2) = + let alloc = unstage (make_id_allocator ()) in + (alloc (), alloc ()) + ]} + + both {!stage} and {!unstage} functions are available in the toplevel namespace. + + (Note that in many cases, including perhaps the one above, it's preferable to create a + custom type rather than use [Staged].) *) + +open! Import + +type +'a t + +val stage : 'a -> 'a t +val unstage : 'a t -> 'a + diff --git a/src/string.ml b/src/string.ml new file mode 100644 index 0000000..4a2467f --- /dev/null +++ b/src/string.ml @@ -0,0 +1,1267 @@ +open! Import + +module Array = Array0 +module Bytes = Bytes0 + +include String0 + +let invalid_argf = Printf.invalid_argf + +let raise_s = Error.raise_s + +let stage = Staged.stage + +module T = struct + type t = string [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_string + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_string in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = string_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_string + [@@@end] + let compare = compare +end + +include T +include Comparator.Make(T) + +type elt = char + +let is_substring_at_gen = + let rec loop ~str ~str_pos ~sub ~sub_pos ~sub_len ~char_equal = + if sub_pos = sub_len + then true + else if char_equal (unsafe_get str str_pos) (unsafe_get sub sub_pos) + then loop ~str ~str_pos:(str_pos + 1) ~sub ~sub_pos:(sub_pos + 1) ~sub_len ~char_equal + else false + in + fun str ~pos:str_pos ~substring:sub ~char_equal -> + let str_len = length str in + let sub_len = length sub in + if str_pos < 0 || str_pos > str_len + then begin + invalid_argf "String.is_substring_at: invalid index %d for string of length %d" + str_pos str_len () + end; + str_pos + sub_len <= str_len + && loop ~str ~str_pos ~sub ~sub_pos:0 ~sub_len ~char_equal + +let is_suffix_gen string ~suffix ~char_equal = + let string_len = length string in + let suffix_len = length suffix in + string_len >= suffix_len + && is_substring_at_gen string + ~pos:(string_len - suffix_len) + ~substring:suffix + ~char_equal +;; + +let is_prefix_gen string ~prefix ~char_equal = + let string_len = length string in + let prefix_len = length prefix in + string_len >= prefix_len + && is_substring_at_gen string ~pos:0 ~substring:prefix ~char_equal +;; + +module Caseless = struct + module T = struct + type t = string [@@deriving_inline sexp] + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = string_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_string + [@@@end] + + let char_compare_caseless c1 c2 = Char.compare (Char.lowercase c1) (Char.lowercase c2) + let char_equal_caseless c1 c2 = Char.equal (Char.lowercase c1) (Char.lowercase c2) + + let rec compare_loop ~pos ~string1 ~len1 ~string2 ~len2 = + if pos = len1 + then if pos = len2 + then 0 + else -1 + else if pos = len2 + then 1 + else begin + let c = char_compare_caseless (unsafe_get string1 pos) (unsafe_get string2 pos) in + match c with + | 0 -> compare_loop ~pos:(pos + 1) ~string1 ~len1 ~string2 ~len2 + | _ -> c + end + + let compare string1 string2 = + if phys_equal string1 string2 + then 0 + else begin + compare_loop ~pos:0 + ~string1 ~len1:(String.length string1) + ~string2 ~len2:(String.length string2) + end + + let hash_fold_t state t = + let len = length t in + let state = ref (hash_fold_int state len) in + for pos = 0 to len - 1 do + state := hash_fold_char !state (Char.lowercase (unsafe_get t pos)) + done; + !state + + let hash t = Hash.run hash_fold_t t + + let is_suffix s ~suffix = is_suffix_gen s ~suffix ~char_equal:char_equal_caseless + let is_prefix s ~prefix = is_prefix_gen s ~prefix ~char_equal:char_equal_caseless + end + + include T + include Comparable.Make(T) +end + +(* This is copied/adapted from 'blit.ml'. + [sub], [subo] could be implemented using [Blit.Make(Bytes)] plus unsafe casts to/from + string but were inlined here to avoid using [Bytes.unsafe_of_string] as much as possible. + Also note that [blit] and [blito] will be deprecated and removed in the future. +*) +let sub src ~pos ~len = + if pos = 0 && len = String.length src + then src + else begin + Ordered_collection_common.check_pos_len_exn ~pos ~len ~total_length:(length src); + let dst = Bytes.create len in + if len > 0 then Bytes.unsafe_blit_string ~src ~src_pos:pos ~dst ~dst_pos:0 ~len; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:dst + end + +let subo ?(pos = 0) ?len src = + sub src ~pos ~len:(match len with Some i -> i | None -> length src - pos) + +let blit = Bytes.blit_string +let blito ~src ?(src_pos = 0) ?(src_len = length src - src_pos) ~dst ?(dst_pos = 0) () = + blit ~src ~src_pos ~len:src_len ~dst ~dst_pos + +let rec contains_unsafe t ~pos ~end_ char = + pos < end_ + && (Char.equal (unsafe_get t pos) char + || contains_unsafe t ~pos:(pos + 1) ~end_ char) + +let contains ?(pos = 0) ?len t char = + let total_length = String.length t in + let len = Option.value len ~default:(total_length - pos) in + Ordered_collection_common.check_pos_len_exn ~pos ~len ~total_length; + contains_unsafe t ~pos ~end_:(pos + len) char +;; + +let is_empty t = length t = 0 + +let index t char = + try Some (index_exn t char) + with Not_found_s _ | Caml.Not_found -> None + +let rindex t char = + try Some (rindex_exn t char) + with Not_found_s _ | Caml.Not_found -> None + +let index_from t pos char = + try Some (index_from_exn t pos char) + with Not_found_s _ | Caml.Not_found -> None + +let rindex_from t pos char = + try Some (rindex_from_exn t pos char) + with Not_found_s _ | Caml.Not_found -> None + +module Search_pattern = struct + + type t = string * int array [@@deriving_inline sexp_of] + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | (v0, v1) -> + let v0 = sexp_of_string v0 + and v1 = sexp_of_array sexp_of_int v1 in + Ppx_sexp_conv_lib.Sexp.List [v0; v1] + [@@@end] + + (* Find max number of matched characters at [next_text_char], given the current + [matched_chars]. Try to extend the current match, if chars don't match, try to match + fewer chars. If chars match then extend the match. *) + let kmp_internal_loop ~matched_chars ~next_text_char ~pattern ~kmp_arr = + let matched_chars = ref matched_chars in + while !matched_chars > 0 + && Char.( <> ) next_text_char (unsafe_get pattern !matched_chars) do + matched_chars := Array.unsafe_get kmp_arr (!matched_chars - 1) + done; + if Char.equal next_text_char (unsafe_get pattern !matched_chars) then + matched_chars := !matched_chars + 1; + !matched_chars + ;; + + (* Classic KMP pre-processing of the pattern: build the int array, which, for each i, + contains the length of the longest non-trivial prefix of s which is equal to a suffix + ending at s.[i] *) + let create pattern = + let n = length pattern in + let kmp_arr = Array.create ~len:n (-1) in + if n > 0 then begin + Array.unsafe_set kmp_arr 0 0; + let matched_chars = ref 0 in + for i = 1 to n - 1 do + matched_chars := + kmp_internal_loop + ~matched_chars:!matched_chars + ~next_text_char:(unsafe_get pattern i) + ~pattern + ~kmp_arr; + Array.unsafe_set kmp_arr i !matched_chars + done + end; + (pattern, kmp_arr) + ;; + + (* Classic KMP: use the pre-processed pattern to optimize look-behinds on non-matches. + We return int to avoid allocation in [index_exn]. -1 means no match. *) + let index_internal ?(pos=0) (pattern, kmp_arr) ~in_:text = + if pos < 0 || pos > length text - length pattern then + -1 + else begin + let j = ref pos in + let matched_chars = ref 0 in + let k = length pattern in + let n = length text in + while !j < n && !matched_chars < k do + let next_text_char = unsafe_get text !j in + matched_chars := + kmp_internal_loop + ~matched_chars:!matched_chars + ~next_text_char + ~pattern + ~kmp_arr; + j := !j + 1 + done; + if !matched_chars = k then + !j - k + else + -1 + end + ;; + + let matches t str = index_internal t ~in_:str >= 0 + + let index ?pos t ~in_ = + let p = index_internal ?pos t ~in_ in + if p < 0 then + None + else + Some p + ;; + + let index_exn ?pos t ~in_ = + let p = index_internal ?pos t ~in_ in + if p >= 0 then + p + else + raise_s (Sexp.message "Substring not found" + ["substring", sexp_of_string (fst t)]) + ;; + + let index_all (pattern, kmp_arr) ~may_overlap ~in_:text = + if length pattern = 0 then + List.init (1 + length text) ~f:Fn.id + else begin + let matched_chars = ref 0 in + let k = length pattern in + let n = length text in + let found = ref [] in + for j = 0 to n do + if !matched_chars = k then begin + found := (j - k)::!found; + (* we just found a match in the previous iteration *) + match may_overlap with + | true -> matched_chars := Array.unsafe_get kmp_arr (k - 1) + | false -> matched_chars := 0 + end; + if j < n then begin + let next_text_char = unsafe_get text j in + matched_chars := + kmp_internal_loop + ~matched_chars:!matched_chars + ~next_text_char + ~pattern + ~kmp_arr + end + done; + List.rev !found + end + ;; + + let replace_first ?pos t ~in_:s ~with_ = + match index ?pos t ~in_:s with + | None -> s + | Some i -> + let len_s = length s in + let len_t = length (fst t) in + let len_with = length with_ in + let dst = Bytes.create (len_s + len_with - len_t) in + blit ~src:s ~src_pos:0 ~dst ~dst_pos:0 ~len:i; + blit ~src:with_ ~src_pos:0 ~dst ~dst_pos:i ~len:len_with; + blit ~src:s ~src_pos:(i + len_t) ~dst ~dst_pos:(i + len_with) ~len:(len_s - i - len_t); + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:dst + ;; + + + let replace_all t ~in_:s ~with_ = + let matches = index_all t ~may_overlap:false ~in_:s in + match matches with + | [] -> s + | _::_ -> + let len_s = length s in + let len_t = length (fst t) in + let len_with = length with_ in + let num_matches = List.length matches in + let dst = Bytes.create (len_s + (len_with - len_t) * num_matches) in + let next_dst_pos = ref 0 in + let next_src_pos = ref 0 in + List.iter matches ~f:(fun i -> + let len = i - !next_src_pos in + blit ~src:s ~src_pos:!next_src_pos ~dst ~dst_pos:!next_dst_pos ~len; + blit ~src:with_ ~src_pos:0 ~dst ~dst_pos:(!next_dst_pos + len) ~len:len_with; + next_dst_pos := !next_dst_pos + len + len_with; + next_src_pos := !next_src_pos + len + len_t; + ); + blit ~src:s ~src_pos:!next_src_pos ~dst ~dst_pos:!next_dst_pos + ~len:(len_s - !next_src_pos); + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:dst + ;; +end + +let substr_index ?pos t ~pattern = + Search_pattern.index ?pos (Search_pattern.create pattern) ~in_:t +;; + +let substr_index_exn ?pos t ~pattern = + Search_pattern.index_exn ?pos (Search_pattern.create pattern) ~in_:t +;; + +let substr_index_all t ~may_overlap ~pattern = + Search_pattern.index_all (Search_pattern.create pattern) ~may_overlap ~in_:t +;; + +let substr_replace_first ?pos t ~pattern = + Search_pattern.replace_first ?pos (Search_pattern.create pattern) ~in_:t +;; + +let substr_replace_all t ~pattern = + Search_pattern.replace_all (Search_pattern.create pattern) ~in_:t +;; + +let is_substring t ~substring = + Option.is_some (substr_index t ~pattern:substring) +;; + +let of_string = Fn.id +let to_string = Fn.id + +let init n ~f = + if n < 0 then invalid_argf "String.init %d" n (); + let t = Bytes.create n in + for i = 0 to n - 1 do + Bytes.set t i (f i); + done; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:t +;; + +let to_list s = + let rec loop acc i = + if i < 0 then + acc + else + loop (s.[i] :: acc) (i-1) + in + loop [] (length s - 1) + +let to_list_rev s = + let len = length s in + let rec loop acc i = + if i = len then + acc + else + loop (s.[i] :: acc) (i+1) + in + loop [] 0 + +let rev t = + let len = length t in + let res = Bytes.create len in + for i = 0 to len - 1 do + unsafe_set res i (unsafe_get t (len - 1 - i)) + done; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:res +;; + +(** Efficient string splitting *) + +let lsplit2_exn line ~on:delim = + let pos = index_exn line delim in + (sub line ~pos:0 ~len:pos, + sub line ~pos:(pos+1) ~len:(length line - pos - 1) + ) + +let rsplit2_exn line ~on:delim = + let pos = rindex_exn line delim in + (sub line ~pos:0 ~len:pos, + sub line ~pos:(pos+1) ~len:(length line - pos - 1) + ) + +let lsplit2 line ~on = + try Some (lsplit2_exn line ~on) with Not_found_s _ | Caml.Not_found -> None + +let rsplit2 line ~on = + try Some (rsplit2_exn line ~on) with Not_found_s _ | Caml.Not_found -> None + +let rec char_list_mem l (c:char) = + match l with + | [] -> false + | hd::tl -> Char.equal hd c || char_list_mem tl c + +let split_gen str ~on = + let is_delim = + match on with + | `char c' -> (fun c -> Char.equal c c') + | `char_list l -> (fun c -> char_list_mem l c) + in + let len = length str in + let rec loop acc last_pos pos = + if pos = -1 then + sub str ~pos:0 ~len:last_pos :: acc + else + if is_delim str.[pos] then + let pos1 = pos + 1 in + let sub_str = sub str ~pos:pos1 ~len:(last_pos - pos1) in + loop (sub_str :: acc) pos (pos - 1) + else loop acc last_pos (pos - 1) + in + loop [] len (len - 1) +;; + +let split str ~on = split_gen str ~on:(`char on) ;; + +let split_on_chars str ~on:chars = + split_gen str ~on:(`char_list chars) +;; + +let split_lines = + let back_up_at_newline ~t ~pos ~eol = + pos := !pos - (if !pos > 0 && Char.equal t.[!pos - 1] '\r' then 2 else 1); + eol := !pos + 1; + in + fun t -> + let n = length t in + if n = 0 + then [] + else + (* Invariant: [-1 <= pos < eol]. *) + let pos = ref (n - 1) in + let eol = ref n in + let ac = ref [] in + (* We treat the end of the string specially, because if the string ends with a + newline, we don't want an extra empty string at the end of the output. *) + if Char.equal t.[!pos] '\n' then back_up_at_newline ~t ~pos ~eol; + while !pos >= 0 do + if Char.( <> ) t.[!pos] '\n' + then decr pos + else + (* Because [pos < eol], we know that [start <= eol]. *) + let start = !pos + 1 in + ac := sub t ~pos:start ~len:(!eol - start) :: !ac; + back_up_at_newline ~t ~pos ~eol + done; + sub t ~pos:0 ~len:!eol :: !ac +;; + +let is_suffix s ~suffix = is_suffix_gen s ~suffix ~char_equal:Char.equal +let is_prefix s ~prefix = is_prefix_gen s ~prefix ~char_equal:Char.equal + +let is_substring_at s ~pos ~substring = + is_substring_at_gen s ~pos ~substring ~char_equal:Char.equal + +let wrap_sub_n t n ~name ~pos ~len ~on_error = + if n < 0 then + invalid_arg (name ^ " expecting nonnegative argument") + else + try + sub t ~pos ~len + with _ -> + on_error + +let drop_prefix t n = wrap_sub_n ~name:"drop_prefix" t n ~pos:n ~len:(length t - n) ~on_error:"" +let drop_suffix t n = wrap_sub_n ~name:"drop_suffix" t n ~pos:0 ~len:(length t - n) ~on_error:"" +let prefix t n = wrap_sub_n ~name:"prefix" t n ~pos:0 ~len:n ~on_error:t +let suffix t n = wrap_sub_n ~name:"suffix" t n ~pos:(length t - n) ~len:n ~on_error:t + +let lfindi ?(pos=0) t ~f = + let n = length t in + let rec loop i = + if i = n then None + else if f i t.[i] then Some i + else loop (i + 1) + in + loop pos +;; + +let find t ~f = + match lfindi t ~f:(fun _ c -> f c) with + | None -> None | Some i -> Some t.[i] + +let find_map t ~f = + let n = length t in + let rec loop i = + if i = n then None + else + match f t.[i] with + | None -> loop (i + 1) + | Some _ as res -> res + in + loop 0 +;; + +let rfindi ?pos t ~f = + let rec loop i = + if i < 0 then None + else begin + if f i t.[i] then Some i + else loop (i - 1) + end + in + let pos = + match pos with + | Some pos -> pos + | None -> length t - 1 + in + loop pos +;; + +let last_non_drop ~drop t = rfindi t ~f:(fun _ c -> not (drop c)) + +let rstrip ?(drop=Char.is_whitespace) t = + match last_non_drop t ~drop with + | None -> "" + | Some i -> + if i = length t - 1 + then t + else prefix t (i + 1) +;; + +let first_non_drop ~drop t = lfindi t ~f:(fun _ c -> not (drop c)) + +let lstrip ?(drop=Char.is_whitespace) t = + match first_non_drop t ~drop with + | None -> "" + | Some 0 -> t + | Some n -> drop_prefix t n +;; + +(* [strip t] could be implemented as [lstrip (rstrip t)]. The implementation + below saves (at least) a factor of two allocation, by only allocating the + final result. This also saves some amount of time. *) +let strip ?(drop=Char.is_whitespace) t = + let length = length t in + if length = 0 || not (drop t.[0] || drop t.[length - 1]) + then t + else + match first_non_drop t ~drop with + | None -> "" + | Some first -> + match last_non_drop t ~drop with + | None -> assert false + | Some last -> sub t ~pos:first ~len:(last - first + 1) +;; + +let mapi t ~f = + let l = length t in + let t' = Bytes.create l in + for i = 0 to l - 1 do + Bytes.unsafe_set t' i (f i t.[i]) + done; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:t' + +(* repeated code to avoid requiring an extra allocation for a closure on each call. *) +let map t ~f = + let l = length t in + let t' = Bytes.create l in + for i = 0 to l - 1 do + Bytes.unsafe_set t' i (f t.[i]) + done; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:t' + +let to_array s = Array.init (length s) ~f:(fun i -> s.[i]) + +let exists = + let rec loop s i ~len ~f = i < len && (f s.[i] || loop s (i + 1) ~len ~f) in + fun s ~f -> loop s 0 ~len:(length s) ~f +;; + +let for_all = + let rec loop s i ~len ~f = i = len || (f s.[i] && loop s (i + 1) ~len ~f) in + fun s ~f -> loop s 0 ~len:(length s) ~f +;; + +let fold t ~init ~f = + let n = length t in + let rec loop i ac = if i = n then ac else loop (i + 1) (f ac t.[i]) in + loop 0 init +;; + +let foldi t ~init ~f = + let n = length t in + let rec loop i ac = if i = n then ac else loop (i + 1) (f i ac t.[i]) in + loop 0 init +;; + +let count t ~f = Container.count ~fold t ~f +let sum m t ~f = Container.sum ~fold m t ~f + +let min_elt t = Container.min_elt ~fold t +let max_elt t = Container.max_elt ~fold t +let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t +let fold_until t ~init ~f = Container.fold_until ~fold ~init ~f t + +let mem = + let rec loop t c ~pos:i ~len = + i < len && (Char.equal c (unsafe_get t i) || loop t c ~pos:(i + 1) ~len) + in + fun t c -> + loop t c ~pos:0 ~len:(length t) +;; + +let tr ~target ~replacement s = + if Char.equal target replacement + then s + else if mem s target + then map s ~f:(fun c -> if Char.equal c target then replacement else c) + else s +;; + +let tr_inplace ~target ~replacement s = (* destructive version of tr *) + for i = 0 to Bytes.length s - 1 do + if Char.equal (Bytes.unsafe_get s i) target then Bytes.unsafe_set s i replacement + done + +let tr_multi ~target ~replacement = + if is_empty target + then stage Fn.id + else if is_empty replacement + then invalid_arg "tr_multi replacement is empty string" + else + match Bytes_tr.tr_create_map ~target ~replacement with + | None -> stage Fn.id + | Some tr_map -> + stage (fun s -> + if exists s ~f:(fun c -> Char.(<>) c (unsafe_get tr_map (Char.to_int c))) + then map s ~f:(fun c -> unsafe_get tr_map (Char.to_int c)) + else s) + +(* fast version, if we ever need it: + {[ + let concat_array ~sep ar = + let ar_len = Array.length ar in + if ar_len = 0 then "" + else + let sep_len = length sep in + let res_len_ref = ref (sep_len * (ar_len - 1)) in + for i = 0 to ar_len - 1 do + res_len_ref := !res_len_ref + length ar.(i) + done; + let res = create !res_len_ref in + let str_0 = ar.(0) in + let len_0 = length str_0 in + blit ~src:str_0 ~src_pos:0 ~dst:res ~dst_pos:0 ~len:len_0; + let pos_ref = ref len_0 in + for i = 1 to ar_len - 1 do + let pos = !pos_ref in + blit ~src:sep ~src_pos:0 ~dst:res ~dst_pos:pos ~len:sep_len; + let new_pos = pos + sep_len in + let str_i = ar.(i) in + let len_i = length str_i in + blit ~src:str_i ~src_pos:0 ~dst:res ~dst_pos:new_pos ~len:len_i; + pos_ref := new_pos + len_i + done; + res + ]} *) + +let concat_array ?sep ar = concat ?sep (Array.to_list ar) + +let concat_map ?sep s ~f = concat_array ?sep (Array.map (to_array s) ~f) + +(* [filter t f] is implemented by the following algorithm. + + Let [n = length t]. + + 1. Find the lowest [i] such that [not (f t.[i])]. + + 2. If there is no such [i], then return [t]. + + 3. If there is such an [i], allocate a string, [out], to hold the result. [out] has + length [n - 1], which is the maximum possible output size given that there is at least + one character not satisfying [f]. + + 4. Copy characters at indices 0 ... [i - 1] from [t] to [out]. + + 5. Walk through characters at indices [i+1] ... [n-1] of [t], copying those that + satisfy [f] from [t] to [out]. + + 6. If we completely filled [out], then return it. If not, return the prefix of [out] + that we did fill in. + + This algorithm has the property that it doesn't allocate a new string if there's + nothing to filter, which is a common case. *) +let filter t ~f = + let n = length t in + let i = ref 0 in + while !i < n && f t.[!i]; do + incr i + done; + if !i = n then + t + else begin + let out = Bytes.create (n - 1) in + blit ~src:t ~src_pos:0 ~dst:out ~dst_pos:0 ~len:!i; + let out_pos = ref !i in + incr i; + while !i < n; do + let c = t.[!i] in + if f c then (Bytes.set out !out_pos c; incr out_pos); + incr i + done; + let out = Bytes.unsafe_to_string ~no_mutation_while_string_reachable:out in + if !out_pos = n - 1 then + out + else + sub out ~pos:0 ~len:!out_pos + end +;; + +let chop_prefix s ~prefix = + if is_prefix s ~prefix then + Some (drop_prefix s (length prefix)) + else + None + +let chop_prefix_exn s ~prefix = + match chop_prefix s ~prefix with + | Some str -> str + | None -> + raise (Invalid_argument + (Printf.sprintf "String.chop_prefix_exn %S %S" s prefix)) + +let chop_suffix s ~suffix = + if is_suffix s ~suffix then + Some (drop_suffix s (length suffix)) + else + None + +let chop_suffix_exn s ~suffix = + match chop_suffix s ~suffix with + | Some str -> str + | None -> + raise (Invalid_argument + (Printf.sprintf "String.chop_suffix_exn %S %S" s suffix)) + +(* There used to be a custom implementation that was faster for very short strings + (peaking at 40% faster for 4-6 char long strings). + This new function is around 20% faster than the default hash function, but slower + than the previous custom implementation. However, the new OCaml function is well + behaved, and this implementation is less likely to diverge from the default OCaml + implementation does, which is a desirable property. (The only way to avoid the + divergence is to expose the macro redefined in hash_stubs.c in the hash.h header of + the OCaml compiler.) *) +module Hash = struct + external hash : string -> int = "Base_hash_string" [@@noalloc] +end + +(* [include Hash] to make the [external] version override the [hash] from + [Hashable.Make_binable], so that we get a little bit of a speedup by exposing it as + external in the mli. *) +let _ = hash +include Hash + +include Comparable.Validate (T) + +(* for interactive top-levels -- modules deriving from String should have String's pretty + printer. *) +let pp = Caml.Format.pp_print_string + +let of_char c = make 1 c + +let of_char_list l = + let t = Bytes.create (List.length l) in + List.iteri l ~f:(fun i c -> Bytes.set t i c); + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:t + +module Escaping = struct + (* If this is changed, make sure to update [escape], which attempts to ensure all the + invariants checked here. *) + let build_and_validate_escapeworthy_map escapeworthy_map escape_char func = + let escapeworthy_map = + if List.Assoc.mem escapeworthy_map ~equal:Char.equal escape_char + then escapeworthy_map + else (escape_char, escape_char) :: escapeworthy_map + in + let arr = Array.create ~len:256 (-1) in + let vals = Array.create ~len:256 false in + let rec loop = function + | [] -> Ok arr + | (c_from, c_to) :: l -> + let k, v = match func with + | `Escape -> Char.to_int c_from, c_to + | `Unescape -> Char.to_int c_to, c_from + in + if arr.(k) <> -1 || vals.(Char.to_int v) then + Or_error.error_s + (Sexp.message "escapeworthy_map not one-to-one" + [ "c_from", sexp_of_char c_from + ; "c_to", sexp_of_char c_to + ; "escapeworthy_map", + sexp_of_list (sexp_of_pair sexp_of_char sexp_of_char) + escapeworthy_map + ]) + else (arr.(k) <- Char.to_int v; vals.(Char.to_int v) <- true; loop l) + in + loop escapeworthy_map + ;; + + let escape_gen ~escapeworthy_map ~escape_char = + match + build_and_validate_escapeworthy_map escapeworthy_map escape_char `Escape + with + | Error _ as x -> x + | Ok escapeworthy -> + Ok (fun src -> + (* calculate a list of (index of char to escape * escaped char) first, the order + is from tail to head *) + let to_escape_len = ref 0 in + let to_escape = + foldi src ~init:[] ~f:(fun i acc c -> + match escapeworthy.(Char.to_int c) with + | -1 -> acc + | n -> + (* (index of char to escape * escaped char) *) + incr to_escape_len; + (i, Char.unsafe_of_int n) :: acc) + in + match to_escape with + | [] -> src + | _ -> + (* [to_escape] divide [src] to [List.length to_escape + 1] pieces separated by + the chars to escape. + + Lets take + {[ + escape_gen_exn + ~escapeworthy_map:[('a', 'A'); ('b', 'B'); ('c', 'C')] + ~escape_char:'_' + ]} + for example, and assume the string to escape is + + "000a111b222c333" + + then [to_escape] is [(11, 'C'); (7, 'B'); (3, 'A')]. + + Then we create a [dst] of length [length src + 3] to store the + result, copy piece "333" to [dst] directly, then copy '_' and 'C' to [dst]; + then move on to next; after 3 iterations, copy piece "000" and we are done. + + Finally the result will be + + "000_A111_B222_C333" *) + let src_len = length src in + let dst_len = src_len + !to_escape_len in + let dst = Bytes.create dst_len in + let rec loop last_idx last_dst_pos = function + | [] -> + (* copy "000" at last *) + blit ~src ~src_pos:0 ~dst ~dst_pos:0 ~len:last_idx + | (idx, escaped_char) :: to_escape -> (*[idx] = the char to escape*) + (* take first iteration for example *) + (* calculate length of "333", minus 1 because we don't copy 'c' *) + let len = last_idx - idx - 1 in + (* set the dst_pos to copy to *) + let dst_pos = last_dst_pos - len in + (* copy "333", set [src_pos] to [idx + 1] to skip 'c' *) + blit ~src ~src_pos:(idx + 1) ~dst ~dst_pos ~len; + (* backoff [dst_pos] by 2 to copy '_' and 'C' *) + let dst_pos = dst_pos - 2 in + Bytes.set dst dst_pos escape_char; + Bytes.set dst (dst_pos + 1) escaped_char; + loop idx dst_pos to_escape + in + (* set [last_dst_pos] and [last_idx] to length of [dst] and [src] first *) + loop src_len dst_len to_escape; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:dst + ) + ;; + + let escape_gen_exn ~escapeworthy_map ~escape_char = + Or_error.ok_exn (escape_gen ~escapeworthy_map ~escape_char) |> stage + ;; + + let escape ~escapeworthy ~escape_char = + (* For [escape_gen_exn], we don't know how to fix invalid escapeworthy_map so we have + to raise exception; but in this case, we know how to fix duplicated elements in + escapeworthy list, so we just fix it instead of raising exception to make this + function easier to use. *) + let escapeworthy_map = + escapeworthy + |> List.dedup_and_sort ~compare:Char.compare + |> List.map ~f:(fun c -> (c, c)) + in + escape_gen_exn ~escapeworthy_map ~escape_char + ;; + + (* In an escaped string, any char is either `Escaping, `Escaped or `Literal. For + example, the escape statuses of chars in string "a_a__" with escape_char = '_' are + + a : `Literal + _ : `Escaping + a : `Escaped + _ : `Escaping + _ : `Escaped + + [update_escape_status str ~escape_char i previous_status] gets escape status of + str.[i] basing on escape status of str.[i - 1] *) + let update_escape_status str ~escape_char i = function + | `Escaping -> `Escaped + | `Literal + | `Escaped -> if Char.equal str.[i] escape_char then `Escaping else `Literal + ;; + + let unescape_gen ~escapeworthy_map ~escape_char = + match + build_and_validate_escapeworthy_map escapeworthy_map escape_char `Unescape + with + | Error _ as x -> x + | Ok escapeworthy -> + Ok (fun src -> + (* Continue the example in [escape_gen_exn], now we unescape + + "000_A111_B222_C333" + + back to + + "000a111b222c333" + + Then [to_unescape] is [14; 9; 4], which is indexes of '_'s. + + Then we create a string [dst] to store the result, copy "333" to it, then copy + 'c', then move on to next iteration. After 3 iterations copy "000" and we are + done. *) + (* indexes of escape chars *) + let to_unescape = + let rec loop i status acc = + if i >= length src then acc + else + let status = update_escape_status src ~escape_char i status in + loop (i + 1) status + (match status with + | `Escaping -> i :: acc + | `Escaped | `Literal -> acc) + in + loop 0 `Literal [] + in + match to_unescape with + | [] -> src + | idx::to_unescape' -> + let dst = Bytes.create (length src - List.length to_unescape) in + let rec loop last_idx last_dst_pos = function + | [] -> + (* copy "000" at last *) + blit ~src ~src_pos:0 ~dst ~dst_pos:0 ~len:last_idx + | idx::to_unescape -> (* [idx] = index of escaping char *) + (* take 1st iteration as example, calculate the length of "333", minus 2 to + skip '_C' *) + let len = last_idx - idx - 2 in + (* point [dst_pos] to the position to copy "333" to *) + let dst_pos = last_dst_pos - len in + (* copy "333" *) + blit ~src ~src_pos:(idx + 2) ~dst ~dst_pos ~len; + (* backoff [dst_pos] by 1 to copy 'c' *) + let dst_pos = dst_pos - 1 in + Bytes.set dst dst_pos ( match escapeworthy.(Char.to_int src.[idx + 1]) with + | -1 -> src.[idx + 1] + | n -> Char.unsafe_of_int n); + (* update [last_dst_pos] and [last_idx] *) + loop idx dst_pos to_unescape + in + ( if idx < length src - 1 then + (* set [last_dst_pos] and [last_idx] to length of [dst] and [src] *) + loop (length src) (Bytes.length dst) to_unescape + else + (* for escaped string ending with an escaping char like "000_", just ignore + the last escaping char *) + loop (length src - 1) (Bytes.length dst) to_unescape' + ); + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:dst + ) + ;; + + let unescape_gen_exn ~escapeworthy_map ~escape_char = + Or_error.ok_exn (unescape_gen ~escapeworthy_map ~escape_char) |> stage + ;; + + let unescape ~escape_char = + unescape_gen_exn ~escapeworthy_map:[] ~escape_char + + let preceding_escape_chars str ~escape_char pos = + let rec loop p cnt = + if (p < 0) || (Char.( <> ) str.[p] escape_char) then + cnt + else + loop (p - 1) (cnt + 1) + in + loop (pos - 1) 0 + ;; + + (* In an escaped string, any char is either `Escaping, `Escaped or `Literal. For + example, the escape statuses of chars in string "a_a__" with escape_char = '_' are + + a : `Literal + _ : `Escaping + a : `Escaped + _ : `Escaping + _ : `Escaped + + [update_escape_status str ~escape_char i previous_status] gets escape status of + str.[i] basing on escape status of str.[i - 1] *) + let update_escape_status str ~escape_char i = function + | `Escaping -> `Escaped + | `Literal + | `Escaped -> if Char.equal str.[i] escape_char then `Escaping else `Literal + ;; + + let escape_status str ~escape_char pos = + let odd = (preceding_escape_chars str ~escape_char pos) mod 2 = 1 in + match odd, Char.equal str.[pos] escape_char with + | true, (true|false) -> `Escaped + | false, true -> `Escaping + | false, false -> `Literal + ;; + + let check_bound str pos function_name = + if pos >= length str || pos < 0 then + invalid_argf "%s: out of bounds" function_name () + ;; + + let is_char_escaping str ~escape_char pos = + check_bound str pos "is_char_escaping"; + match escape_status str ~escape_char pos with + | `Escaping -> true + | `Escaped | `Literal -> false + ;; + + let is_char_escaped str ~escape_char pos = + check_bound str pos "is_char_escaped"; + match escape_status str ~escape_char pos with + | `Escaped -> true + | `Escaping | `Literal -> false + ;; + + let is_char_literal str ~escape_char pos = + check_bound str pos "is_char_literal"; + match escape_status str ~escape_char pos with + | `Literal -> true + | `Escaped | `Escaping -> false + ;; + + let index_from str ~escape_char pos char = + check_bound str pos "index_from"; + let rec loop i status = + if i >= pos + && (match status with `Literal -> true | `Escaped | `Escaping -> false) + && Char.equal str.[i] char + then Some i + else ( + let i = i + 1 in + if i >= length str then None + else loop i (update_escape_status str ~escape_char i status)) + in + loop pos (escape_status str ~escape_char pos) + ;; + + let index_from_exn str ~escape_char pos char = + match index_from str ~escape_char pos char with + | None -> + raise_s + (Sexp.message "index_from_exn: not found" + [ "str" , sexp_of_t str + ; "escape_char" , sexp_of_char escape_char + ; "pos" , sexp_of_int pos + ; "char" , sexp_of_char char + ]) + | Some pos -> pos + ;; + + let index str ~escape_char char = index_from str ~escape_char 0 char + let index_exn str ~escape_char char = index_from_exn str ~escape_char 0 char + + let rindex_from str ~escape_char pos char = + check_bound str pos "rindex_from"; + (* if the target char is the same as [escape_char], we have no way to determine which + escape_char is literal, so just return None *) + if Char.equal char escape_char then None + else + let rec loop pos = + if pos < 0 then None + else ( + let escape_chars = preceding_escape_chars str ~escape_char pos in + if escape_chars mod 2 = 0 && Char.equal str.[pos] char + then Some pos else loop (pos - escape_chars - 1)) + in + loop pos + ;; + + let rindex_from_exn str ~escape_char pos char = + match rindex_from str ~escape_char pos char with + | None -> + raise_s + (Sexp.message "rindex_from_exn: not found" + [ "str" , sexp_of_t str + ; "escape_char" , sexp_of_char escape_char + ; "pos" , sexp_of_int pos + ; "char" , sexp_of_char char + ]) + | Some pos -> pos + ;; + + let rindex str ~escape_char char = + if is_empty str + then None + else rindex_from str ~escape_char (length str - 1) char + ;; + + let rindex_exn str ~escape_char char = + rindex_from_exn str ~escape_char (length str - 1) char + ;; + + (* [split_gen str ~escape_char ~on] works similarly to [String.split_gen], with an + additional requirement: only split on literal chars, not escaping or escaped *) + let split_gen str ~escape_char ~on = + let is_delim = match on with + | `char c' -> (fun c -> Char.equal c c') + | `char_list l -> (fun c -> char_list_mem l c) + in + let len = length str in + let rec loop acc status last_pos pos = + if pos = len then + List.rev (sub str ~pos:last_pos ~len:(len - last_pos) :: acc) + else + let status = update_escape_status str ~escape_char pos status in + if (match status with `Literal -> true | `Escaped | `Escaping -> false) + && is_delim str.[pos] + then ( + let sub_str = sub str ~pos:last_pos ~len:(pos - last_pos) in + loop (sub_str :: acc) status (pos + 1) (pos + 1)) + else loop acc status last_pos (pos + 1) + in + loop [] `Literal 0 0 + ;; + + let split str ~on = split_gen str ~on:(`char on) ;; + + let split_on_chars str ~on:chars = + split_gen str ~on:(`char_list chars) + ;; + + let split_at str pos = + sub str ~pos:0 ~len:pos, + sub str ~pos:(pos + 1) ~len:(length str - pos - 1) + ;; + + let lsplit2 str ~on ~escape_char = + Option.map (index str ~escape_char on) ~f:(fun x -> split_at str x) + ;; + + let rsplit2 str ~on ~escape_char = + Option.map (rindex str ~escape_char on) ~f:(fun x -> split_at str x) + ;; + + let lsplit2_exn str ~on ~escape_char = + split_at str (index_exn str ~escape_char on) + ;; + let rsplit2_exn str ~on ~escape_char = + split_at str (rindex_exn str ~escape_char on) + ;; + + (* [last_non_drop_literal] and [first_non_drop_literal] are either both [None] or both + [Some]. If [Some], then the former is >= the latter. *) + let last_non_drop_literal ~drop ~escape_char t = + rfindi t ~f:(fun i c -> + not (drop c) + || is_char_escaping t ~escape_char i + || is_char_escaped t ~escape_char i) + let first_non_drop_literal ~drop ~escape_char t = + lfindi t ~f:(fun i c -> + not (drop c) + || is_char_escaping t ~escape_char i + || is_char_escaped t ~escape_char i) + + let rstrip_literal ?(drop=Char.is_whitespace) t ~escape_char = + match last_non_drop_literal t ~drop ~escape_char with + | None -> "" + | Some i -> + if i = length t - 1 + then t + else prefix t (i + 1) + ;; + + let lstrip_literal ?(drop=Char.is_whitespace) t ~escape_char = + match first_non_drop_literal t ~drop ~escape_char with + | None -> "" + | Some 0 -> t + | Some n -> drop_prefix t n + ;; + + (* [strip t] could be implemented as [lstrip (rstrip t)]. The implementation + below saves (at least) a factor of two allocation, by only allocating the + final result. This also saves some amount of time. *) + let strip_literal ?(drop=Char.is_whitespace) t ~escape_char = + let length = length t in + (* performance hack: avoid copying [t] in common cases *) + if length = 0 || not (drop t.[0] || drop t.[length - 1]) + then t + else + match first_non_drop_literal t ~drop ~escape_char with + | None -> "" + | Some first -> + match last_non_drop_literal t ~drop ~escape_char with + | None -> assert false + | Some last -> sub t ~pos:first ~len:(last - first + 1) + ;; +end + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! String_replace_polymorphic_compare + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let create = Bytes.create +let fill = Bytes.fill + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include String_replace_polymorphic_compare diff --git a/src/string.mli b/src/string.mli new file mode 100644 index 0000000..3bf1d2c --- /dev/null +++ b/src/string.mli @@ -0,0 +1,461 @@ +(** An extension of the standard [StringLabels]. If you [open Base], you'll get these + extensions in the [String] module. *) + +open! Import + +type t = string [@@deriving_inline hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + + +val blit : (t, bytes) Blit.blit [@@deprecated "[since 2017-10] Use [Bytes.blit] instead"] +val blito : (t, bytes) Blit.blito [@@deprecated "[since 2017-10] Use [Bytes.blito] instead"] +val unsafe_blit : (t, bytes) Blit.blit [@@deprecated "[since 2017-10] Use [Bytes.unsafe_blit] instead"] + +val sub : (t, t) Blit.sub +val subo : (t, t) Blit.subo + +include Container.S0 with type t := t with type elt = char +include Identifiable.S with type t := t + +(** Maximum length of a string. *) +val max_length : int + +external length : t -> int = "%string_length" +external get : t -> int -> char = "%string_safe_get" + +(** [unsafe_get t i] is like [get t i] but does not perform bounds checking. The caller + must ensure that it is a memory-safe operation. *) +external unsafe_get : string -> int -> char = "%string_unsafe_get" + +val create : int -> bytes [@@deprecated "[since 2017-10] Use [Bytes.create] instead"] + +val make : int -> char -> t + +(** Assuming you haven't passed -unsafe-string to the compiler, strings are immutable, so + there'd be no motivation to make a copy. *) +val copy : t -> t [@@deprecated "[since 2018-03] Use [Bytes.copy] instead"] + +val init : int -> f:(int -> char) -> t + +val fill : bytes -> pos:int -> len:int -> char -> unit [@@deprecated "[since 2017-10] Use [Bytes.fill] instead"] + +(** String append. Also available unqualified, but re-exported here for documentation + purposes. + + Note that [a ^ b] must copy both [a] and [b] into a newly-allocated result string, so + [a ^ b ^ c ^ ... ^ z] is quadratic in the number of strings. [String.concat] does not + have this problem -- it allocates the result buffer only once. *) +val ( ^ ) : t -> t -> t + +(** Concatenates all strings in the list using separator [sep] (with a default separator + [""]). *) +val concat : ?sep:t -> t list -> t + +(** Special characters are represented by escape sequences, following the lexical + conventions of OCaml. *) +val escaped : t -> t + +val contains : ?pos:int -> ?len:int -> t -> char -> bool + +(** Operates on the whole string using the US-ASCII character set, + e.g. [uppercase "foo" = "FOO"]. *) +val uppercase : t -> t +val lowercase : t -> t + +(** Operates on just the first character using the US-ASCII character set, + e.g. [capitalize "foo" = "Foo"]. *) +val capitalize : t -> t +val uncapitalize : t -> t + +(** [index] gives the index of the first appearance of [char] in the string when + searching from left to right, or [None] if it's not found. [rindex] does the same but + searches from the right. + + For example, [String.index "Foo" 'o'] is [Some 1] while [String.rindex "Foo" 'o'] is + [Some 2]. + + The [_exn] versions return the actual index (instead of an option) when [char] is + found, and throw an exception otherwise. +*) + +(** [Caseless] compares and hashes strings ignoring case, so that for example + [Caseless.equal "OCaml" "ocaml"] and [Caseless.("apple" < "Banana")] are [true], and + [Caseless.Map], [Caseless.Table] lookup and [Caseless.Set] membership is + case-insensitive. + + [Caseless] also provides case-insensitive [is_suffix] and [is_prefix] functions, so + that for example [Caseless.is_suffix "OCaml" ~suffix:"AmL"] and [Caseless.is_prefix + "OCaml" ~prefix:"oc"] are [true]. *) +module Caseless : sig + type nonrec t = t [@@deriving_inline hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Comparable.S with type t := t + + val is_suffix : t -> suffix:t -> bool + val is_prefix : t -> prefix:t -> bool +end + +(** [index_exn] and [index_from_exn] raise [Caml.Not_found] or [Not_found_s] when [char] + cannot be found in [s]. *) +val index : t -> char -> int option +val index_exn : t -> char -> int +val index_from : t -> int -> char -> int option +val index_from_exn : t -> int -> char -> int + +(** [rindex_exn] and [rindex_from_exn] raise [Caml.Not_found] or [Not_found_s] when [char] + cannot be found in [s]. *) +val rindex : t -> char -> int option +val rindex_exn : t -> char -> int +val rindex_from : t -> int -> char -> int option +val rindex_from_exn : t -> int -> char -> int + +(** Substring search and replace functions. They use the Knuth-Morris-Pratt algorithm + (KMP) under the hood. + + The functions in the [Search_pattern] module allow the program to preprocess the + searched pattern once and then use it many times without further allocations. *) +module Search_pattern : sig + + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** [create pattern] preprocesses [pattern] as per KMP, building an [int array] of + length [length pattern]. All inputs are valid. *) + val create : string -> t + + (** [matches pat str] returns true if [str] matches [pat] *) + val matches : t -> string -> bool + + (** [pos < 0] or [pos >= length string] result in no match (hence [index] returns + [None] and [index_exn] raises). *) + val index : ?pos:int -> t -> in_:string -> int option + val index_exn : ?pos:int -> t -> in_:string -> int + + (** [may_overlap] determines whether after a successful match, [index_all] should start + looking for another one at the very next position ([~may_overlap:true]), or jump to + the end of that match and continue from there ([~may_overlap:false]), e.g.: + + - [index_all (create "aaa") ~may_overlap:false ~in_:"aaaaBaaaaaa" = [0; 5; 8]] + - [index_all (create "aaa") ~may_overlap:true ~in_:"aaaaBaaaaaa" = [0; 1; 5; 6; 7; + 8]] + + E.g., [replace_all] internally calls [index_all ~may_overlap:false]. *) + val index_all : t -> may_overlap:bool -> in_:string -> int list + + (** Note that the result of [replace_all pattern ~in_:text ~with_:r] may still + contain [pattern], e.g., + + {[ + replace_all (create "bc") ~in_:"aabbcc" ~with_:"cb" = "aabcbc" + ]} *) + val replace_first : ?pos:int -> t -> in_:string -> with_:string -> string + val replace_all : t -> in_:string -> with_:string -> string +end + +(** Substring search and replace convenience functions. They call [Search_pattern.create] + and then forget the preprocessed pattern when the search is complete. [pos < 0] or + [pos >= length t] result in no match (hence [substr_index] returns [None] and + [substr_index_exn] raises). [may_overlap] indicates whether to report overlapping + matches, see [Search_pattern.index_all]. *) +val substr_index : ?pos:int -> t -> pattern:t -> int option +val substr_index_exn : ?pos:int -> t -> pattern:t -> int +val substr_index_all : t -> may_overlap:bool -> pattern:t -> int list + +val substr_replace_first : ?pos:int -> t -> pattern:t -> with_:t -> t + +(** As with [Search_pattern.replace_all], the result may still contain [pattern]. *) +val substr_replace_all : t -> pattern:t -> with_:t -> t + +(** [is_substring ~substring:"bar" "foo bar baz"] is true. *) +val is_substring : t -> substring:t -> bool + +(** [is_substring_at "foo bar baz" ~pos:4 ~substring:"bar"] is true. *) +val is_substring_at : t -> pos:int -> substring:t -> bool + +(** Returns the reversed list of characters contained in a list. *) +val to_list_rev : t -> char list + +(** [rev t] returns [t] in reverse order. *) +val rev : t -> t + +(** [is_suffix s ~suffix] returns [true] if [s] ends with [suffix]. *) +val is_suffix : t -> suffix:t -> bool + +(** [is_prefix s ~prefix] returns [true] if [s] starts with [prefix]. *) +val is_prefix : t -> prefix:t -> bool + +(** If the string [s] contains the character [on], then [lsplit2_exn s ~on] returns a pair + containing [s] split around the first appearance of [on] (from the left). Raises + [Caml.Not_found] or [Not_found_s] when [on] cannot be found in [s]. *) +val lsplit2_exn : t -> on:char -> t * t + +(** If the string [s] contains the character [on], then [rsplit2_exn s ~on] returns a pair + containing [s] split around the first appearance of [on] (from the right). Raises + [Caml.Not_found] or [Not_found_s] when [on] cannot be found in [s]. *) +val rsplit2_exn : t -> on:char -> t * t + +(** [lsplit2 s ~on] optionally returns [s] split into two strings around the + first appearance of [on] from the left. *) +val lsplit2 : t -> on:char -> (t * t) option + +(** [rsplit2 s ~on] optionally returns [s] split into two strings around the first + appearance of [on] from the right. *) +val rsplit2 : t -> on:char -> (t * t) option + +(** [split s ~on] returns a list of substrings of [s] that are separated by [on]. + Consecutive [on] characters will cause multiple empty strings in the result. + Splitting the empty string returns a list of the empty string, not the empty list. *) +val split : t -> on:char -> t list + +(** [split_on_chars s ~on] returns a list of all substrings of [s] that are separated by + one of the chars from [on]. [on] are not grouped. So a grouping of [on] in the + source string will produce multiple empty string splits in the result. *) +val split_on_chars : t -> on:char list -> t list + +(** [split_lines t] returns the list of lines that comprise [t]. The lines do not include + the trailing ["\n"] or ["\r\n"]. *) +val split_lines : t -> t list + +(** [lfindi ?pos t ~f] returns the smallest [i >= pos] such that [f i t.[i]], if there is + such an [i]. By default, [pos = 0]. *) +val lfindi : ?pos : int -> t -> f:(int -> char -> bool) -> int option + +(** [rfindi ?pos t ~f] returns the largest [i <= pos] such that [f i t.[i]], if there is + such an [i]. By default [pos = length t - 1]. *) +val rfindi : ?pos : int -> t -> f:(int -> char -> bool) -> int option + +(** [lstrip ?drop s] returns a string with consecutive chars satisfying [drop] (by default + white space, e.g. tabs, spaces, newlines, and carriage returns) stripped from the + beginning of [s]. *) +val lstrip : ?drop:(char -> bool) -> t -> t + +(** [rstrip ?drop s] returns a string with consecutive chars satisfying [drop] (by default + white space, e.g. tabs, spaces, newlines, and carriage returns) stripped from the end + of [s]. *) +val rstrip : ?drop:(char -> bool) -> t -> t + +(** [strip ?drop s] returns a string with consecutive chars satisfying [drop] (by default + white space, e.g. tabs, spaces, newlines, and carriage returns) stripped from the + beginning and end of [s]. *) +val strip : ?drop:(char -> bool) -> t -> t + +val map : t -> f : (char -> char) -> t + +(** Like [map], but passes each character's index to [f] along with the char. *) +val mapi : t -> f : (int -> char -> char) -> t + +(** [foldi] works similarly to [fold], but also passes the index of each character to + [f]. *) +val foldi : t -> init : 'a -> f : (int -> 'a -> char -> 'a) -> 'a + +(** Like [map], but allows the replacement of a single character with zero or two or more + characters. *) +val concat_map : ?sep:t -> t -> f : (char -> t) -> t + +(** [filter s ~f:predicate] discards characters not satisfying [predicate]. *) +val filter : t -> f : (char -> bool) -> t + +(** [tr ~target ~replacement s] replaces every instance of [target] in [s] with + [replacement]. *) +val tr : target:char -> replacement:char -> t -> t + +(** [tr_inplace ~target ~replacement s] destructively modifies [s] (in place!), replacing + every instance of [target] in [s] with [replacement]. *) +val tr_inplace : target:char -> replacement:char -> bytes -> unit +[@@deprecated "[since 2017-10] Use [Bytes.tr] instead"] + +(** [tr_multi ~target ~replacement] returns a function that replaces every + instance of a character in [target] with the corresponding character in + [replacement]. + + If [replacement] is shorter than [target], it is lengthened by repeating + its last character. Empty [replacement] is illegal unless [target] also is. + + If [target] contains multiple copies of the same character, the last + corresponding [replacement] character is used. Note that character ranges + are {b not} supported, so [~target:"a-z"] means the literal characters ['a'], + ['-'], and ['z']. *) +val tr_multi : target:t -> replacement:t -> (t -> t) Staged.t + +(** [chop_suffix_exn s ~suffix] returns [s] without the trailing [suffix], + raising [Invalid_argument] if [suffix] is not a suffix of [s]. *) +val chop_suffix_exn : t -> suffix:t -> t + +(** [chop_prefix_exn s ~prefix] returns [s] without the leading [prefix], + raising [Invalid_argument] if [prefix] is not a prefix of [s]. *) +val chop_prefix_exn : t -> prefix:t -> t + +val chop_suffix : t -> suffix:t -> t option + +val chop_prefix : t -> prefix:t -> t option + +(** [suffix s n] returns the longest suffix of [s] of length less than or equal to [n]. *) +val suffix : t -> int -> t + +(** [prefix s n] returns the longest prefix of [s] of length less than or equal to [n]. *) +val prefix : t -> int -> t + +(** [drop_suffix s n] drops the longest suffix of [s] of length less than or equal to + [n]. *) +val drop_suffix : t -> int -> t + +(** [drop_prefix s n] drops the longest prefix of [s] of length less than or equal to + [n]. *) +val drop_prefix : t -> int -> t + +(** [concat_array sep ar] like {!String.concat}, but operates on arrays. *) +val concat_array : ?sep : t -> t array -> t + +(** Slightly faster hash function on strings. *) +external hash : t -> int = "Base_hash_string" [@@noalloc] + +(** Fast equality function on strings, doesn't use [compare_val]. *) +val equal : t -> t -> bool + +val of_char : char -> t + +val of_char_list : char list -> t + +(** Operations for escaping and unescaping strings, with parameterized escape and + escapeworthy characters. Escaping/unescaping using this module is more efficient than + using Pcre. Benchmark code can be found in core/benchmarks/string_escaping.ml. *) +module Escaping : sig + (** [escape_gen_exn escapeworthy_map escape_char] returns a function that will escape a + string [s] as follows: if [(c1,c2)] is in [escapeworthy_map], then all occurrences + of [c1] are replaced by [escape_char] concatenated to [c2]. + + Raises an exception if [escapeworthy_map] is not one-to-one. If [escape_char] is + not in [escapeworthy_map], then it will be escaped to itself.*) + val escape_gen_exn + : escapeworthy_map:(char * char) list + -> escape_char:char + -> (string -> string) Staged.t + + val escape_gen + : escapeworthy_map:(char * char) list + -> escape_char:char + -> (string -> string) Or_error.t + + (** [escape ~escapeworthy ~escape_char s] is + {[ + escape_gen_exn ~escapeworthy_map:(List.zip_exn escapeworthy escapeworthy) + ~escape_char + ]} + Duplicates and [escape_char] will be removed from [escapeworthy]. So, no + exception will be raised *) + val escape : escapeworthy:char list -> escape_char:char -> (string -> string) Staged.t + + (** [unescape_gen_exn] is the inverse operation of [escape_gen_exn]. That is, + {[ + let escape = Staged.unstage (escape_gen_exn ~escapeworthy_map ~escape_char) in + let unescape = Staged.unstage (unescape_gen_exn ~escapeworthy_map ~escape_char) in + assert (s = unescape (escape s)) + ]} + always succeed when ~escapeworthy_map is not causing exceptions. *) + val unescape_gen_exn + : escapeworthy_map:(char * char) list + -> escape_char:char + -> (string -> string) Staged.t + + val unescape_gen + : escapeworthy_map:(char * char) list + -> escape_char:char + -> (string -> string) Or_error.t + + (** [unescape ~escape_char] is defined as [unescape_gen_exn ~map:\[\] ~escape_char] *) + val unescape : escape_char:char -> (string -> string) Staged.t + + (** Any char in an escaped string is either escaping, escaped, or literal. For example, + for escaped string ["0_a0__0"] with [escape_char] as ['_'], pos 1 and 4 are + escaping, 2 and 5 are escaped, and the rest are literal. + + [is_char_escaping s ~escape_char pos] returns true if the char at [pos] is escaping, + false otherwise. *) + val is_char_escaping : string -> escape_char:char -> int -> bool + + (** [is_char_escaped s ~escape_char pos] returns true if the char at [pos] is escaped, + false otherwise. *) + val is_char_escaped : string -> escape_char:char -> int -> bool + + (** [is_char_literal s ~escape_char pos] returns true if the char at [pos] is not + escaped or escaping. *) + val is_char_literal : string -> escape_char:char -> int -> bool + + (** [index s ~escape_char char] finds the first literal (not escaped) instance of [char] + in s starting from 0. *) + val index : string -> escape_char:char -> char -> int option + val index_exn : string -> escape_char:char -> char -> int + + (** [rindex s ~escape_char char] finds the first literal (not escaped) instance of + [char] in [s] starting from the end of [s] and proceeding towards 0. *) + val rindex : string -> escape_char:char -> char -> int option + val rindex_exn : string -> escape_char:char -> char -> int + + (** [index_from s ~escape_char pos char] finds the first literal (not escaped) instance + of [char] in [s] starting from [pos] and proceeding towards the end of [s]. *) + val index_from : string -> escape_char:char -> int -> char -> int option + val index_from_exn : string -> escape_char:char -> int -> char -> int + + (** [rindex_from s ~escape_char pos char] finds the first literal (not escaped) + instance of [char] in [s] starting from [pos] and towards 0. *) + val rindex_from : string -> escape_char:char -> int -> char -> int option + val rindex_from_exn : string -> escape_char:char -> int -> char -> int + + (** [split s ~escape_char ~on] returns a list of substrings of [s] that are separated by + literal versions of [on]. Consecutive [on] characters will cause multiple empty + strings in the result. Splitting the empty string returns a list of the empty + string, not the empty list. + + E.g., [split ~escape_char:'_' ~on:',' "foo,bar_,baz" = ["foo"; "bar_,baz"]]. *) + val split : string -> on:char -> escape_char:char -> string list + + (** [split_on_chars s ~on] returns a list of all substrings of [s] that are separated by + one of the literal chars from [on]. [on] are not grouped. So a grouping of [on] in + the source string will produce multiple empty string splits in the result. + + E.g., [split_on_chars ~escape_char:'_' ~on:[',';'|'] "foo_|bar,baz|0" -> + ["foo_|bar"; "baz"; "0"]]. *) + val split_on_chars : string -> on:char list -> escape_char:char -> string list + + (** [lsplit2 s ~on ~escape_char] splits s into a pair on the first literal instance of + [on] (meaning the first unescaped instance) starting from the left. *) + val lsplit2 : string -> on:char -> escape_char:char -> (string * string) option + val lsplit2_exn : string -> on:char -> escape_char:char -> (string * string) + + (** [rsplit2 s ~on ~escape_char] splits [s] into a pair on the first literal + instance of [on] (meaning the first unescaped instance) starting from the + right. *) + val rsplit2 : string -> on:char -> escape_char:char -> (string * string) option + val rsplit2_exn : string -> on:char -> escape_char:char -> (string * string) + + (** These are the same as [lstrip], [rstrip], and [strip] for generic strings, except + that they only drop literal characters -- they do not drop characters that are + escaping or escaped. This makes sense if you're trying to get rid of junk + whitespace (for example), because escaped whitespace seems more likely to be + deliberate and not junk. *) + val lstrip_literal : ?drop:(char -> bool) -> t -> escape_char:char -> t + val rstrip_literal : ?drop:(char -> bool) -> t -> escape_char:char -> t + val strip_literal : ?drop:(char -> bool) -> t -> escape_char:char -> t +end + +val set : bytes -> int -> char -> unit [@@deprecated "[since 2017-10] Use [Bytes.set] instead"] +val unsafe_set : bytes -> int -> char -> unit [@@deprecated "[since 2017-10] Use [Bytes.unsafe_set] instead"] diff --git a/src/string0.ml b/src/string0.ml new file mode 100644 index 0000000..1cb6a9b --- /dev/null +++ b/src/string0.ml @@ -0,0 +1,62 @@ +(* [String0] defines string functions that are primitives or can be simply defined in + terms of [Caml.String]. [String0] is intended to completely express the part of + [Caml.String] that [Base] uses -- no other file in Base other than string0.ml should + use [Caml.String]. [String0] has few dependencies, and so is available early in Base's + build order. + + All Base files that need to use strings, including the subscript syntax + [x.(i)] or [x.(i) <- e] which the OCaml parser desugars into calls to + [String], and come before [Base.String] in build order should do + + {[ + module String = String0] + ]} + + Defining [module String = String0] is also necessary because it prevents + ocamldep from mistakenly causing a file to depend on [Base.String]. *) + +let capitalize = Caml.String.capitalize_ascii +let lowercase = Caml.String.lowercase_ascii +let uncapitalize = Caml.String.uncapitalize_ascii +let uppercase = Caml.String.uppercase_ascii + +open! Import0 + +module Sys = Sys0 + +module String = struct + external get : string -> int -> char = "%string_safe_get" + external length : string -> int = "%string_length" + external unsafe_get : string -> int -> char = "%string_unsafe_get" + + include Bytes_set_primitives +end + +include String + +let max_length = Sys.max_string_length + +let (^) = (^) + +let blit = Caml.String.blit +let compare = Caml.String.compare +let copy = Caml.String.copy [@@warning "-3"] +let escaped = Caml.String.escaped +let index_exn = Caml.String.index +let index_from_exn = Caml.String.index_from +let make = Caml.String.make +let rindex_exn = Caml.String.rindex +let rindex_from_exn = Caml.String.rindex_from +let sub = Caml.String.sub +let unsafe_blit = Caml.String.unsafe_blit + +let concat ?(sep = "") l = + match l with + | [] -> "" + (* The stdlib does not specialize this case because it could break existing projects. *) + | [x] -> x + | l -> Caml.String.concat ~sep l + +(* These are eta expanded in order to permute parameter order to follow Base + conventions. *) +let iter t ~f = Caml.String.iter t ~f diff --git a/src/stringable.ml b/src/stringable.ml new file mode 100644 index 0000000..f5efcd0 --- /dev/null +++ b/src/stringable.ml @@ -0,0 +1,10 @@ +(** Provides type-specific conversion functions to and from [string]. *) + +open! Import + +module type S = sig + type t + + val of_string : string -> t + val to_string : t -> string +end diff --git a/src/sys.ml b/src/sys.ml new file mode 100644 index 0000000..ad91468 --- /dev/null +++ b/src/sys.ml @@ -0,0 +1,3 @@ +open! Import + +include Sys0 diff --git a/src/sys.mli b/src/sys.mli new file mode 100644 index 0000000..5999112 --- /dev/null +++ b/src/sys.mli @@ -0,0 +1,95 @@ +(** Cross-platform system configuration values. *) + +(** The command line arguments given to the process. + The first element is the command name used to invoke the program. + The following elements are the command-line arguments given to the program. + + When running in JavaScript in the browser, it is [[| "a.out" |]]. *) +val argv : string array + +(** [interactive] is set to [true] when being executed in the [ocaml] REPL, and [false] + otherwise. *) +val interactive : bool ref + +(** [os_type] describes the operating system that the OCaml program is running on. Its + value is one of ["Unix"], ["Win32"], or ["Cygwin"]. When running in JavaScript, it is + ["Unix"]. *) +val os_type : string + +(** [unix] is [true] if [os_type = "Unix"]. *) +val unix : bool + +(** [win32] is [true] if [os_type = "Win32"]. *) +val win32 : bool + +(** [cygwin] is [true] if [os_type = "Cygwin"]. *) +val cygwin : bool + +(** Currently, the official distribution only supports [Native] and [Bytecode], + but it can be other backends with alternative compilers, for example, + JavaScript. *) +type backend_type = Sys0.backend_type = + | Native + | Bytecode + | Other of string + +(** Backend type currently executing the OCaml program. *) +val backend_type : backend_type + +(** [word_size_in_bits] is the number of bits in one word on the machine currently + executing the OCaml program. Generally speaking it will be either [32] or [64]. When + running in JavaScript, it will be [32]. *) +val word_size_in_bits : int + +(** [int_size_in_bits] is the number of bits in the [int] type. Generally, on + 32-bit platforms, its value will be [31], and on 64 bit platforms its value + will be [63]. When running in JavaScript, it will be [32]. {!Int.num_bits} + is the same as this value. *) +val int_size_in_bits : int + +(** [big_endian] is true when the program is running on a big-endian + architecture. When running in JavaScript, it will be [false]. *) +val big_endian : bool + +(** [max_string_length] is the maximum allowed length of a [string] or [Bytes.t]. + {!String.max_length} is the same as this value. *) +val max_string_length : int + +(** [max_array_length] is the maximum allowed length of an ['a array]. + {!Array.max_length} is the same as this value. *) +val max_array_length : int + +(** Returns the name of the runtime variant the program is running on. This is normally + the argument given to [-runtime-variant] at compile time, but for byte-code it can be + changed after compilation. When running in JavaScript, it will be [""]. *) +val runtime_variant : unit -> string + +(** Returns the value of the runtime parameters, in the same format as the contents of the + [OCAMLRUNPARAM] environment variable. When running in JavaScript, it will be [""]. *) +val runtime_parameters : unit -> string + +(** [ocaml_version] is the OCaml version with which the program was compiled. It is a + string of the form ["major.minor[.patchlevel][+additional-info]"], where major, minor, + and patchlevel are integers, and additional-info is an arbitrary string. The + [[.patchlevel]] and [[+additional-info]] parts may be absent. *) +val ocaml_version : string + +(** Controls whether the OCaml runtime system can emit warnings on stderr. Currently, the + only supported warning is triggered when a channel created by [open_*] functions is + finalized without being closed. Runtime warnings are enabled by default. *) +val enable_runtime_warnings : bool -> unit + +(** Returns whether runtime warnings are currently enabled. *) +val runtime_warnings_enabled : unit -> bool + +(** For the purposes of optimization, [opaque_identity] behaves like an unknown (and thus + possibly side-effecting) function. At runtime, [opaque_identity] disappears + altogether. A typical use of this function is to prevent pure computations from being + optimized away in benchmarking loops. For example: + + {[ + for _round = 1 to 100_000 do + ignore (Sys.opaque_identity (my_pure_computation ())) + done + ]} *) +external opaque_identity : 'a -> 'a = "%opaque" diff --git a/src/sys0.ml b/src/sys0.ml new file mode 100644 index 0000000..e2f69ac --- /dev/null +++ b/src/sys0.ml @@ -0,0 +1,43 @@ + +(* [Sys0] defines functions that are primitives or can be simply defined in + terms of [Caml.Sys]. [Sys0] is intended to completely express the part of + [Caml.Sys] that [Base] uses -- no other file in Base other than sys.ml + should use [Caml.Sys]. [Sys0] has few dependencies, and so is available + early in Base's build order. All Base files that need to use these + functions and come before [Base.Sys] in build order should do + [module Sys = Sys0]. Defining [module Sys = Sys0] is also necessary because + it prevents ocamldep from mistakenly causing a file to depend on [Base.Sys]. *) + +open! Import0 + +type backend_type = Caml.Sys.backend_type = + | Native + | Bytecode + | Other of string + +let backend_type = Caml.Sys.backend_type + +let interactive = Caml.Sys.interactive +let os_type = Caml.Sys.os_type +let unix = Caml.Sys.unix +let win32 = Caml.Sys.win32 +let cygwin = Caml.Sys.cygwin + +let word_size_in_bits = Caml.Sys.word_size +let int_size_in_bits = Caml.Sys.int_size +let big_endian = Caml.Sys.big_endian +let max_string_length = Caml.Sys.max_string_length +let max_array_length = Caml.Sys.max_array_length +let runtime_variant = Caml.Sys.runtime_variant +let runtime_parameters = Caml.Sys.runtime_parameters + +let argv = Caml.Sys.argv +let getenv = Caml.Sys.getenv + +let ocaml_version = Caml.Sys.ocaml_version +let enable_runtime_warnings = Caml.Sys.enable_runtime_warnings +let runtime_warnings_enabled = Caml.Sys.runtime_warnings_enabled + +external opaque_identity : 'a -> 'a = "%opaque" + +exception Break = Caml.Sys.Break diff --git a/src/t.ml b/src/t.ml new file mode 100644 index 0000000..9dd5970 --- /dev/null +++ b/src/t.ml @@ -0,0 +1,10 @@ +(** This module defines various abstract interfaces that are convenient when one needs a + module that matches a bare signature with just a type. This sometimes occurs in + functor arguments and in interfaces. *) + +open! Import + +module type T = sig type t end +module type T1 = sig type 'a t end +module type T2 = sig type ('a, 'b) t end +module type T3 = sig type ('a, 'b, 'c) t end diff --git a/src/type_equal.ml b/src/type_equal.ml new file mode 100644 index 0000000..a57271b --- /dev/null +++ b/src/type_equal.ml @@ -0,0 +1,172 @@ +open! Import + +type ('a, 'b) t = T : ('a, 'a) t [@@deriving_inline sexp_of] +let sexp_of_t : type a b. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> + (b -> Ppx_sexp_conv_lib.Sexp.t) -> (a, b) t -> Ppx_sexp_conv_lib.Sexp.t + = fun _of_a -> fun _of_b -> function | T -> Ppx_sexp_conv_lib.Sexp.Atom "T" +[@@@end] +type ('a, 'b) equal = ('a, 'b) t + +let refl = T + +let sym (type a) (type b) (T : (a, b) t) = (T : (b, a) t) + +let trans (type a) (type b) (type c) (T : (a, b) t) (T : (b, c) t) = (T : (a, c) t) + +let conv (type a) (type b) (T : (a, b) t) (a : a) = (a : b) + +module Lift (X : sig type 'a t end) = struct + let lift (type a) (type b) (T : (a, b) t) = (T : (a X.t, b X.t) t) +end + +module Lift2 (X : sig type ('a1, 'a2) t end) = struct + let lift (type a1) (type b1) (type a2) (type b2) (T : (a1, b1) t) (T : (a2, b2) t) = + (T : ((a1, a2) X.t, (b1, b2) X.t) t) + ;; +end + +module Lift3 (X : sig type ('a1, 'a2, 'a3) t end) = struct + let lift (type a1) (type b1) (type a2) (type b2) (type a3) (type b3) + (T : (a1, b1) t) (T : (a2, b2) t) (T : (a3, b3) t) = + (T : ((a1, a2, a3) X.t, (b1, b2, b3) X.t) t) + ;; +end + +let detuple2 (type a1) (type a2) (type b1) (type b2) + (T : (a1 * a2, b1 * b2) t) : (a1, b1) t * (a2, b2) t = + T, T +;; + +let tuple2 (type a1) (type a2) (type b1) (type b2) + (T : (a1, b1) t) (T : (a2, b2) t) : (a1 * a2, b1 * b2) t = + T +;; + +module type Injective = sig + type 'a t + val strip : ('a t, 'b t) equal -> ('a, 'b) equal +end + +module type Injective2 = sig + type ('a1, 'a2) t + val strip : (('a1, 'a2) t, ('b1, 'b2) t) equal -> ('a1, 'b1) equal * ('a2, 'b2) equal +end + +module Composition_preserves_injectivity (M1 : Injective) (M2 : Injective) = struct + type 'a t = 'a M1.t M2.t + let strip e = M1.strip (M2.strip e) +end + +module Obj = struct + module Extension_constructor = struct + [@@@ocaml.warning "-3"] + let id = Caml.Obj.extension_id + let of_val = Caml.Obj.extension_constructor + end +end + +module Id = struct + module Uid = Int + + module Witness = struct + module Key = struct + type _ t = .. + + type type_witness_int = [ `type_witness of int ] [@@deriving_inline sexp_of] + let sexp_of_type_witness_int : type_witness_int -> Ppx_sexp_conv_lib.Sexp.t = + function + | `type_witness v0 -> + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "type_witness"; sexp_of_int v0] + [@@@end] + + let sexp_of_t _sexp_of_a t = + (`type_witness (Obj.Extension_constructor.id (Obj.Extension_constructor.of_val t))) + |> sexp_of_type_witness_int + ;; + end + + module type S = sig + type t + type _ Key.t += Key : t Key.t + end + + type 'a t = (module S with type t = 'a) + + let sexp_of_t (type a) sexp_of_a (module M : S with type t = a) = + M.Key |> Key.sexp_of_t sexp_of_a + ;; + + let create (type t) () = + let module M = struct + type nonrec t = t + type _ Key.t += Key : t Key.t + end in + (module M : S with type t = t) + ;; + + let uid (type a) (module M : S with type t = a) = + Obj.Extension_constructor.id (Obj.Extension_constructor.of_val M.Key) + + (* We want a constant allocated once that [same] can return whenever it gets the same + witnesses. If we write the constant inside the body of [same], the native-code + compiler will do the right thing and lift it out. But for clarity and robustness, + we do it ourselves. *) + let some_t = Some T + + let same (type a) (type b) (a : a t) (b : b t) : (a, b) equal option = + let module A = (val a : S with type t = a) in + let module B = (val b : S with type t = b) in + match A.Key with + | B.Key -> some_t + | _ -> None + ;; + end + + + type 'a t = + { witness : 'a Witness.t + ; name : string + ; to_sexp : 'a -> Sexp.t + } + + let sexp_of_t _ { witness; name; to_sexp } : Sexp.t = + if am_testing + then Atom name + else List [ List [ Atom "name"; Atom name ] + ; List [ Atom "witness"; witness |> Witness.sexp_of_t to_sexp ]] + ;; + + let to_sexp t x = t.to_sexp x + let name t = t.name + + let create ~name to_sexp = + { witness = Witness.create () + ; name + ; to_sexp + } + ;; + + let uid t = Witness.uid t.witness + + let hash t = uid t + + let hash_fold_t s t = hash_fold_int s (uid t) + + let same_witness t1 t2 = Witness.same t1.witness t2.witness + + let same t1 t2 = Option.is_some (same_witness t1 t2) + + let same_witness_exn t1 t2 = + match same_witness t1 t2 with + | Some w -> w + | None -> + Error.raise_s + (Sexp.message "Type_equal.Id.same_witness_exn got different ids" + [ "", + sexp_of_pair (sexp_of_t sexp_of_opaque) (sexp_of_t sexp_of_opaque) (t1, t2) + ]) + ;; +end + diff --git a/src/type_equal.mli b/src/type_equal.mli new file mode 100644 index 0000000..eb33bf4 --- /dev/null +++ b/src/type_equal.mli @@ -0,0 +1,236 @@ +(** The purpose of [Type_equal] is to represent type equalities that the type checker + otherwise would not know, perhaps because the type equality depends on dynamic data, + or perhaps because the type system isn't powerful enough. + + A value of type [(a, b) Type_equal.t] represents that types [a] and [b] are equal. + One can think of such a value as a proof of type equality. The [Type_equal] module + has operations for constructing and manipulating such proofs. For example, the + functions [refl], [sym], and [trans] express the usual properties of reflexivity, + symmetry, and transitivity of equality. + + If one has a value [t : (a, b) Type_equal.t] that proves types [a] and [b] are equal, + there are two ways to use [t] to safely convert a value of type [a] to a value of type + [b]: [Type_equal.conv] or pattern matching on [Type_equal.T]: + + {[ + let f (type a) (type b) (t : (a, b) Type_equal.t) (a : a) : b = + Type_equal.conv t a + + let f (type a) (type b) (t : (a, b) Type_equal.t) (a : a) : b = + let Type_equal.T = t in a + ]} + + At runtime, conversion by either means is just the identity -- nothing is changing + about the value. Consistent with this, a value of type [Type_equal.t] is always just + a constructor [Type_equal.T]; the value has no interesting semantic content. + [Type_equal] gets its power from the ability to, in a type-safe way, prove to the type + checker that two types are equal. The [Type_equal.t] value that is passed is + necessary for the type-checker's rules to be correct, but the compiler could, in + principle, not pass around values of type [Type_equal.t] at runtime. +*) + +open! Import +open T + +type ('a, 'b) t = T : ('a, 'a) t [@@deriving_inline sexp_of] +include +sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('b -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 'b) t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] +type ('a, 'b) equal = ('a, 'b) t (** just an alias, needed when [t] gets shadowed below *) + +(** [refl], [sym], and [trans] construct proofs that type equality is reflexive, + symmetric, and transitive. *) + +val refl : ('a, 'a) t +val sym : ('a, 'b) t -> ('b, 'a) t +val trans : ('a, 'b) t -> ('b, 'c) t -> ('a, 'c) t + +(** [conv t x] uses the type equality [t : (a, b) t] as evidence to safely cast [x] + from type [a] to type [b]. [conv] is semantically just the identity function. + + In a program that has [t : (a, b) t] where one has a value of type [a] that one wants + to treat as a value of type [b], it is often sufficient to pattern match on + [Type_equal.T] rather than use [conv]. However, there are situations where OCaml's + type checker will not use the type equality [a = b], and one must use [conv]. For + example: + + {[ + module F (M1 : sig type t end) (M2 : sig type t end) : sig + val f : (M1.t, M2.t) equal -> M1.t -> M2.t + end = struct + let f equal (m1 : M1.t) = conv equal m1 + end + ]} + + If one wrote the body of [F] using pattern matching on [T]: + + {[ + let f (T : (M1.t, M2.t) equal) (m1 : M1.t) = (m1 : M2.t) + ]} + + this would give a type error. *) +val conv : ('a, 'b) t -> 'a -> 'b + +(** It is always safe to conclude that if type [a] equals [b], then for any type ['a t], + type [a t] equals [b t]. The OCaml type checker uses this fact when it can. However, + sometimes, e.g., when using [conv], one needs to explicitly use this fact to construct + an appropriate [Type_equal.t]. The [Lift*] functors do this. *) + +module Lift (X : T1) : sig + val lift : ('a, 'b) t -> ('a X.t, 'b X.t) t +end + +module Lift2 (X : T2) : sig + val lift : ('a1, 'b1) t -> ('a2, 'b2) t -> (('a1, 'a2) X.t, ('b1, 'b2) X.t) t +end + +module Lift3 (X : T3) : sig + val lift + : ('a1, 'b1) t + -> ('a2, 'b2) t + -> ('a3, 'b3) t + -> (('a1, 'a2, 'a3) X.t, ('b1, 'b2, 'b3) X.t) t +end + +(** [tuple2] and [detuple2] convert between equality on a 2-tuple and its components. *) + +val detuple2 : ('a1 * 'a2, 'b1 * 'b2) t -> ('a1, 'b1) t * ('a2, 'b2) t +val tuple2 : ('a1, 'b1) t -> ('a2, 'b2) t -> ('a1 * 'a2, 'b1 * 'b2) t + +(** [Injective] is an interface that states that a type is injective, where the type is + viewed as a function from types to other types. The typical usage is: + + {[ + type 'a t + include Injective with type 'a t := 'a t + ]} + + For example, ['a list] is an injective type, because whenever ['a list = 'b list], we + know that ['a] = ['b]. On the other hand, if we define: + + {[ + type 'a t = unit + ]} + + then clearly [t] isn't injective, because, e.g., [int t = bool t], but [int <> bool]. + + If [module M : Injective], then [M.strip] provides a way to get a proof that two types + are equal from a proof that both types transformed by [M.t] are equal. + + OCaml has no built-in language feature to state that a type is injective, which is why + we have [module type Injective]. However, OCaml can infer that a type is injective, + and we can use this to match [Injective]. A typical implementation will look like + this: + + {[ + let strip (type a) (type b) + (Type_equal.T : (a t, b t) Type_equal.t) : (a, b) Type_equal.t = + Type_equal.T + ]} + + This will not type check for all type constructors (certainly not for non-injective + ones!), but it's always safe to try the above implementation if you are unsure. If + OCaml accepts this definition, then the type is injective. On the other hand, if + OCaml doesn't, then the type may or may not be injective. For example, if the + definition of the type depends on abstract types that match [Injective], OCaml will + not automatically use their injectivity, and one will have to write a more complicated + definition of [strip] that causes OCaml to use that fact. For example: + + {[ + module F (M : Type_equal.Injective) : Type_equal.Injective = struct + type 'a t = 'a M.t * int + + let strip (type a) (type b) + (e : (a t, b t) Type_equal.t) : (a, b) Type_equal.t = + let e1, _ = Type_equal.detuple2 e in + M.strip e1 + ;; + end + ]} + + If in the definition of [F] we had written the simpler implementation of [strip] that + didn't use [M.strip], then OCaml would have reported a type error. +*) +module type Injective = sig + type 'a t + val strip : ('a t, 'b t) equal -> ('a, 'b) equal +end + +(** [Injective2] is for a binary type that is injective in both type arguments. *) +module type Injective2 = sig + type ('a1, 'a2) t + val strip : (('a1, 'a2) t, ('b1, 'b2) t) equal -> ('a1, 'b1) equal * ('a2, 'b2) equal +end + +(** [Composition_preserves_injectivity] is a functor that proves that composition of + injective types is injective. *) +module Composition_preserves_injectivity (M1 : Injective) (M2 : Injective) + : Injective with type 'a t = 'a M1.t M2.t + +(** [Id] provides identifiers for types, and the ability to test (via [Id.same]) at + runtime if two identifiers are equal, and if so to get a proof of equality of their + types. Unlike values of type [Type_equal.t], values of type [Id.t] do have semantic + content and must have a nontrivial runtime representation. *) +module Id : sig + type 'a t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** Every [Id.t] contains a unique id that is distinct from the [Uid.t] in any other + [Id.t]. *) + module Uid : sig + type t [@@deriving_inline hash] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + end[@@ocaml.doc "@inline"] + [@@@end] + include Sexpable.S with type t := t + include Comparable.S with type t := t + end + + val uid : _ t -> Uid.t + + (** [create ~name] defines a new type identity. Two calls to [create] will result in + two distinct identifiers, even for the same arguments with the same type. If the + type ['a] doesn't support sexp conversion, then a good practice is to have the + converter be [<:sexp_of< _ >>], (or [sexp_of_opaque], if not using pa_sexp). *) + val create + : name:string + -> ('a -> Sexp.t) + -> 'a t + + (** Accessors *) + + val hash : _ t -> int + val name : _ t -> string + val to_sexp : 'a t -> 'a -> Sexp.t + + val hash_fold_t : Hash.state -> _ t -> Hash.state + + (** [same_witness t1 t2] and [same_witness_exn t1 t2] return a type equality proof iff + the two identifiers are the same (i.e., physically equal, resulting from the same + call to [create]). This is a useful way to achieve a sort of dynamic typing. + [same_witness] does not allocate a [Some] every time it is called. + + [same t1 t2 = is_some (same_witness t1 t2)]. + *) + + val same : _ t -> _ t -> bool + val same_witness : 'a t -> 'b t -> ('a, 'b) equal option + val same_witness_exn : 'a t -> 'b t -> ('a, 'b) equal +end diff --git a/src/uchar.ml b/src/uchar.ml new file mode 100644 index 0000000..57bbe4d --- /dev/null +++ b/src/uchar.ml @@ -0,0 +1,81 @@ +open! Import + +let failwithf = Printf.failwithf + +module T = struct + include Uchar0 + + let module_name = "Base.Uchar" + + let hash_fold_t state t = Hash.fold_int state (to_int t) + let hash t = Hash.run hash_fold_t t + + let to_string t = + Printf.sprintf "U+%04X" (to_int t) + (* Do not actually export this. See discussion in the .mli *) + + let sexp_of_t t = Sexp.Atom (to_string t) + let t_of_sexp sexp = + match sexp with + | Sexp.List _ -> of_sexp_error "Uchar.t_of_sexp: atom needed" sexp + | Sexp.Atom s -> + try + Caml.Scanf.sscanf s "U+%X" (fun i -> Uchar0.of_int i) + with _ -> + of_sexp_error "Uchar.t_of_sexp: atom of the form U+XXXX needed" sexp +end + +include T +include Pretty_printer.Register(T) +include Comparable.Make(T) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Uchar_replace_polymorphic_compare + +let int_is_scalar = is_valid + +let succ_exn c = + try Uchar0.succ c + with Invalid_argument msg -> failwithf "Uchar.succ_exn: %s" msg () + +let succ c = + try Some (Uchar0.succ c) + with Invalid_argument _ -> None + +let pred_exn c = + try Uchar0.pred c + with Invalid_argument msg -> failwithf "Uchar.pred_exn: %s" msg () + +let pred c = + try Some (Uchar0.pred c) + with Invalid_argument _ -> None + +let of_scalar i = + if int_is_scalar i + then Some (unsafe_of_int i) + else None + +let of_scalar_exn i = + if int_is_scalar i + then unsafe_of_int i + else failwithf "Uchar.of_int_exn got a invalid Unicode scalar value: %04X" i () + +let to_scalar t = Uchar0.to_int t + +let to_char c = + if is_char c + then Some (unsafe_to_char c) + else None + +let to_char_exn c = + if is_char c + then unsafe_to_char c + else failwithf "Uchar.to_char_exn got a non latin-1 character: U+%04X" (to_int c) () + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Uchar_replace_polymorphic_compare diff --git a/src/uchar.mli b/src/uchar.mli new file mode 100644 index 0000000..541b22f --- /dev/null +++ b/src/uchar.mli @@ -0,0 +1,55 @@ +(** Unicode character operations. *) + +open! Import + +type t = Uchar0.t [@@deriving_inline compare, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Comparable.S with type t := t +include Pretty_printer.S with type t := t + + +(** [succ_exn t] is the scalar value after [t] in the set of Unicode scalar values, and + raises if [t = max_value]. *) +val succ : t -> t option +val succ_exn : t -> t + +(** [pred_exn t] is the scalar value before [t] in the set of Unicode scalar values, and + raises if [t = min_value]. *) +val pred : t -> t option +val pred_exn : t -> t + +(** [is_char t] is [true] iff [n] is in the latin-1 character set. *) +val is_char : t -> bool + +(** [to_char_exn t] is [t] as a [char] if it is in the latin-1 character set, and raises + otherwise. *) +val to_char : t -> char option +val to_char_exn : t -> char + +(** [of_char c] is [c] as a Unicode character. *) +val of_char : char -> t + +(** [int_is_scalar n] is [true] iff [n] is an Unicode scalar value (i.e., in the ranges + [0x0000]...[0xD7FF] or [0xE000]...[0x10FFFF]). *) +val int_is_scalar : int -> bool + +(** [of_scalar_exn n] is [n] as a Unicode character. Raises if [not (int_is_scalar + i)]. *) +val of_scalar : int -> t option +val of_scalar_exn : int -> t + +(** [to_scalar t] is [t] as an integer scalar value. *) +val to_scalar : t -> int + +val min_value : t +val max_value : t diff --git a/src/uchar0.ml b/src/uchar0.ml new file mode 100644 index 0000000..59baacd --- /dev/null +++ b/src/uchar0.ml @@ -0,0 +1,21 @@ +open! Import0 + +type t = Caml.Uchar.t + +let succ = Caml.Uchar.succ +let pred = Caml.Uchar.pred +let is_valid = Caml.Uchar.is_valid +let is_char = Caml.Uchar.is_char +let unsafe_to_char = Caml.Uchar.unsafe_to_char +let unsafe_of_int = Caml.Uchar.unsafe_of_int + +let of_int = Caml.Uchar.of_int +let to_int = Caml.Uchar.to_int + +let of_char = Caml.Uchar.of_char + +let compare = Caml.Uchar.compare +let equal = Caml.Uchar.equal + +let min_value = Caml.Uchar.min +let max_value = Caml.Uchar.max diff --git a/src/uniform_array.ml b/src/uniform_array.ml new file mode 100644 index 0000000..1e5deae --- /dev/null +++ b/src/uniform_array.ml @@ -0,0 +1,119 @@ +open! Import + +(* WARNING: + We use non-memory-safe things throughout the [Trusted] module. + Most of it is only safe in combination with the type signature (e.g. exposing + [val copy : 'a t -> 'b t] would be a big mistake). *) +module Trusted : sig + + type 'a t + val empty : 'a t + val unsafe_create_uninitialized : len:int -> 'a t + val create_obj_array : len:int -> 'a t + val create : len:int -> 'a -> 'a t + val singleton : 'a -> 'a t + val get : 'a t -> int -> 'a + val set : 'a t -> int -> 'a -> unit + val swap : _ t -> int -> int -> unit + val unsafe_get : 'a t -> int -> 'a + val unsafe_set : 'a t -> int -> 'a -> unit + val unsafe_set_omit_phys_equal_check : 'a t -> int -> 'a -> unit + val unsafe_set_int : 'a t -> int -> int -> unit + val unsafe_set_int_assuming_currently_int : 'a t -> int -> int -> unit + val unsafe_set_assuming_currently_int : 'a t -> int -> 'a -> unit + val length : 'a t -> int + val unsafe_blit : ('a t, 'a t) Blit.blit + val copy : 'a t -> 'a t + val unsafe_truncate : 'a t -> len:int -> unit + val unsafe_clear_if_pointer : _ t -> int -> unit +end = struct + + type 'a t = Obj_array.t + + let empty = Obj_array.empty + + let unsafe_create_uninitialized ~len = Obj_array.create_zero ~len + let create_obj_array ~len = Obj_array.create_zero ~len + let create ~len x = Obj_array.create ~len (Caml.Obj.repr x) + let singleton x = Obj_array.singleton (Caml.Obj.repr x) + + let swap t i j = Obj_array.swap t i j + + let get arr i = Caml.Obj.obj (Obj_array.get arr i) + let set arr i x = Obj_array.set arr i (Caml.Obj.repr x) + let unsafe_get arr i = Caml.Obj.obj (Obj_array.unsafe_get arr i) + let unsafe_set arr i x = + Obj_array.unsafe_set arr i (Caml.Obj.repr x) + let unsafe_set_int arr i x = + Obj_array.unsafe_set_int arr i x + let unsafe_set_int_assuming_currently_int arr i x = + Obj_array.unsafe_set_int_assuming_currently_int arr i x + let unsafe_set_assuming_currently_int arr i x = + Obj_array.unsafe_set_assuming_currently_int arr i (Caml.Obj.repr x) + + let length = Obj_array.length + + let unsafe_blit = Obj_array.unsafe_blit + + let copy = Obj_array.copy + + let unsafe_truncate = Obj_array.truncate + + let unsafe_set_omit_phys_equal_check t i x = + Obj_array.unsafe_set_omit_phys_equal_check t i (Caml.Obj.repr x) + + let unsafe_clear_if_pointer = Obj_array.unsafe_clear_if_pointer +end + +include Trusted + +let init l ~f = + if l < 0 then invalid_arg "Uniform_array.init" + else + let res = unsafe_create_uninitialized ~len:l in + for i = 0 to l - 1 do + unsafe_set res i (f i) + done; + res + +let of_array arr = init ~f:(Array.unsafe_get arr) (Array.length arr) + +let map a ~f = init ~f:(fun i -> f (unsafe_get a i)) (length a) + +let iter a ~f = + for i = 0 to length a - 1 do + f (unsafe_get a i) + done + +let to_list t = List.init ~f:(get t) (length t) + +let of_list l = + let len = List.length l in + let res = unsafe_create_uninitialized ~len in + List.iteri l ~f:(fun i x -> set res i x); + res + +(* It is not safe for [to_array] to be the identity function because we have code that + relies on [float array]s being unboxed, for example in [bin_write_array]. *) +let to_array t = Array.init (length t) ~f:(fun i -> unsafe_get t i) + +include Sexpable.Of_sexpable1(Array)(struct + type nonrec 'a t = 'a t + let to_sexpable = to_array + let of_sexpable = of_array + end) + +include + Blit.Make1 + (struct + type nonrec 'a t = 'a t + let length = length + let create_like ~len t = + if len = 0 + then empty + else (assert (length t > 0); create ~len (get t 0)) + ;; + let unsafe_blit = unsafe_blit + end) + + diff --git a/src/uniform_array.mli b/src/uniform_array.mli new file mode 100644 index 0000000..5219cbf --- /dev/null +++ b/src/uniform_array.mli @@ -0,0 +1,88 @@ +(** Same semantics as ['a Array.t], except it's guaranteed that the representation array + is not tagged with [Double_array_tag], the tag for float arrays. + + This means it's safer to use in the presence of [Obj.magic], but it's slower than + normal [Array] if you use it with floats. + + It can often be faster than [Array] if you use it with non-floats. +*) + +open! Import + +(** See [Base.Array] for comments. *) +type 'a t [@@deriving_inline sexp] +include +sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + + +val empty : _ t + +val create : len:int -> 'a -> 'a t + +val singleton : 'a -> 'a t + +val init : int -> f:(int -> 'a) -> 'a t + +val length : 'a t -> int + +val get : 'a t -> int -> 'a +val unsafe_get : 'a t -> int -> 'a + +val set : 'a t -> int -> 'a -> unit +val unsafe_set : 'a t -> int -> 'a -> unit + +val swap : _ t -> int -> int -> unit + +(** [unsafe_set_omit_phys_equal_check] is like [unsafe_set], except it doesn't do a + [phys_equal] check to try to skip [caml_modify]. It is safe to call this even if the + values are [phys_equal]. *) +val unsafe_set_omit_phys_equal_check : 'a t -> int -> 'a -> unit + +val map : 'a t -> f:('a -> 'b) -> 'b t +val iter : 'a t -> f:('a -> unit) -> unit + +(** [of_array] and [to_array] return fresh arrays with the same contents rather than + returning a reference to the underlying array. *) +val of_array : 'a array -> 'a t +val to_array : 'a t -> 'a array + +val of_list : 'a list -> 'a t +val to_list : 'a t -> 'a list + +include Blit.S1 with type 'a t := 'a t + +val copy : 'a t -> 'a t + +(** [truncate t ~len] shortens [t]'s length to [len]. It is an error if [len <= 0] or + [len > length t]. It's unsafe to truncate in the middle of iteration. *) +val unsafe_truncate : _ t -> len:int -> unit + +(** {2 Extra lowlevel and unsafe functions} *) + +(** The behavior is undefined if you access an element before setting it. *) +val unsafe_create_uninitialized : len:int -> _ t + +(** New obj array filled with [Obj.repr 0] *) +val create_obj_array : len:int -> Caml.Obj.t t + +(** [unsafe_set_assuming_currently_int t i obj] sets index [i] of [t] to [obj], but only + works correctly if the value there is an immediate, i.e. [Caml.Obj.is_int (get t i)]. + This precondition saves a dynamic check. + + [unsafe_set_int_assuming_currently_int] is similar, except the value being set is an + int. + + [unsafe_set_int] is similar but does not assume anything about the target. *) +val unsafe_set_assuming_currently_int : Caml.Obj.t t -> int -> Caml.Obj.t -> unit +val unsafe_set_int_assuming_currently_int : Caml.Obj.t t -> int -> int -> unit +val unsafe_set_int : Caml.Obj.t t -> int -> int -> unit + +(** [unsafe_clear_if_pointer t i] prevents [t.(i)] from pointing to anything to prevent + space leaks. It does this by setting [t.(i)] to [Caml.Obj.repr 0]. As a performance + hack, it only does this when [not (Caml.Obj.is_int t.(i))]. It is an error to access + the cleared index before setting it again. *) +val unsafe_clear_if_pointer : Caml.Obj.t t -> int -> unit diff --git a/src/unit.ml b/src/unit.ml new file mode 100644 index 0000000..afae207 --- /dev/null +++ b/src/unit.ml @@ -0,0 +1,29 @@ +open! Import + +module T = struct + type t = unit [@@deriving_inline enumerate, hash, sexp] + let all : t list = [()] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_unit + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_unit in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = unit_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_unit + [@@@end] + + let compare _ _ = 0 + + let of_string = function + | "()" -> () + | _ -> failwith "Base.Unit.of_string: () expected" + + let to_string () = "()" + + let module_name = "Base.Unit" +end + +include T +include Identifiable.Make (T) + +let invariant () = () diff --git a/src/unit.mli b/src/unit.mli new file mode 100644 index 0000000..6a1ecc4 --- /dev/null +++ b/src/unit.mli @@ -0,0 +1,19 @@ +(** Module for the type [unit]. *) + +open! Import + +type t = unit [@@deriving_inline compare, enumerate, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Identifiable.S with type t := t +include Invariant.S with type t := t diff --git a/src/validate.ml b/src/validate.ml new file mode 100644 index 0000000..81fb0c6 --- /dev/null +++ b/src/validate.ml @@ -0,0 +1,184 @@ +open! Import + +module Int = Int0 +module String = String0 + +(** Each single_error is a path indicating the location within the datastructure in + question that is being validated, along with an error message. *) +type single_error = + { path : string list; + error : Error.t; + } + +type t = single_error list + +type 'a check = 'a -> t + +let pass : t = [] + +let fails message a sexp_of_a = + [ { path = []; + error = Error.create message a sexp_of_a; + } ] +;; + +let fail message = [ { path = []; error = Error.of_string message } ] + +let failf format = Printf.ksprintf fail format + +let fail_s sexp = [ { path = []; error = Error.create_s sexp } ] + +let combine t1 t2 = t1 @ t2 + +let of_list = List.concat + +let name name t = + match t with + | [] -> [] (* when successful, avoid the allocation of a closure for [~f], below *) + | _ -> List.map t ~f:(fun { path; error } -> { path = name :: path; error }) +;; + +let name_list n l = name n (of_list l) + +let fail_fn message _ = fail message + +let pass_bool (_:bool) = pass +let pass_unit (_:unit) = pass + +let protect f v = + try + f v + with exn -> + fail_s + (Sexp.message "Exception raised during validation" [ "", sexp_of_exn exn ]) +;; + +let try_with f = + protect (fun () -> f (); pass) () + +let path_string path = String.concat ~sep:"." path + +let errors t = + List.map t ~f:(fun { path; error } -> + (Error.to_string_hum (Error.tag error ~tag:(path_string path)))) +;; + +let [@inline never] result_fail t = + Or_error.error + "validation errors" + (List.map t ~f:(fun { path; error } -> (path_string path, error))) + (sexp_of_list (sexp_of_pair sexp_of_string Error.sexp_of_t)) +;; + +(** [result] is carefully implemented so that it can be inlined -- calling [result_fail], + which is not inlineable, is key to this. *) +let result t = + if List.is_empty t + then Ok () + else result_fail t +;; + +let maybe_raise t = Or_error.ok_exn (result t) + +let valid_or_error x check = + Or_error.map (result (protect check x)) ~f:(fun () -> x) +;; + +let field record fld f = + let v = Field.get fld record in + let result = protect f v in + name (Field.name fld) result +;; + +let field_folder record check = (); fun acc fld -> field record fld check :: acc + +let field_direct_folder check = + Staged.stage (fun acc fld _record v -> + match protect check v with + | [] -> acc + | result -> name (Field.name fld) result :: acc) +;; + +let all checks v = + let rec loop checks v errs = + match checks with + | [] -> errs + | check :: checks -> + match protect check v with + | [] -> loop checks v errs + | err -> loop checks v (err :: errs) + in + of_list (List.rev (loop checks v [])) +;; + +let of_result f = + protect (fun v -> + match f v with + | Ok () -> pass + | Error error -> fail error) +;; + +let of_error f = + protect (fun v -> + match f v with + | Ok () -> pass + | Error error -> [ { path = []; error } ]) +;; + +let booltest f ~if_false = protect (fun v -> if f v then pass else fail if_false) + +let pair ~fst ~snd (fst_value,snd_value) = + of_list [ name "fst" (protect fst fst_value); + name "snd" (protect snd snd_value); + ] +;; + +let list_indexed check list = + List.mapi list ~f:(fun i el -> + name (Int.to_string (i+1)) (protect check el)) + |> of_list +;; + +let list ~name:extract_name check list = + List.map list ~f:(fun el -> + match protect check el with + | [] -> [] + | t -> + (* extra level of protection in case extract_name throws an exception *) + protect (fun t -> name (extract_name el) t) t) + |> of_list +;; + +let alist ~name f list' = + list (fun (_, x) -> f x) list' + ~name:(fun (key, _) -> name key) +;; + +let first_failure t1 t2 = if List.is_empty t1 then t2 else t1 + +let of_error_opt = function + | None -> pass + | Some error -> fail error +;; + +let bounded ~name ~lower ~upper ~compare x = + match Maybe_bound.compare_to_interval_exn ~lower ~upper ~compare x with + | In_range -> pass + | Below_lower_bound -> + begin + match lower with + | Unbounded -> assert false + | Incl incl -> fail (Printf.sprintf "value %s < bound %s" (name x) (name incl)) + | Excl excl -> fail (Printf.sprintf "value %s <= bound %s" (name x) (name excl)) + end + | Above_upper_bound -> + begin + match upper with + | Unbounded -> assert false + | Incl incl -> fail (Printf.sprintf "value %s > bound %s" (name x) (name incl)) + | Excl excl -> fail (Printf.sprintf "value %s >= bound %s" (name x) (name excl)) + end + +module Infix = struct + let (++) t1 t2 = combine t1 t2 +end diff --git a/src/validate.mli b/src/validate.mli new file mode 100644 index 0000000..aff6022 --- /dev/null +++ b/src/validate.mli @@ -0,0 +1,175 @@ +(** A module for organizing validations of data structures. + + Allows standardized ways of checking for conditions, and keeps track of the location + of errors by keeping a path to each error found. Thus, if you were validating the + following datastructure: + + {[ + { foo = 3; + bar = { snoo = 34.5; + blue = Snoot -6; } + } + ]} + + One might end up with an error with the error path: + + {v bar.blue.Snoot : value -6 <= bound 0 v} + + By convention, the validations for a type defined in module [M] appear in module [M], + and have their name prefixed by [validate_]. E.g., [Int.validate_positive]. + + Here's an example of how you would use [validate] with a record: + + {[ + type t = + { foo: int; + bar: float; + } + [@@deriving_inline fields][@@@end] + + let validate t = + let module V = Validate in + let w check = V.field_folder t check in + V.of_list + (Fields.fold ~init:[] + ~foo:(w Int.validate_positive) + ~bar:(w Float.validate_non_negative) + ) + ]} + + + And here's an example of how you would use it with a variant type: + + {[ + type t = + | Foo of int + | Bar of (float * int) + | Snoo of Floogle.t + + let validate = function + | Foo i -> V.name "Foo" (Int.validate_positive i) + | Bar p -> V.name "Bar" (V.pair + ~fst:Float.validate_positive + ~snd:Int.validate_non_negative p) + | Snoo floogle -> V.name "Snoo" (Floogle.validate floogle) + ]} *) + +open! Import + +(** The result of a validation. This effectively contains the list of errors, qualified + by their location path *) +type t +type 'a check = 'a -> t (** To make function signatures easier to read. *) + +(** A result containing no errors. *) +val pass : t + +(** A result containing a single error. *) +val fail : string -> t + +val fails + : string + -> 'a + -> ('a -> Sexp.t) + -> t +val fail_s : Sexp.t -> t (** This can be used with the [%sexp] extension. *) + +(** Like [sprintf] or [failwithf] but produces a [t] instead of a string or exception. *) +val failf : ('a, unit, string, t) format4 -> 'a + +val combine : t -> t -> t + +(** Combines multiple results, merging errors. *) +val of_list : t list -> t + +(** Extends location path by one name. *) +val name : string -> t -> t + +val name_list : string -> t list -> t + +(** [fail_fn err] returns a function that always returns fail, with [err] as the error + message. (Note that there is no [pass_fn] so as to discourage people from ignoring + the type of the value being passed unconditionally irrespective of type.) *) +val fail_fn : string -> _ check + +(** Checks for unconditionally passing a bool. *) +val pass_bool : bool check + +(** Checks for unconditionally passing a unit. *) +val pass_unit : unit check + +(** [protect f x] applies the validation [f] to [x], catching any exceptions and returning + them as errors. *) +val protect : 'a check -> 'a check + +(** [try_with f] runs [f] catching any exceptions and returning them as errors. *) +val try_with : (unit -> unit) -> t + +val result : t -> unit Or_error.t + +(** Returns a list of formatted error strings, which include both the error message and + the path to the error. *) +val errors : t -> string list + +(** If the result contains any errors, then raises an exception with a formatted error + message containing a message for every error. *) +val maybe_raise : t -> unit + +(** Returns an error if validation fails. *) +val valid_or_error : 'a -> 'a check -> 'a Or_error.t + +(** Used for validating an individual field. *) +val field : 'record -> ([> `Read], 'record, 'a) Field.t_with_perm -> 'a check -> t + +(** Creates a function for use in a [Fields.fold]. *) +val field_folder + : 'record + -> 'a check + -> (t list -> ([> `Read], 'record, 'a) Field.t_with_perm -> t list) + +(** Creates a function for use in a [Fields.Direct.fold]. *) +val field_direct_folder + : 'a check + -> (t list -> ([> `Read], 'record, 'a) Field.t_with_perm -> 'record -> 'a -> t list) + Staged.t + +(** Combines a list of validation functions into one that does all validations. *) +val all : 'a check list -> 'a check + +(** Creates a validation function from a function that produces a [Result.t]. *) +val of_result : ('a -> (unit, string) Result.t) -> 'a check + +val of_error : ('a -> unit Or_error.t) -> 'a check + +(** Creates a validation function from a function that produces a bool. *) +val booltest : ('a -> bool) -> if_false:string -> 'a check + +(** Validation functions for particular data types. *) +val pair : fst:'a check -> snd:'b check -> ('a * 'b) check + +(** Validates a list, naming each element by its position in the list (where the first + position is 1, not 0). *) +val list_indexed : 'a check -> 'a list check + +(** Validates a list, naming each element using a user-defined function for computing the + name. *) +val list : name:('a -> string) -> 'a check -> 'a list check + +val first_failure : t -> t -> t + +val of_error_opt : string option -> t + +(** Validates an association list, naming each element using a user-defined function for + computing the name. *) +val alist : name:('a -> string) -> 'b check -> ('a * 'b) list check + +val bounded + : name : ('a -> string) + -> lower : 'a Maybe_bound.t + -> upper : 'a Maybe_bound.t + -> compare : ('a -> 'a -> int) + -> 'a check + +module Infix : sig + val (++) : t -> t -> t (** Infix operator for [combine] above. *) +end diff --git a/src/variant.ml b/src/variant.ml new file mode 100644 index 0000000..3b7e041 --- /dev/null +++ b/src/variant.ml @@ -0,0 +1,7 @@ +type 'constructor t = { + name : string; + (* the position of the constructor in the type definition, starting from 0 *) + rank : int; + + constructor : 'constructor +} diff --git a/src/variant.mli b/src/variant.mli new file mode 100644 index 0000000..308a58e --- /dev/null +++ b/src/variant.mli @@ -0,0 +1,9 @@ +(** First-class representative of an individual variant in a variant type, used in + [[@@deriving_inline variants][@@@end]]. *) + +type 'constructor t = { + name : string; + (** The position of the constructor in the type definition, starting from 0 *) + rank : int; + constructor : 'constructor +} diff --git a/src/variantslib.ml b/src/variantslib.ml new file mode 100644 index 0000000..95994d5 --- /dev/null +++ b/src/variantslib.ml @@ -0,0 +1,3 @@ +(** This module is for use by ppx_variants_conv, and is thus not in the interface of + Base. *) +module Variant = Variant diff --git a/src/with_return.ml b/src/with_return.ml new file mode 100644 index 0000000..39a1531 --- /dev/null +++ b/src/with_return.ml @@ -0,0 +1,35 @@ +(* belongs in Common, but moved here to avoid circular dependencies *) + +open! Import + +type 'a return = { return : 'b. 'a -> 'b } [@@unboxed] + +let with_return (type a) f = + let module M = struct + (* Raised to indicate ~return was called. Local so that the exception is tied to a + particular call of [with_return]. *) + exception Return of a + end in + let is_alive = ref true in + let return a = + if not !is_alive + then failwith "use of [return] from a [with_return] that already returned"; + Exn.raise_without_backtrace (M.Return a); + in + try + let a = f { return } in + is_alive := false; + a + with exn -> + is_alive := false; + match exn with + | M.Return a -> a + | _ -> raise exn +;; + +let with_return_option f = + with_return (fun return -> + f { return = fun a -> return.return (Some a) }; None) +;; + +let prepend { return } ~f = { return = fun x -> return (f x) } diff --git a/src/with_return.mli b/src/with_return.mli new file mode 100644 index 0000000..6ed770a --- /dev/null +++ b/src/with_return.mli @@ -0,0 +1,54 @@ + +(** [with_return f] allows for something like the return statement in C within [f]. + + There are three ways [f] can terminate: + + + If [f] calls [r.return x], then [x] is returned by [with_return]. + + If [f] evaluates to a value [x], then [x] is returned by [with_return]. + + If [f] raises an exception, it escapes [with_return]. + + Here is a typical example: + + {[ + let find l ~f = + with_return (fun r -> + List.iter l ~f:(fun x -> if f x then r.return (Some x)); + None + ) + ]} + + It is only because of a deficiency of ML types that [with_return] doesn't have type: + + {[ val with_return : 'a. (('a -> ('b. 'b)) -> 'a) -> 'a ]} + + but we can slightly increase the scope of ['b] without changing the meaning of the + type, and then we get: + + {[ + type 'a return = { return : 'b . 'a -> 'b } + val with_return : ('a return -> 'a) -> 'a + ]} + + But the actual reason we chose to use a record type with polymorphic field is that + otherwise we would have to clobber the namespace of functions with [return] and that + is undesirable because [return] would get hidden as soon as we open any monad. We + considered names different than [return] but everything seemed worse than just having + [return] as a record field. We are clobbering the namespace of record fields but that + is much more acceptable. *) + +open! Import + +type -'a return = private { return : 'b. 'a -> 'b } [@@unboxed] + +val with_return : ('a return -> 'a ) -> 'a + +(** Note that [with_return_option] allocates ~5 words more than the equivalent + [with_return] call. *) +val with_return_option : ('a return -> unit) -> 'a option + +(** [prepend a ~f] returns a value [x] such that each call to [x.return] first applies [f] + before applying [a.return]. The call to [f] is "prepended" to the call to the + original [a.return]. A possible use case is to hand [x] over to another function + which returns ['b], a subtype of ['a], or to capture a common transformation [f] + applied to returned values at several call sites. *) +val prepend : 'a return -> f:('b -> 'a) -> 'b return diff --git a/src/word_size.ml b/src/word_size.ml new file mode 100644 index 0000000..3651fe6 --- /dev/null +++ b/src/word_size.ml @@ -0,0 +1,19 @@ +open! Import + +module Sys = Sys0 + +type t = W32 | W64 [@@deriving_inline sexp_of] +let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | W32 -> Ppx_sexp_conv_lib.Sexp.Atom "W32" + | W64 -> Ppx_sexp_conv_lib.Sexp.Atom "W64" +[@@@end] + +let num_bits = function W32 -> 32 | W64 -> 64 + +let word_size = + match Sys.word_size_in_bits with + | 32 -> W32 + | 64 -> W64 + | _ -> failwith "unknown word size" +;; diff --git a/src/word_size.mli b/src/word_size.mli new file mode 100644 index 0000000..11f7a53 --- /dev/null +++ b/src/word_size.mli @@ -0,0 +1,14 @@ +(** For determining the word size that the program is using. *) + +open! Import + +type t = W32 | W64 [@@deriving_inline sexp_of] +include +sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] + +val num_bits : t -> int + +(** Returns the word size of this program, not necessarily of the OS. *) +val word_size : t diff --git a/test/avltree_unit_tests.ml b/test/avltree_unit_tests.ml new file mode 100644 index 0000000..443e343 --- /dev/null +++ b/test/avltree_unit_tests.ml @@ -0,0 +1,282 @@ +open! Import + +let%test_module _ = + (module (struct + + open Avltree + + type ('k, 'v) t = ('k, 'v) Avltree.t = private + | Empty + | Node of { mutable left : ('k, 'v) t + ; key : 'k + ; mutable value : 'v + ; mutable height : int + ; mutable right : ('k, 'v) t + } + | Leaf of { key : 'k + ; mutable value : 'v + } + + module For_quickcheck = struct + + module Key = struct + include Quickcheck.Int + let quickcheck_generator = Quickcheck.Generator.small_non_negative_int + end + module Data = struct + include Quickcheck.String + let quickcheck_generator = gen' Quickcheck.Char.gen_lowercase + end + + let compare = Key.compare + + open Quickcheck + open Generator + + module Constructor = struct + + type t = + | Add of Key.t * Data.t + | Replace of Key.t * Data.t + | Remove of Key.t + [@@deriving sexp_of] + + let add_gen = + Key.quickcheck_generator >>= fun key -> + Data.quickcheck_generator >>| fun data -> + Add (key, data) + + let replace_gen = + Key.quickcheck_generator >>= fun key -> + Data.quickcheck_generator >>| fun data -> + Replace (key, data) + + let remove_gen = + Key.quickcheck_generator >>| fun key -> + Remove key + + let quickcheck_generator = union [ add_gen ; replace_gen ; remove_gen ] + + let apply_to_tree t tree = + match t with + | Add (key, data) -> + add tree ~key ~data ~compare ~added:(ref false) ~replace:false + | Replace (key, data) -> + add tree ~key ~data ~compare ~added:(ref false) ~replace:true + | Remove key -> + remove tree key ~compare ~removed:(ref false) + + let apply_to_map t map = + match t with + | Add (key, data) -> + if Map.mem map key + then map + else Map.set map ~key ~data + | Replace (key, data) -> + Map.set map ~key ~data + | Remove key -> + Map.remove map key + + end + + let constructors_gen = List.quickcheck_generator Constructor.quickcheck_generator + + let reify constructors = + List.fold constructors + ~init:(empty, Key.Map.empty) + ~f:(fun (t, map) constructor -> + Constructor.apply_to_tree constructor t, + Constructor.apply_to_map constructor map) + + let merge map1 map2 = + Map.merge map1 map2 ~f:(fun ~key variant -> + match variant with + | `Left data | `Right data -> Some data + | `Both (data1, data2) -> + Error.raise_s ( + [%message + "duplicate data for key" (key : Key.t) (data1 : Data.t) (data2 : Data.t)])) + let rec to_map = function + | Empty -> Key.Map.empty + | Leaf { key; value = data } -> Key.Map.singleton key data + | Node { left; key; value = data; height = _; right } -> + merge (Key.Map.singleton key data) + (merge (to_map left) (to_map right)) + + end + + open For_quickcheck + + let empty = empty + + let%test_unit _ = + match empty with + | Empty -> () + | _ -> assert false + + let invariant = invariant + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + invariant t ~compare; + [%test_result: Data.t Key.Map.t] (to_map t) ~expect:map) + + let add = add + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple4 constructors_gen Key.quickcheck_generator Data.quickcheck_generator Quickcheck.Bool.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t * Data.t * bool] + ~f:(fun (constructors, key, data, replace) -> + let t, map = reify constructors in + (* test [added], other aspects of [add] are tested via [reify] in the + [invariant] test above *) + let added = ref false in + let _ = add t ~key ~data ~compare ~added ~replace in + [%test_result: bool] + !added + ~expect:(not (Map.mem map key))) + + let remove = remove + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple2 constructors_gen Key.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t] + ~f:(fun (constructors, key) -> + let t, map = reify constructors in + (* test [removed], other aspects of [remove] are tested via [reify] in the + [invariant] test above *) + let removed = ref false in + let _ = remove t key ~compare ~removed in + [%test_result: bool] + !removed + ~expect:(Map.mem map key)) + + let find = find + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple2 constructors_gen Key.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t] + ~f:(fun (constructors, key) -> + let t, map = reify constructors in + [%test_result: Data.t option] + (find t key ~compare) + ~expect:(Map.find map key)) + + let mem = mem + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple2 constructors_gen Key.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t] + ~f:(fun (constructors, key) -> + let t, map = reify constructors in + [%test_result: bool] + (mem t key ~compare) + ~expect:(Map.mem map key)) + + let first = first + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + [%test_result: (Key.t * Data.t) option] + (first t) + ~expect:(Map.min_elt map)) + + let last = last + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + [%test_result: (Key.t * Data.t) option] + (last t) + ~expect:(Map.max_elt map)) + + let find_and_call = find_and_call + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple2 constructors_gen Key.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t] + ~f:(fun (constructors, key) -> + let t, map = reify constructors in + [%test_result: [ `Found of Data.t | `Not_found of Key.t ]] + (find_and_call t key ~compare + ~if_found: (fun data -> `Found data) + ~if_not_found: (fun key -> `Not_found key)) + ~expect:(match Map.find map key with + | None -> `Not_found key + | Some data -> `Found data)) + + let findi_and_call = findi_and_call + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple2 constructors_gen Key.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t] + ~f:(fun (constructors, key) -> + let t, map = reify constructors in + [%test_result: [ `Found of (Key.t * Data.t) | `Not_found of Key.t ]] + (findi_and_call t key ~compare + ~if_found: (fun ~key ~data -> `Found (key, data)) + ~if_not_found: (fun key -> `Not_found key)) + ~expect:(match Map.find map key with + | None -> `Not_found key + | Some data -> `Found (key, data))) + + let iter = iter + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + [%test_result: (Key.t * Data.t) list] + (let q = Queue.create () in + iter t ~f:(fun ~key ~data -> + Queue.enqueue q (key, data)); + Queue.to_list q) + ~expect:(Map.to_alist map)) + + let mapi_inplace = mapi_inplace + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + [%test_result: (Key.t * Data.t) list] + (mapi_inplace t ~f:(fun ~key:_ ~data -> data ^ data); + (fold t ~init:[] ~f:(fun ~key ~data acc -> (key, data) :: acc))) + ~expect:(Map.map map ~f:(fun data -> data ^ data) + |> Map.to_alist |> List.rev)) + + let fold = fold + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + [%test_result: (Key.t * Data.t) list] + (fold t ~init:[] ~f:(fun ~key ~data acc -> + (key, data) :: acc)) + ~expect:(Map.to_alist map |> List.rev)) + + end : module type of Avltree)) diff --git a/test/avltree_unit_tests.mli b/test/avltree_unit_tests.mli new file mode 100644 index 0000000..8b9cdbb --- /dev/null +++ b/test/avltree_unit_tests.mli @@ -0,0 +1 @@ +(* intentionally blank *) diff --git a/test/dune b/test/dune new file mode 100644 index 0000000..734364c --- /dev/null +++ b/test/dune @@ -0,0 +1,4 @@ +(library (name base_test) + (libraries base core_kernel.base_for_tests caml sexplib num + expect_test_helpers_kernel stdio) + (preprocess (pps ppx_jane -dont-apply=pipebang))) \ No newline at end of file diff --git a/test/hashtbl_tests.ml b/test/hashtbl_tests.ml new file mode 100644 index 0000000..6fa1f04 --- /dev/null +++ b/test/hashtbl_tests.ml @@ -0,0 +1,360 @@ +open! Base + + +module type Hashtbl_for_testing = sig + include Hashtbl.Accessors with type 'key key = 'key + include Invariant.S2 with type ('key, 'data) t := ('key, 'data) t + + (* we don't define [module Poly : Hashtbl.S_poly] because we want to require only + the minimal number of constructors necessary to implement the tests, and also avoid + conflicting with any existing names. *) + val create_poly : ?size:int -> unit -> ('key, 'data) t + + val of_alist_poly_exn : ('key * 'data) list -> ('key, 'data) t + val of_alist_poly_or_error : ('key * 'data) list -> ('key, 'data) t Or_error.t +end + +module Make (Hashtbl : Hashtbl_for_testing) = struct + open Poly + + let test_data = [("a",1);("b",2);("c",3)] + + let test_hash = begin + let h = Hashtbl.create_poly () ~size:10 in + List.iter test_data ~f:(fun (k,v) -> + Hashtbl.set h ~key:k ~data:v + ); + h + end + + (* This is a very strong notion of equality on hash tables *) + let equal t t' equal_data = + let subtable t t' = + try + List.for_all (Hashtbl.keys t) ~f:(fun key -> + equal_data (Hashtbl.find_exn t key) (Hashtbl.find_exn t' key)) + with + | Invalid_argument _ -> false + in + subtable t t' && subtable t' t + + let%test "find" = + let found = Hashtbl.find test_hash "a" in + let not_found = Hashtbl.find test_hash "A" in + Hashtbl.invariant ignore ignore test_hash; + match found,not_found with + | Some _, None -> true + | _ -> false + ;; + + let%test "findi_and_call" = + let our_hash = Hashtbl.copy test_hash in + let test_string = "test string" in + Hashtbl.add_exn our_hash ~key:test_string ~data:10; + let test_string' = "test " ^ "string" in + assert (not (phys_equal test_string test_string')); + Hashtbl.findi_and_call our_hash test_string' + ~if_found:(fun ~key ~data -> phys_equal test_string key && data = 10) + ~if_not_found:(fun _ -> false) + ;; + + let%test_unit "add" = + let our_hash = Hashtbl.copy test_hash in + let duplicate = Hashtbl.add our_hash ~key:"a" ~data:4 in + let no_duplicate = Hashtbl.add our_hash ~key:"d" ~data:5 in + assert (Hashtbl.find our_hash "a" = Some 1); + assert (Hashtbl.find our_hash "d" = Some 5); + Hashtbl.invariant ignore ignore our_hash; + assert (match duplicate, no_duplicate with + | `Duplicate, `Ok -> true + | _ -> false) + ;; + + let%test "iter" = + let predicted = List.sort ~compare:Int.descending ( + List.map test_data ~f:(fun (_,v) -> v)) + in + let found = + let found = ref [] in + Hashtbl.iter test_hash ~f:(fun v -> found := v :: !found); + !found + |> List.sort ~compare:Int.descending + in + List.equal Int.equal predicted found + ;; + + let%test "iter_keys" = + let predicted = List.sort ~compare:String.descending ( + List.map test_data ~f:(fun (k,_) -> k)) + in + let found = + let found = ref [] in + Hashtbl.iter_keys test_hash ~f:(fun k -> found := k :: !found); + !found + |> List.sort ~compare:String.descending + in + List.equal String.equal predicted found + ;; + + let%test_module "of_alist" = + (module struct + + let%test "size" = + let predicted = List.length test_data in + let found = Hashtbl.length (Hashtbl.of_alist_poly_exn test_data) in + predicted = found + ;; + + let%test "right keys" = + let predicted = List.map test_data ~f:(fun (k,_) -> k) in + let found = Hashtbl.keys (Hashtbl.of_alist_poly_exn test_data) in + let sp = List.sort ~compare:Poly.ascending predicted in + let sf = List.sort ~compare:Poly.ascending found in + sp = sf + ;; + end) + + let%test_module "of_alist_or_error" = + (module struct + + let%test "unique" = + Result.is_ok (Hashtbl.of_alist_poly_or_error test_data) + + let%test "duplicate" = + Result.is_error (Hashtbl.of_alist_poly_or_error (test_data @ test_data)) + + end) + + let%test "size and right keys" = + let predicted = List.map test_data ~f:(fun (k,_) -> k) in + let found = Hashtbl.keys test_hash in + let sp = List.sort ~compare:Poly.ascending predicted in + let sf = List.sort ~compare:Poly.ascending found in + sp = sf + ;; + + let%test "size and right data" = + let predicted = List.map test_data ~f:(fun (_,v) -> v) in + let found = Hashtbl.data test_hash in + let sp = List.sort ~compare:Poly.ascending predicted in + let sf = List.sort ~compare:Poly.ascending found in + sp = sf + ;; + + let%test "map" = + let add1 x = x + 1 in + let predicted_data = + List.sort ~compare:Poly.ascending + (List.map test_data ~f:(fun (k,v) -> (k,add1 v))) + in + let found_alist = + Hashtbl.map test_hash ~f:add1 + |> Hashtbl.to_alist + |> List.sort ~compare:Poly.ascending + in + List.equal Poly.equal predicted_data found_alist + ;; + + let%test_unit "filter_map" = + let f x = Some x in + let result = Hashtbl.filter_map test_hash ~f in + assert (equal test_hash result Int.(=)); + let is_even x = x % 2 = 0 in + let add1_to_even x = if is_even x then Some (x + 1) else None in + let predicted_data = List.filter_map test_data ~f:(fun (k,v) -> + if is_even v then Some (k, v+1) else None) + in + let found = Hashtbl.filter_map test_hash ~f:add1_to_even in + let found_alist = + List.sort ~compare:Poly.ascending (Hashtbl.to_alist found) + in + assert (List.equal Poly.equal predicted_data found_alist ) + ;; + + let%test "filter_inplace" = + let f x = x <> 2 in + let predicted_data = + List.sort ~compare:Poly.ascending + (List.filter test_data ~f:(fun (_,v) -> f v)) + in + let test_hash = Hashtbl.copy test_hash in + Hashtbl.filter_inplace test_hash ~f; + let found_alist = + Hashtbl.to_alist test_hash + |> List.sort ~compare:Poly.ascending + in + List.equal Poly.equal predicted_data found_alist + ;; + + let%test "filter_keys_inplace" = + let f x = x = "c" in + let predicted_data = + List.sort ~compare:Poly.ascending + (List.filter test_data ~f:(fun (k,_) -> f k)) + in + let test_hash = Hashtbl.copy test_hash in + Hashtbl.filter_keys_inplace test_hash ~f; + let found_alist = + Hashtbl.to_alist test_hash + |> List.sort ~compare:Poly.ascending + in + List.equal Poly.equal predicted_data found_alist + ;; + + let%test "filter_map_inplace" = + let f x = if x = 1 then None else Some (x * 2) in + let predicted_data = + List.sort ~compare:Poly.ascending + (List.filter_map test_data ~f:(fun (k,v) -> Option.map (f v) ~f:(fun x -> (k,x)))) + in + let test_hash = Hashtbl.copy test_hash in + Hashtbl.filter_map_inplace test_hash ~f; + let found_alist = + Hashtbl.to_alist test_hash + |> List.sort ~compare:Poly.ascending + in + List.equal Poly.equal predicted_data found_alist + ;; + + let%test "map_inplace" = + let f x = x + 3 in + let predicted_data = + List.sort ~compare:Poly.ascending + (List.map test_data ~f:(fun (k,v) -> (k,f v))) + in + let test_hash = Hashtbl.copy test_hash in + Hashtbl.map_inplace test_hash ~f; + let found_alist = + Hashtbl.to_alist test_hash + |> List.sort ~compare:Poly.ascending + in + List.equal Poly.equal predicted_data found_alist + ;; + + let%test_unit "insert-find-remove" = + let t = Hashtbl.create_poly () ~size:1 in + let inserted = ref [] in + Random.init 123; + let verify_inserted t = + let missing = + List.fold !inserted ~init:[] ~f:(fun acc (key, data) -> + match Hashtbl.find t key with + | None -> `Missing key :: acc + | Some d -> + if data = d then acc + else `Wrong_data (key, data) :: acc) + in + match missing with + | [] -> () + | _ -> + raise_s + [%message + "some inserts are missing" + (missing : [`Missing of int | `Wrong_data of int * int ] list) ] + in + let equal = Int.equal in + let rec loop i t = + if i < 2000 then begin + let k = Random.int 10_000 in + inserted := List.Assoc.add (List.Assoc.remove !inserted ~equal k) ~equal k i; + Hashtbl.set t ~key:k ~data:i; + Hashtbl.invariant ignore ignore t; + verify_inserted t; + loop (i + 1) t + end + in + loop 0 t; + List.iter !inserted ~f:(fun (x, _) -> + Hashtbl.remove t x; + Hashtbl.invariant ignore ignore t; + begin match Hashtbl.find t x with + | None -> () + | Some _ -> failwith (Printf.sprintf "present after removal: %d" x) + end; + inserted := List.Assoc.remove !inserted ~equal x; + verify_inserted t) + ;; + + let%test_unit "clear" = + let t = Hashtbl.create_poly () ~size:1 in + let l = List.range 0 100 in + let verify_present l = List.for_all l ~f:(Hashtbl.mem t) in + let verify_not_present l = + List.for_all l ~f:(fun i -> not (Hashtbl.mem t i)) + in + List.iter l ~f:(fun i -> Hashtbl.set t ~key:i ~data:(i * i)); + List.iter l ~f:(fun i -> Hashtbl.set t ~key:i ~data:(i * i)); + assert (Hashtbl.length t = 100); + assert (verify_present l); + Hashtbl.clear t; + Hashtbl.invariant ignore ignore t; + assert (Hashtbl.length t = 0); + assert (verify_not_present l); + let l = List.take l 42 in + List.iter l ~f:(fun i -> Hashtbl.set t ~key:i ~data:(i * i)); + assert (Hashtbl.length t = 42); + assert (verify_present l); + Hashtbl.invariant ignore ignore t + ;; + + let%test_unit "mem" = + let t = Hashtbl.create_poly () ~size:1 in + Hashtbl.invariant ignore ignore t; + assert (not (Hashtbl.mem t "Fred")); + Hashtbl.invariant ignore ignore t; + Hashtbl.set t ~key:"Fred" ~data:"Wilma"; + Hashtbl.invariant ignore ignore t; + assert (Hashtbl.mem t "Fred"); + Hashtbl.invariant ignore ignore t; + Hashtbl.remove t "Fred"; + Hashtbl.invariant ignore ignore t; + assert (not (Hashtbl.mem t "Fred")); + Hashtbl.invariant ignore ignore t + ;; + + let%test_unit "exists" = + let t = Hashtbl.create_poly () in + assert (not (Hashtbl.exists t ~f:(fun _ -> failwith "can't be called"))); + assert (not (Hashtbl.existsi t ~f:(fun ~key:_ ~data:_ -> failwith "can't be called"))); + Hashtbl.set t ~key:7 ~data:3; + assert (not (Hashtbl.exists t ~f:(Int.equal 4))); + Hashtbl.set t ~key:8 ~data:4; + assert (Hashtbl.exists t ~f:(Int.equal 4)); + Hashtbl.set t ~key:9 ~data:5; + assert (Hashtbl.existsi t ~f:(fun ~key ~data -> key + data = 14)) + + let%test_unit "for_all" = + let t = Hashtbl.create_poly () in + assert (Hashtbl.for_all t ~f:(fun _ -> failwith "can't be called")); + assert (Hashtbl.for_alli t ~f:(fun ~key:_ ~data:_ -> failwith "can't be called")); + Hashtbl.set t ~key:7 ~data:3; + assert (Hashtbl.for_all t ~f:(fun x -> Int.equal x 3)); + Hashtbl.set t ~key:8 ~data:4; + assert (not (Hashtbl.for_all t ~f:(fun x -> Int.equal x 3))); + Hashtbl.set t ~key:9 ~data:5; + assert (Hashtbl.for_alli t ~f:(fun ~key ~data -> key - 4 = data)) + + let%test_unit "count" = + let t = Hashtbl.create_poly () in + assert (Hashtbl.count t ~f:(fun _ -> failwith "can't be called") = 0); + assert (Hashtbl.counti t ~f:(fun ~key:_ ~data:_ -> failwith "can't be called") = 0); + Hashtbl.set t ~key:7 ~data:3; + assert (Hashtbl.count t ~f:(fun x -> Int.equal x 3) = 1); + Hashtbl.set t ~key:8 ~data:4; + assert (Hashtbl.count t ~f:(fun x -> Int.equal x 3) = 1); + Hashtbl.set t ~key:9 ~data:5; + assert (Hashtbl.counti t ~f:(fun ~key ~data -> key - 4 = data) = 3) + + let%test_unit "merge" = + let make alist = Hashtbl.of_alist_poly_exn alist in + let t1 = make [ 1, 111 ; 2, 222 ; 3, 333 ] in + let t2 = make [ 1, 123 ; 2, 222 ; 4, 444 ] in + [%test_result: (int * [`Left of int|`Right of int|`Both of int*int]) List.t] + (Hashtbl.merge t1 t2 ~f:(fun ~key:_ -> function + | `Left x -> Some (`Left x) + | `Right y -> Some (`Right y) + | `Both (x, y) -> if x=y then None else Some (`Both (x, y))) + |> Hashtbl.to_alist + |> List.sort ~compare:(fun (x,_) (y,_) -> Int.compare x y)) + ~expect:[ 1, `Both (111,123) ; 3, `Left 333 ; 4, `Right 444 ] +end diff --git a/test/hashtbl_tests.mli b/test/hashtbl_tests.mli new file mode 100644 index 0000000..9109c90 --- /dev/null +++ b/test/hashtbl_tests.mli @@ -0,0 +1,13 @@ +open! Base + +module type Hashtbl_for_testing = sig + include Hashtbl.Accessors with type 'key key = 'key + include Invariant.S2 with type ('key, 'data) t := ('key, 'data) t + + val create_poly : ?size:int -> unit -> ('key, 'data) t + + val of_alist_poly_exn : ('key * 'data) list -> ('key, 'data) t + val of_alist_poly_or_error : ('key * 'data) list -> ('key, 'data) t Or_error.t +end + +module Make (Hashtbl : Hashtbl_for_testing) : sig end diff --git a/test/import.ml b/test/import.ml new file mode 100644 index 0000000..3858c5e --- /dev/null +++ b/test/import.ml @@ -0,0 +1,56 @@ +include Base +include Stdio +include Base_for_tests +include Expect_test_helpers_kernel + +module Quickcheck = struct + include Core_kernel.Quickcheck + + module Bool = Core_kernel.Bool + module Char = Core_kernel.Char + module Int = Core_kernel.Int + module Int32 = Core_kernel.Int32 + module Int64 = Core_kernel.Int64 + module List = Core_kernel.List + module Nativeint = Core_kernel.Nativeint + module String = Core_kernel.String +end + +module Core_kernel = struct +end [@@deprecated "[since 1970-01] Don't use Core_kernel in Base tests. Use Base."] + +let () = Base.Not_exposed_properly.Int_conversions.sexp_of_int_style := `Underscores + +let is_none = Option.is_none +let is_some = Option.is_some +let ok_exn = Or_error.ok_exn +let stage = Staged.stage +let unstage = Staged.unstage + +module type Hash = sig + type t [@@deriving hash, sexp_of] +end + +let check_hash_coherence (type t) here (module T : Hash with type t = t) ts = + List.iter ts ~f:(fun t -> + let hash1 = T.hash t in + let hash2 = [%hash: T.t] t in + require here (hash1 = hash2) ~cr:CR_soon + ~if_false_then_print_s: + (lazy [%message "" ~value:(t : T.t) (hash1 : int) (hash2 : int)])); +;; + +module type Int_hash = sig + include Hash + val of_int_exn : int -> t + val min_value : t + val max_value : t +end + +let check_int_hash_coherence (type t) here (module I : Int_hash with type t = t) = + check_hash_coherence here (module I) + [ I.min_value + ; I.of_int_exn 0 + ; I.of_int_exn 37 + ; I.max_value ]; +;; diff --git a/test/int_math_tests.ml b/test/int_math_tests.ml new file mode 100644 index 0000000..6d84011 --- /dev/null +++ b/test/int_math_tests.ml @@ -0,0 +1,39 @@ +let%test_module "overflow_bounds" = + (module struct + module Pow_overflow_bounds = Base.Not_exposed_properly.Pow_overflow_bounds + let%test _ = Pow_overflow_bounds.overflow_bound_max_int_value = Caml.max_int + let%test _ = Pow_overflow_bounds.overflow_bound_max_int64_value = Int64.max_int + + module Big_int = struct + include Big_int + let (>) = gt_big_int + let (=) = eq_big_int + let (^) = power_big_int_positive_int + let (+) = add_big_int + let one = unit_big_int + let to_string = string_of_big_int + end + + let test_overflow_table tbl conv max_val = + assert (Array.length tbl = 64); + let max_val = conv max_val in + StdLabels.Array.iteri tbl ~f:(fun i max_base -> + let max_base = conv max_base in + let overflows b = Big_int.((b ^ i) > max_val) in + let is_ok = + if i = 0 then Big_int.(max_base = max_val) + else + not (overflows max_base) && overflows Big_int.(max_base + one) + in + if not is_ok then + Base.Printf.failwithf + "overflow table check failed for %s (index %d)" + (Big_int.to_string max_base) i ()) + ;; + + let%test_unit _ = test_overflow_table Pow_overflow_bounds.int_positive_overflow_bounds + Big_int.big_int_of_int Caml.max_int + + let%test_unit _ = test_overflow_table Pow_overflow_bounds.int64_positive_overflow_bounds + Big_int.big_int_of_int64 Int64.max_int + end) diff --git a/test/interfaces_tests.ml b/test/interfaces_tests.ml new file mode 100644 index 0000000..0b07c4d --- /dev/null +++ b/test/interfaces_tests.ml @@ -0,0 +1,48 @@ +open Base + +let () = + let module M : sig + open Set + + type ('a, 'b) t + + include Accessors2 + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) Set.Using_comparator.Tree.t + with type ('a, 'b) named := ('a, 'b) Set.Named.t + + include Creators_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) With_first_class_module.t + with type ('a, 'b) set := ('a, 'b) t + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) Set.Using_comparator.Tree.t + end = struct + type 'a elt = 'a + type _ cmp + include Set + let of_tree _ = assert false + let to_tree _ = assert false + end in () + +let () = + let module M : sig + open Map + + type ('a, 'b, 'c) t + + include Accessors3 + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Map.Using_comparator.Tree.t + + include Creators_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) With_first_class_module.t + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Map.Using_comparator.Tree.t + end = struct + type 'a key = 'a + include Map + let of_tree _ = assert false + let to_tree _ = assert false + end + in + () diff --git a/test/test_am_testing.ml b/test/test_am_testing.ml new file mode 100644 index 0000000..70fb5d1 --- /dev/null +++ b/test/test_am_testing.ml @@ -0,0 +1,8 @@ +open! Base +open! Import + +let%expect_test _ = + print_s [%sexp (Exported_for_specific_uses.am_testing : bool)]; + [%expect {| + true |}]; +;; diff --git a/test/test_am_testing.mli b/test/test_am_testing.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_am_testing.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_am_testing.mlt b/test/test_am_testing.mlt new file mode 100644 index 0000000..3eaebd2 --- /dev/null +++ b/test/test_am_testing.mlt @@ -0,0 +1,7 @@ +open! Base +open! Expect_test_helpers_base + +let () = print_s [%sexp (Exported_for_specific_uses.am_testing : bool)]; +[%%expect {| +true +|}];; diff --git a/test/test_applicative.ml b/test/test_applicative.ml new file mode 100644 index 0000000..06aaa6d --- /dev/null +++ b/test/test_applicative.ml @@ -0,0 +1,383 @@ +open! Import + +let%test_module "Make" = + (module struct + module A = + Applicative.Make (struct + type 'a t = 'a Or_error.t + let return = Or_error.return + let apply = Or_error.apply + let map = `Define_using_apply + end) + + let error = Or_error.error_string + + module Tests : module type of A = struct + let return = A.return + let%expect_test _ = + print_s [%sexp (return "okay" : string Or_error.t)]; + [%expect {| (Ok okay) |}]; + ;; + + let apply = A.apply + let%expect_test _ = + let test x y = print_s [%sexp (apply x y : string Or_error.t)] in + test (Ok String.capitalize) (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay") (Ok "okay"); + [%expect {| (Error "not okay") |}]; + test (Ok String.capitalize) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fun") (error "no arg"); + [%expect {| (Error ("no fun" "no arg")) |}]; + ;; + + let ( <*> ) = A.( <*> ) + let%expect_test _ = + let test x y = print_s [%sexp (( <*> ) x y : string Or_error.t)] in + test (Ok String.capitalize) (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay") (Ok "okay"); + [%expect {| (Error "not okay") |}]; + test (Ok String.capitalize) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fun") (error "no arg"); + [%expect {| (Error ("no fun" "no arg")) |}]; + ;; + + let ( *> ) = A.( *> ) + let%expect_test _ = + let test x y = print_s [%sexp (( *> ) x y : string Or_error.t)] in + test (Ok ()) (Ok "kay"); + [%expect {| (Ok kay) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok ()) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let ( <* ) = A.( <* ) + let%expect_test _ = + let test x y = print_s [%sexp (( <* ) x y : string Or_error.t)] in + test (Ok "okay") (Ok ()); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok ()); + [%expect {| (Error "not okay") |}]; + test (Ok "okay") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let both = A.both + let%expect_test _ = + let test x y = print_s [%sexp (both x y : (string * string) Or_error.t)] in + test (Ok "o") (Ok "kay"); + [%expect {| (Ok (o kay)) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let map = A.map + let%expect_test _ = + let test x = print_s [%sexp (map x ~f:String.capitalize : string Or_error.t)] in + test (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay"); + [%expect {| (Error "not okay") |}]; + ;; + + let ( >>| ) = A.( >>| ) + let%expect_test _ = + let test x = print_s [%sexp (x >>| String.capitalize : string Or_error.t)] in + test (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay"); + [%expect {| (Error "not okay") |}]; + ;; + + let map2 = A.map2 + let%expect_test _ = + let test x y = print_s [%sexp (map2 x y ~f:(^) : string Or_error.t)] in + test (Ok "o") (Ok "kay"); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let map3 = A.map3 + let%expect_test _ = + let test x y z = + print_s [%sexp (map3 x y z ~f:(fun a b c -> a ^ b ^ c) : string Or_error.t)] + in + test (Ok "o") (Ok "k") (Ok "ay"); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok "k") (Ok "ay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay") (Ok "ay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (Ok "k") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no 1st") (error "no 2nd") (error "no 3rd"); + [%expect {| (Error ("no 1st" "no 2nd" "no 3rd")) |}]; + ;; + + let all = A.all + let%expect_test _ = + let test list = print_s [%sexp (all list : string list Or_error.t)] in + test []; + [%expect {| (Ok ()) |}]; + test [Ok "okay"]; + [%expect {| (Ok (okay)) |}]; + test [Ok "o"; Ok "kay"]; + [%expect {| (Ok (o kay)) |}]; + test [Ok "o"; Ok "k"; Ok "ay"]; + [%expect {| (Ok (o k ay)) |}]; + test [error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok "okay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "okay"; error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok "o"; Ok "kay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "o"; error "oh no!"; Ok "aay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "o"; Ok "kay"; error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh"; error "no"; error "!"]; + [%expect {| (Error (oh no !)) |}]; + ;; + + let all_unit = A.all_unit + let all_ignore = all_unit + let%expect_test _ = + let test list = print_s [%sexp (all_unit list : unit Or_error.t)] in + test []; + [%expect {| (Ok ()) |}]; + test [Ok ()]; + [%expect {| (Ok ()) |}]; + test [Ok (); Ok ()]; + [%expect {| (Ok ()) |}]; + test [Ok (); Ok (); Ok ()]; + [%expect {| (Ok ()) |}]; + test [error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok (); Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); error "oh no!"; Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); Ok (); error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh"; error "no"; error "!"]; + [%expect {| (Error (oh no !)) |}]; + ;; + + module Applicative_infix = A.Applicative_infix + end + end) + +let%test_module "Make_using_map2" = + (module struct + module A = + Applicative.Make_using_map2 (struct + type 'a t = 'a Or_error.t + let return = Or_error.return + let map2 = Or_error.map2 + let map = `Define_using_map2 + end) + + let error = Or_error.error_string + + module Tests : module type of A = struct + let return = A.return + let%expect_test _ = + print_s [%sexp (return "okay" : string Or_error.t)]; + [%expect {| (Ok okay) |}]; + ;; + + let apply = A.apply + let%expect_test _ = + let test x y = print_s [%sexp (apply x y : string Or_error.t)] in + test (Ok String.capitalize) (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay") (Ok "okay"); + [%expect {| (Error "not okay") |}]; + test (Ok String.capitalize) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fun") (error "no arg"); + [%expect {| (Error ("no fun" "no arg")) |}]; + ;; + + let ( <*> ) = A.( <*> ) + let%expect_test _ = + let test x y = print_s [%sexp (( <*> ) x y : string Or_error.t)] in + test (Ok String.capitalize) (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay") (Ok "okay"); + [%expect {| (Error "not okay") |}]; + test (Ok String.capitalize) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fun") (error "no arg"); + [%expect {| (Error ("no fun" "no arg")) |}]; + ;; + + let ( *> ) = A.( *> ) + let%expect_test _ = + let test x y = print_s [%sexp (( *> ) x y : string Or_error.t)] in + test (Ok ()) (Ok "kay"); + [%expect {| (Ok kay) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok ()) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let ( <* ) = A.( <* ) + let%expect_test _ = + let test x y = print_s [%sexp (( <* ) x y : string Or_error.t)] in + test (Ok "okay") (Ok ()); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok ()); + [%expect {| (Error "not okay") |}]; + test (Ok "okay") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let both = A.both + let%expect_test _ = + let test x y = print_s [%sexp (both x y : (string * string) Or_error.t)] in + test (Ok "o") (Ok "kay"); + [%expect {| (Ok (o kay)) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let map = A.map + let%expect_test _ = + let test x = print_s [%sexp (map x ~f:String.capitalize : string Or_error.t)] in + test (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay"); + [%expect {| (Error "not okay") |}]; + ;; + + let ( >>| ) = A.( >>| ) + let%expect_test _ = + let test x = print_s [%sexp (x >>| String.capitalize : string Or_error.t)] in + test (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay"); + [%expect {| (Error "not okay") |}]; + ;; + + let map2 = A.map2 + let%expect_test _ = + let test x y = print_s [%sexp (map2 x y ~f:(^) : string Or_error.t)] in + test (Ok "o") (Ok "kay"); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let map3 = A.map3 + let%expect_test _ = + let test x y z = + print_s [%sexp (map3 x y z ~f:(fun a b c -> a ^ b ^ c) : string Or_error.t)] + in + test (Ok "o") (Ok "k") (Ok "ay"); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok "k") (Ok "ay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay") (Ok "ay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (Ok "k") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no 1st") (error "no 2nd") (error "no 3rd"); + [%expect {| (Error ("no 1st" "no 2nd" "no 3rd")) |}]; + ;; + + let all = A.all + let%expect_test _ = + let test list = print_s [%sexp (all list : string list Or_error.t)] in + test []; + [%expect {| (Ok ()) |}]; + test [Ok "okay"]; + [%expect {| (Ok (okay)) |}]; + test [Ok "o"; Ok "kay"]; + [%expect {| (Ok (o kay)) |}]; + test [Ok "o"; Ok "k"; Ok "ay"]; + [%expect {| (Ok (o k ay)) |}]; + test [error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok "okay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "okay"; error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok "o"; Ok "kay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "o"; error "oh no!"; Ok "aay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "o"; Ok "kay"; error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh"; error "no"; error "!"]; + [%expect {| (Error (oh no !)) |}]; + ;; + + let all_unit = A.all_unit + let all_ignore = all_unit + let%expect_test _ = + let test list = print_s [%sexp (all_unit list : unit Or_error.t)] in + test []; + [%expect {| (Ok ()) |}]; + test [Ok ()]; + [%expect {| (Ok ()) |}]; + test [Ok (); Ok ()]; + [%expect {| (Ok ()) |}]; + test [Ok (); Ok (); Ok ()]; + [%expect {| (Ok ()) |}]; + test [error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok (); Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); error "oh no!"; Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); Ok (); error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh"; error "no"; error "!"]; + [%expect {| (Error (oh no !)) |}]; + ;; + + module Applicative_infix = A.Applicative_infix + end + end) diff --git a/test/test_applicative.mli b/test/test_applicative.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_applicative.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_array.ml b/test/test_array.ml new file mode 100644 index 0000000..a6937b4 --- /dev/null +++ b/test/test_array.ml @@ -0,0 +1,288 @@ +open! Import +open! Array + +let%test_module "Binary_searchable" = + (module Test_binary_searchable.Test1 + (struct + include Array + module For_test = struct + let of_array = Fn.id + end + end)) + +let%test_module "Blit" = + (module Test_blit.Test1 + (struct + type 'a z = 'a + include Array + let create_bool ~len = create ~len false + end) + (Array)) + +let%test_module "Sort" = + (module struct + open Private.Sort + + let%test_module "Intro_sort.five_element_sort" = + (module struct + (* run [five_element_sort] on all permutations of an array of five elements *) + + let rec sprinkle x xs = + (x :: xs) :: begin + match xs with + | [] -> [] + | x' :: xs' -> + List.map (sprinkle x xs') ~f:(fun sprinkled -> x' :: sprinkled) + end + + let rec permutations = function + | [] -> [[]] + | x :: xs -> + List.concat_map (permutations xs) ~f:(fun perms -> sprinkle x perms) + + let all_perms = permutations [1;2;3;4;5] + let%test _ = List.length all_perms = 120 + let%test _ = not (List.contains_dup ~compare:[%compare: int list] all_perms) + + let%test _ = + List.for_all all_perms ~f:(fun l -> + let arr = Array.of_list l in + Intro_sort.five_element_sort arr ~compare:[%compare: int] 0 1 2 3 4; + [%compare.equal: int t] arr [|1;2;3;4;5|]) + end) + + module Test (M : Private.Sort.Sort) = struct + let random_data ~length ~range = + let arr = Array.create ~len:length 0 in + for i = 0 to length - 1 do + arr.(i) <- Random.int range; + done; + arr + ;; + + let assert_sorted arr = + M.sort arr ~left:0 ~right:(Array.length arr - 1) ~compare:[%compare: int]; + let len = Array.length arr in + let rec loop i prev = + if i = len then true + else if arr.(i) < prev then false + else loop (i + 1) arr.(i) + in + loop 0 (-1) + ;; + + let%test _ = assert_sorted (random_data ~length:0 ~range:100) + let%test _ = assert_sorted (random_data ~length:1 ~range:100) + let%test _ = assert_sorted (random_data ~length:100 ~range:1_000) + let%test _ = assert_sorted (random_data ~length:1_000 ~range:1) + let%test _ = assert_sorted (random_data ~length:1_000 ~range:10) + let%test _ = assert_sorted (random_data ~length:1_000 ~range:1_000_000) + end + + let%test_module _ = (module Test (Insertion_sort)) + let%test_module _ = (module Test (Heap_sort)) + let%test_module _ = (module Test (Intro_sort)) + + let%expect_test "Array.sort [||] only allocates when computing bounds" = + require_allocation_does_not_exceed (Minor_words 3) [%here] + (fun () -> Array.sort ~compare:Int.compare [||]); + [%expect {||}] + ;; + + let%expect_test "Array.sort [| 5; 2; 3; 4; 1 |] only allocates when computing bounds" = + let arr = [| 5; 2; 3; 4; 1 |] in + require_allocation_does_not_exceed (Minor_words 3) [%here] + (fun () -> Array.sort ~compare:Int.compare arr); + [%expect {||}] + ;; + end) + +let%test _ = is_sorted [||] ~compare:[%compare: int] +let%test _ = is_sorted [|0|] ~compare:[%compare: int] +let%test _ = is_sorted [|0;1;2;2;4|] ~compare:[%compare: int] +let%test _ = not (is_sorted [|0;1;2;3;2|] ~compare:[%compare: int]) + +let%test_unit _ = + List.iter + ~f:(fun (t, expect) -> + assert (Bool.equal expect (is_sorted_strictly (of_list t) ~compare:[%compare: int]))) + [ [] , true; + [ 1 ] , true; + [ 1; 2 ] , true; + [ 1; 1 ] , false; + [ 2; 1 ] , false; + [ 1; 2; 3 ], true; + [ 1; 1; 3 ], false; + [ 1; 2; 2 ], false; + ] +;; + +let%test _ = foldi [||] ~init:13 ~f:(fun _ _ _ -> failwith "bad") = 13 +let%test _ = foldi [| 13 |] ~init:17 ~f:(fun i ac x -> ac + i + x) = 30 +let%test _ = foldi [| 13; 17 |] ~init:19 ~f:(fun i ac x -> ac + i + x) = 50 + +let%test _ = counti [|0;1;2;3;4|] ~f:(fun idx x -> idx = x) = 5 +let%test _ = counti [|0;1;2;3;4|] ~f:(fun idx x -> idx = 4-x) = 1 + +let%test_unit _ = + for i = 0 to 5 do + let l1 = List.init i ~f:Fn.id in + let l2 = List.rev (to_list (of_list_rev l1)) in + assert ([%compare.equal: int list] l1 l2); + done +;; + +let%test_unit _ = + List.iter + ~f:(fun (t, len) -> + assert (Exn.does_raise (fun () -> unsafe_truncate t ~len))) + [ [| |] , -1 + ; [| |] , 0 + ; [| |] , 1 + ; [| 1 |], -1 + ; [| 1 |], 0 + ; [| 1 |], 2 + ] +;; + +let%test_unit _ = + for orig_len = 1 to 5 do + for new_len = 1 to orig_len do + let t = init orig_len ~f:Fn.id in + unsafe_truncate t ~len:new_len; + assert (length t = new_len); + for i = 0 to new_len - 1 do + assert (t.(i) = i); + done; + done; + done +;; + +let%test_unit _ = [%test_result: int array] (filter_opt [|Some 1; None; Some 2; None; Some 3|]) ~expect:[|1; 2; 3|] +let%test_unit _ = [%test_result: int array] (filter_opt [|Some 1; None; Some 2|]) ~expect:[|1; 2|] +let%test_unit _ = [%test_result: int array] (filter_opt [|Some 1|]) ~expect:[|1|] +let%test_unit _ = [%test_result: int array] (filter_opt [|None|]) ~expect:[||] +let%test_unit _ = [%test_result: int array] (filter_opt [||]) ~expect:[||] + +let%test_unit _ = + [%test_result: int] + (fold2_exn [||] [||] ~init:13 ~f:(fun _ -> failwith "fail")) + ~expect:13 +let%test_unit _ = + [%test_result: (int * string) list] + (fold2_exn [| 1 |] [| "1" |] ~init:[] ~f:(fun ac a b -> (a, b) :: ac)) + ~expect:[ 1, "1" ] + +let%test_unit _ = [%test_result: int array] (filter [| 0; 1 |] ~f:(fun n -> n < 2)) ~expect:[| 0; 1 |] +let%test_unit _ = [%test_result: int array] (filter [| 0; 1 |] ~f:(fun n -> n < 1)) ~expect:[| 0 |] +let%test_unit _ = [%test_result: int array] (filter [| 0; 1 |] ~f:(fun n -> n < 0)) ~expect:[||] + +let%test_unit _ = [%test_result: bool] (exists [||] ~f:(fun _ -> true)) ~expect:false +let%test_unit _ = [%test_result: bool] (exists [|0;1;2;3|] ~f:(fun x -> 4 = x)) ~expect:false +let%test_unit _ = [%test_result: bool] (exists [|0;1;2;3|] ~f:(fun x -> 2 = x)) ~expect:true + +let%test_unit _ = [%test_result: bool] (existsi [||] ~f:(fun _ _ -> true)) ~expect:false +let%test_unit _ = [%test_result: bool] (existsi [|0;1;2;3|] ~f:(fun i x -> i <> x)) ~expect:false +let%test_unit _ = [%test_result: bool] (existsi [|0;1;3;3|] ~f:(fun i x -> i <> x)) ~expect:true + +let%test_unit _ = [%test_result: bool] (for_all [||] ~f:(fun _ -> false)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_all [|1;2;3|] ~f:Int.is_positive ) ~expect:true +let%test_unit _ = [%test_result: bool] (for_all [|0;1;3;3|] ~f:Int.is_positive ) ~expect:false + +let%test_unit _ = [%test_result: bool] (for_alli [||] ~f:(fun _ _ -> false)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_alli [|0;1;2;3|] ~f:(fun i x -> i = x)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_alli [|0;1;3;3|] ~f:(fun i x -> i = x)) ~expect:false + +let%test_unit _ = [%test_result: bool] (exists2_exn [||] [||] ~f:(fun _ _ -> true)) ~expect:false +let%test_unit _ = [%test_result: bool] (exists2_exn [|0;2;4;6|] [|0;2;4;6|] ~f:(fun x y -> x <> y)) ~expect:false +let%test_unit _ = [%test_result: bool] (exists2_exn [|0;2;4;8|] [|0;2;4;6|] ~f:(fun x y -> x <> y)) ~expect:true +let%test_unit _ = [%test_result: bool] (exists2_exn [|2;2;4;6|] [|0;2;4;6|] ~f:(fun x y -> x <> y)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_all2_exn [||] [||] ~f:(fun _ _ -> false)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_all2_exn [|0;2;4;6|] [|0;2;4;6|] ~f:(fun x y -> x = y)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_all2_exn [|0;2;4;8|] [|0;2;4;6|] ~f:(fun x y -> x = y)) ~expect:false +let%test_unit _ = [%test_result: bool] (for_all2_exn [|2;2;4;6|] [|0;2;4;6|] ~f:(fun x y -> x = y)) ~expect:false + +let%test_unit _ = [%test_result: bool] (equal (=) [||] [||]) ~expect:true +let%test_unit _ = [%test_result: bool] (equal (=) [| 1 |] [| 1 |]) ~expect:true +let%test_unit _ = [%test_result: bool] (equal (=) [| 1; 2 |] [| 1; 2 |]) ~expect:true +let%test_unit _ = [%test_result: bool] (equal (=) [||] [| 1 |]) ~expect:false +let%test_unit _ = [%test_result: bool] (equal (=) [| 1 |] [||]) ~expect:false +let%test_unit _ = [%test_result: bool] (equal (=) [| 1 |] [| 1; 2 |]) ~expect:false +let%test_unit _ = [%test_result: bool] (equal (=) [| 1; 2 |] [| 1; 3 |]) ~expect:false + +let%test_unit _ = + [%test_result: (int * int) option] + (findi [|1;2;3;4|] ~f:(fun i x -> i = 2*x)) + ~expect:None +let%test_unit _ = + [%test_result: (int * int) option] + (findi [|1;2;1;4|] ~f:(fun i x -> i = 2*x)) + ~expect:(Some (2, 1)) + +let%test_unit _ = [%test_result: int option] (find_mapi [|0;5;2;1;4|] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 0) +let%test_unit _ = [%test_result: int option] (find_mapi [|3;5;2;1;4|] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 4) +let%test_unit _ = [%test_result: int option] (find_mapi [|3;5;1;1;4|] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 8) +let%test_unit _ = [%test_result: int option] (find_mapi [|3;5;1;1;2|] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:None + +let%test_unit _ = + List.iter + ~f:(fun (l, expect) -> + let t = of_list l in + assert (Poly.equal expect (find_consecutive_duplicate t ~equal:Poly.equal))) + [ [] , None + ; [ 1 ] , None + ; [ 1; 1 ] , Some (1, 1) + ; [ 1; 2 ] , None + ; [ 1; 2; 1 ] , None + ; [ 1; 2; 2 ] , Some (2, 2) + ; [ 1; 1; 2; 2 ], Some (1, 1) + ] +;; + +let%test_unit _ = [%test_result: int option] (random_element [| |]) ~expect:None +let%test_unit _ = [%test_result: int option] (random_element [| 0 |]) ~expect:(Some 0) + +let%test_unit _ = + List.iter + [ [||] + ; [| 1 |] + ; [| 1; 2; 3; 4; 5 |] + ] + ~f:(fun t -> + [%test_result: int array] + (Sequence.to_array (to_sequence t)) + ~expect:t) +;; + +let test_fold_map array ~init ~f ~expect = + [%test_result: int array] (folding_map array ~init ~f) ~expect:(snd expect); + [%test_result: int * int array] (fold_map array ~init ~f) ~expect + +let test_fold_mapi array ~init ~f ~expect = + [%test_result: int array] (folding_mapi array ~init ~f) ~expect:(snd expect); + [%test_result: int * int array] (fold_mapi array ~init ~f) ~expect + +let%test_unit _ = test_fold_map [|1;2;3;4|] ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) + ~expect:(10, [|1;3;6;10|]) +let%test_unit _ = test_fold_map [||] ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) + ~expect:(0, [||]) +let%test_unit _ = test_fold_mapi [|1;2;3;4|] ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) + ~expect:(20, [|0;2;8;20|]) +let%test_unit _ = test_fold_mapi [||] ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) + ~expect:(0, [||]) + +let%test "equal does not allocate" = + let arr1 = [|1;2;3;4|] in + let arr2 = [|1;2;4;3|] in + require_no_allocation [%here] (fun () -> + not (equal Int.equal arr1 arr2)) + +let%test "foldi does not allocate" = + let arr = [|1;2;3;4|] in + let f = fun i x y -> i + x + y in + require_no_allocation [%here] (fun () -> + 16 = (foldi ~init:0 ~f arr)) diff --git a/test/test_array.mli b/test/test_array.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_array.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_backtrace.ml b/test/test_backtrace.ml new file mode 100644 index 0000000..9f7860d --- /dev/null +++ b/test/test_backtrace.ml @@ -0,0 +1,14 @@ +open! Import +open! Backtrace + +let%test_unit _ [@tags "no-js"] = + let t = get () in + assert (String.length (to_string t) > 0) +;; + +let%expect_test _ = + Stdio.Out_channel.(output_string stdout) + (Sexp.to_string (sexp_of_t (Exn.with_recording false ~f:Exn.most_recent))); + [%expect {| + ("") |}]; +;; diff --git a/test/test_backtrace.mli b/test/test_backtrace.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_backtrace.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_base.ml b/test/test_base.ml new file mode 100644 index 0000000..eecf934 --- /dev/null +++ b/test/test_base.ml @@ -0,0 +1,17 @@ +open! Import + +let%expect_test _ = + let f x = x * 2 in + let g x = x + 3 in + print_s [%sexp (f @@ 5 : int)]; + [%expect {| 10 |}]; + print_s [%sexp (g @@ f @@ 5 : int)]; + [%expect {| 13 |}]; + print_s [%sexp (f @@ g @@ 5 : int)]; + [%expect {| 16 |}]; +;; + +let%expect_test "exp is present at the toplevel" = + print_s [%sexp (2 ** 8 : int)]; + [%expect {| 256 |}] +;; diff --git a/test/test_base.mli b/test/test_base.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_base.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_blit.ml b/test/test_blit.ml new file mode 100644 index 0000000..b62c7fb --- /dev/null +++ b/test/test_blit.ml @@ -0,0 +1,81 @@ +open! Import +open! Blit + +(* This unit test checks that when [blit] calls [unsafe_blit], the slices are valid. + It also checks that [blit] doesn't call [unsafe_blit] when there is a range error. *) +let%test_module _ = + (module struct + + let blit_was_called = ref false + + let slices_are_valid = ref (Ok ()) + + module B = + Make + (struct + type t = bool array + let create ~len = Array.create false ~len + let length = Array.length + let unsafe_blit ~src ~src_pos ~dst ~dst_pos ~len = + blit_was_called := true; + slices_are_valid := + Or_error.try_with (fun () -> + assert (len >= 0); + assert (src_pos >= 0); + assert (src_pos + len <= Array.length src); + assert (dst_pos >= 0); + assert (dst_pos + len <= Array.length dst)); + Array.blit ~src ~src_pos ~dst ~dst_pos ~len; + ;; + end) + ;; + + let%test_module "Bool" = + (module Test_blit.Test + (struct + type t = bool + let equal = Bool.equal + let of_bool = Fn.id + end) + (struct + type t = bool array [@@deriving sexp_of] + let create ~len = Array.create false ~len + let length = Array.length + let get = Array.get + let set = Array.set + end) + (B)) + ;; + + let%test_unit _ = + let opts = [ None; Some (-1); Some 0; Some 1; Some 2 ] in + List.iter [ 0; 1; 2 ] ~f:(fun src -> + List.iter [ 0; 1; 2 ] ~f:(fun dst -> + List.iter opts ~f:(fun src_pos -> + List.iter opts ~f:(fun src_len -> + List.iter opts ~f:(fun dst_pos -> + try begin + let check f = + blit_was_called := false; + slices_are_valid := Ok (); + match Or_error.try_with f with + | Error _ -> assert (not !blit_was_called); + | Ok () -> ok_exn !slices_are_valid + in + check (fun () -> + B.blito + ~src:(Array.create ~len:src false) ?src_pos ?src_len + ~dst:(Array.create ~len:dst false) ?dst_pos + ()); + check (fun () -> + ignore (B.subo (Array.create ~len:src false) ?pos:src_pos ?len:src_len + : bool array)); + end + with exn -> + raise_s [%message + "failure" + (exn : exn) + (src : int) (src_pos : int option) (src_len : int option) + (dst : int) (dst_pos : int option)]))))) + ;; + end) diff --git a/test/test_blit.mli b/test/test_blit.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_blit.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_bool.ml b/test/test_bool.ml new file mode 100644 index 0000000..3c67b70 --- /dev/null +++ b/test/test_bool.ml @@ -0,0 +1,32 @@ +open! Import + +let%expect_test "hash coherence" = + check_hash_coherence [%here] (module Bool) [ false; true ]; + [%expect {| |}] +;; + +let%expect_test "Bool.Non_short_circuiting.(||)" = + let (||) = Bool.Non_short_circuiting.(||) in + assert (true || true); + assert (true || false); + assert (false || true); + assert (not (false || false)); + + assert (true || (print_endline "rhs"; true)); + [%expect {|rhs|}]; + assert (false || (print_endline "rhs"; true)); + [%expect {|rhs|}]; +;; + +let%expect_test "Bool.Non_short_circuiting.(&&)" = + let (&&) = Bool.Non_short_circuiting.(&&) in + assert (true && true); + assert (not (true && false)); + assert (not (false && true)); + assert (not (false && false)); + + assert (true && (print_endline "rhs"; true)); + [%expect {|rhs|}]; + assert (not (false && (print_endline "rhs"; true))); + [%expect {|rhs|}]; +;; diff --git a/test/test_bool.mli b/test/test_bool.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_bool.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_bytes.ml b/test/test_bytes.ml new file mode 100644 index 0000000..a2fd53b --- /dev/null +++ b/test/test_bytes.ml @@ -0,0 +1,14 @@ +open! Import +open! Bytes + +let%test_module "Blit" = + (module Test_blit.Test + (struct + include Char + let of_bool b = if b then 'a' else 'b' + end) + (struct + include Bytes + let create ~len = create len + end) + (Bytes)) diff --git a/test/test_char.ml b/test/test_char.ml new file mode 100644 index 0000000..a80b90c --- /dev/null +++ b/test/test_char.ml @@ -0,0 +1,534 @@ +open! Import +open! Char + +let%test _ = not (is_whitespace '\008') (* backspace *) +let%test _ = is_whitespace '\009' (* '\t': horizontal tab *) +let%test _ = is_whitespace '\010' (* '\n': line feed *) +let%test _ = is_whitespace '\011' (* '\v': vertical tab *) +let%test _ = is_whitespace '\012' (* '\f': form feed *) +let%test _ = is_whitespace '\013' (* '\r': carriage return *) +let%test _ = not (is_whitespace '\014') (* shift out *) +let%test _ = is_whitespace '\032' (* space *) + +let%expect_test "hash coherence" = + check_hash_coherence [%here] (module Char) [ min_value; 'a'; max_value ]; + [%expect {| |}]; +;; + +let%test_module "int to char conversion" = + (module struct + + let%test_unit "of_int bounds" = + let bounds_check i = + [%test_result: t option] + (of_int i) + ~expect:None + ~message:(Int.to_string i) + in + for i = 1 to 100 do + bounds_check (-i); + bounds_check (255 + i); + done + + let%test_unit "of_int_exn vs of_int" = + for i = -100 to 300 do + [%test_eq: t option] + (of_int i) + (Option.try_with (fun () -> of_int_exn i)) + ~message:(Int.to_string i) + done + + let%test_unit "unsafe_of_int vs of_int_exn" = + for i = 0 to 255 do + [%test_eq: t] + (unsafe_of_int i) + (of_int_exn i) + ~message:(Int.to_string i) + done + + end) + +let%expect_test "all" = + print_s [%sexp (all : t list)]; + [%expect {| + ("\000" + "\001" + "\002" + "\003" + "\004" + "\005" + "\006" + "\007" + "\b" + "\t" + "\n" + "\011" + "\012" + "\r" + "\014" + "\015" + "\016" + "\017" + "\018" + "\019" + "\020" + "\021" + "\022" + "\023" + "\024" + "\025" + "\026" + "\027" + "\028" + "\029" + "\030" + "\031" + " " + ! + "\"" + # + $ + % + & + ' + "(" + ")" + * + + + , + - + . + / + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + : + ";" + < + = + > + ? + @ + A + B + C + D + E + F + G + H + I + J + K + L + M + N + O + P + Q + R + S + T + U + V + W + X + Y + Z + [ + "\\" + ] + ^ + _ + ` + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z + { + | + } + ~ + "\127" + "\128" + "\129" + "\130" + "\131" + "\132" + "\133" + "\134" + "\135" + "\136" + "\137" + "\138" + "\139" + "\140" + "\141" + "\142" + "\143" + "\144" + "\145" + "\146" + "\147" + "\148" + "\149" + "\150" + "\151" + "\152" + "\153" + "\154" + "\155" + "\156" + "\157" + "\158" + "\159" + "\160" + "\161" + "\162" + "\163" + "\164" + "\165" + "\166" + "\167" + "\168" + "\169" + "\170" + "\171" + "\172" + "\173" + "\174" + "\175" + "\176" + "\177" + "\178" + "\179" + "\180" + "\181" + "\182" + "\183" + "\184" + "\185" + "\186" + "\187" + "\188" + "\189" + "\190" + "\191" + "\192" + "\193" + "\194" + "\195" + "\196" + "\197" + "\198" + "\199" + "\200" + "\201" + "\202" + "\203" + "\204" + "\205" + "\206" + "\207" + "\208" + "\209" + "\210" + "\211" + "\212" + "\213" + "\214" + "\215" + "\216" + "\217" + "\218" + "\219" + "\220" + "\221" + "\222" + "\223" + "\224" + "\225" + "\226" + "\227" + "\228" + "\229" + "\230" + "\231" + "\232" + "\233" + "\234" + "\235" + "\236" + "\237" + "\238" + "\239" + "\240" + "\241" + "\242" + "\243" + "\244" + "\245" + "\246" + "\247" + "\248" + "\249" + "\250" + "\251" + "\252" + "\253" + "\254" + "\255") |}] + +let%expect_test "predicates" = + print_s [%sexp (List.filter all ~f:is_digit : t list)]; + [%expect {| (0 1 2 3 4 5 6 7 8 9) |}]; + print_s [%sexp (List.filter all ~f:is_lowercase : t list)]; + [%expect {| (a b c d e f g h i j k l m n o p q r s t u v w x y z) |}]; + print_s [%sexp (List.filter all ~f:is_uppercase : t list)]; + [%expect {| (A B C D E F G H I J K L M N O P Q R S T U V W X Y Z) |}]; + print_s [%sexp (List.filter all ~f:is_alpha : t list)]; + [%expect {| + (A + B + C + D + E + F + G + H + I + J + K + L + M + N + O + P + Q + R + S + T + U + V + W + X + Y + Z + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z) |}]; + print_s [%sexp (List.filter all ~f:is_alphanum : t list)]; + [%expect {| + (0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + A + B + C + D + E + F + G + H + I + J + K + L + M + N + O + P + Q + R + S + T + U + V + W + X + Y + Z + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z) |}]; + print_s [%sexp (List.filter all ~f:is_print : t list)]; + [%expect {| + (" " + ! + "\"" + # + $ + % + & + ' + "(" + ")" + * + + + , + - + . + / + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + : + ";" + < + = + > + ? + @ + A + B + C + D + E + F + G + H + I + J + K + L + M + N + O + P + Q + R + S + T + U + V + W + X + Y + Z + [ + "\\" + ] + ^ + _ + ` + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z + { + | + } + ~) |}]; + print_s [%sexp (List.filter all ~f:is_whitespace : t list)]; + [%expect {| ("\t" "\n" "\011" "\012" "\r" " ") |}]; diff --git a/test/test_char.mli b/test/test_char.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_char.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_compare.ml b/test/test_compare.ml new file mode 100644 index 0000000..f52f612 --- /dev/null +++ b/test/test_compare.ml @@ -0,0 +1,97 @@ +open! Base +open Expect_test_helpers_kernel + +module type S = sig + type t [@@deriving sexp_of] + include Comparable.Polymorphic_compare with type t := t +end + +(* Test the consistency of derived comparison operators with [compare] because many of + them are hand-optimized in [Base]. *) +let test (type a) here (module T : S with type t = a) list = + let op (type b) (module Result : S with type t = b) operator ~actual ~expect = + With_return.with_return (fun failed -> + List.iter list ~f:(fun arg1 -> + List.iter list ~f:(fun arg2 -> + let actual = actual arg1 arg2 in + let expect = expect arg1 arg2 in + if not (Result.compare actual expect = 0) then begin + print_cr here [%message + "comparison failed" + (operator : string) + (arg1 : T.t) + (arg2 : T.t) + (actual : Result.t) + (expect : Result.t)]; + failed.return () + end))) + in + let module C = Comparable.Make (T) in + op (module Bool) "equal" ~actual:T.equal ~expect:C.equal; + op (module T) "min" ~actual:T.min ~expect:C.min; + op (module T) "max" ~actual:T.max ~expect:C.max; + op (module Bool) "(=)" ~actual:T.(=) ~expect:C.(=); + op (module Bool) "(<)" ~actual:T.(<) ~expect:C.(<); + op (module Bool) "(>)" ~actual:T.(>) ~expect:C.(>); + op (module Bool) "(<>)" ~actual:T.(<>) ~expect:C.(<>); + op (module Bool) "(<=)" ~actual:T.(<=) ~expect:C.(<=); + op (module Bool) "(>=)" ~actual:T.(>=) ~expect:C.(>=); +;; + +let%expect_test "Base" = + test [%here] + (module struct include Base type t = int [@@deriving sexp_of] end) + Int.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Unit" = + test [%here] (module Unit) Unit.all; + [%expect {||}]; +;; + +let%expect_test "Bool" = + test [%here] (module Bool) Bool.all; + [%expect {||}]; +;; + +let%expect_test "Char" = + test [%here] (module Char) Char.all; + [%expect {||}]; +;; + +let%expect_test "Float" = + test [%here] (module Float) + Float.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Int" = + test [%here] (module Int) + Int.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Int32" = + test [%here] (module Int32) + Int32.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Int64" = + test [%here] (module Int64) + Int64.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Nativeint" = + test [%here] (module Nativeint) + Nativeint.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Int63" = + test [%here] (module Int63) + Int63.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; diff --git a/test/test_compare.mli b/test/test_compare.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_compare.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_container.ml b/test/test_container.ml new file mode 100644 index 0000000..2e791f5 --- /dev/null +++ b/test/test_container.ml @@ -0,0 +1,174 @@ +open! Import +open! Container + +module Test_generic + (Elt : sig + type 'a t + val of_int : int -> int t + val to_int : int t -> int + end) + (Container : sig + type 'a t [@@deriving sexp] + include Generic + with type 'a t := 'a t + with type 'a elt := 'a Elt.t + val mem : 'a t -> 'a Elt.t -> bool + val of_list : 'a Elt.t list -> [`Ok of 'a t | `Skip_test ] + end) + (* This signature constraint reminds us to add unit tests when functions are added to + [Generic]. *) + : sig + type 'a t [@@deriving sexp] + include Generic with type 'a t := 'a t + val mem : 'a t -> 'a Elt.t -> bool + end + with type 'a t := 'a Container.t + with type 'a elt := 'a Elt.t += struct + + open Container + + let find = find + let find_map = find_map + let fold = fold + let is_empty = is_empty + let iter = iter + let length = length + let mem = mem + let sexp_of_t = sexp_of_t + let t_of_sexp = t_of_sexp + let to_array = to_array + let to_list = to_list + let fold_result = fold_result + let fold_until = fold_until + + let%test_unit _ = + let ( = ) = Poly.equal in + let compare = Poly.compare in + List.iter [ 0; 1; 2; 3; 4; 8; 1024 ] ~f:(fun n -> + let list = List.init n ~f:Elt.of_int in + match Container.of_list list with + | `Skip_test -> () + | `Ok c -> + let sort l = List.sort l ~compare in + let sorts_are_equal l1 l2 = sort l1 = sort l2 in + assert (n = Container.length c); + assert ((n = 0) = Container.is_empty c); + assert (sorts_are_equal list + (Container.fold c ~init:[] ~f:(fun ac e -> e :: ac))); + assert (sorts_are_equal list (Container.to_list c)); + assert (sorts_are_equal list (Array.to_list (Container.to_array c))); + assert (n > 0 = is_some (Container.find c ~f:(fun e -> Elt.to_int e = 0))); + assert (n > 0 = is_some (Container.find c ~f:(fun e -> Elt.to_int e = n - 1))); + assert (is_none (Container.find c ~f:(fun e -> Elt.to_int e = n))); + assert (n > 0 = Container.mem c (Elt.of_int 0)); + assert (n > 0 = Container.mem c (Elt.of_int (n - 1))); + assert (not (Container.mem c (Elt.of_int n))); + assert (n > 0 = is_some (Container.find_map c ~f:(fun e -> + if Elt.to_int e = 0 then Some () else None))); + assert (n > 0 = is_some (Container.find_map c ~f:(fun e -> + if Elt.to_int e = n - 1 then Some () else None))); + assert (is_none (Container.find_map c ~f:(fun e -> + if Elt.to_int e = n then Some () else None))); + let r = ref 0 in + Container.iter c ~f:(fun e -> r := !r + Elt.to_int e); + assert (!r = List.fold list ~init:0 ~f:(fun n e -> n + Elt.to_int e)); + assert (!r = sum (module Int) c ~f:Elt.to_int); + let c2 = [%of_sexp: int Container.t] ([%sexp_of: int Container.t] c) in + assert (sorts_are_equal list (Container.to_list c2)); + let compare_elt a b = Int.compare (Elt.to_int a) (Elt.to_int b) in + if n = 0 then begin + assert (!r = 0); + assert (min_elt ~compare:compare_elt c = None); + assert (max_elt ~compare:compare_elt c = None); + end else begin + assert (!r = (n * (n-1) / 2)); + assert (Option.map ~f:Elt.to_int (min_elt ~compare:compare_elt c) = Some 0); + assert (Option.map ~f:Elt.to_int (max_elt ~compare:compare_elt c) = Some (Int.pred n)); + end; + let mid = Container.length c / 2 in + match + Container.fold_result c + ~init:0 + ~f:(fun count _elt -> if count = mid then Error count else Ok (count + 1)) + with + | Ok 0 -> assert (Container.length c = 0) + | Ok _ -> failwith "Expected fold to stop early" + | Error x -> assert (mid = x) + ) + ;; + + let min_elt = min_elt + let max_elt = max_elt + + let count = count + let sum = sum + let exists = exists + let for_all = for_all + + let%test_unit _ = + List.iter [ []; + [true]; + [false]; + [false; false]; + [true; false]; + [false; true]; + [true; true]; + ] + ~f:(fun bools -> + let count_should_be = + List.fold bools ~init:0 ~f:(fun n b -> if b then n + 1 else n) + in + let forall_should_be = List.fold bools ~init:true ~f:(fun ac b -> b && ac) in + let exists_should_be = List.fold bools ~init:false ~f:(fun ac b -> b || ac) in + match + Container.of_list + (List.map bools ~f:(fun b -> Elt.of_int (if b then 1 else 0))) + with + | `Skip_test -> () + | `Ok container -> + let is_one e = Elt.to_int e = 1 in + let ( = ) = Poly.equal in + assert (forall_should_be = Container.for_all container ~f:is_one); + assert (exists_should_be = Container.exists container ~f:is_one); + assert (count_should_be = Container.count container ~f:is_one); + ) + ;; + +end + +module Test_S1_allow_skipping_tests + (Container : sig + type 'a t [@@deriving sexp] + include Container.S1 with type 'a t := 'a t + val of_list : 'a list -> [`Ok of 'a t | `Skip_test] + end) = struct + + include + Test_generic + (struct + type 'a t = 'a + let of_int = Fn.id + let to_int = Fn.id + end) + (struct + include Container + let mem t a = mem t a ~equal:Poly.equal + end) + + let mem = Container.mem +end + +module Test_S1 + (Container : sig + type 'a t [@@deriving sexp] + include Container.S1 with type 'a t := 'a t + val of_list : 'a list -> 'a t + end) = Test_S1_allow_skipping_tests (struct + include Container + let of_list l = `Ok (of_list l) + end) + +include (Test_S1 (Array) : sig end) +include (Test_S1 (List) : sig end) +include (Test_S1 (Queue) : sig end) diff --git a/test/test_error.ml b/test/test_error.ml new file mode 100644 index 0000000..a2056c0 --- /dev/null +++ b/test/test_error.ml @@ -0,0 +1,23 @@ +open! Base +open! Import + +let errors = + [ Error.of_string "ABC" + ; Error.tag ~tag:"DEF" (Error.of_thunk (fun () -> "GHI")) + ; Error.create_s ([%message "foo" ~bar:(31:int)]) + ] + +let%expect_test _ = + List.iter errors ~f:(fun error -> show_raise (fun () -> Error.raise error)); + [%expect {| + (raised ABC) + (raised (DEF GHI)) + (raised (foo (bar 31))) |}] + +let%expect_test _ = + List.iter errors ~f:(fun error -> + show_raise (fun () -> Error.raise_s [%sexp (error : Error.t)])); + [%expect {| + (raised ABC) + (raised (DEF GHI)) + (raised (foo (bar 31))) |}] diff --git a/test/test_error.mli b/test/test_error.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_error.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_exn.ml b/test/test_exn.ml new file mode 100644 index 0000000..286861e --- /dev/null +++ b/test/test_exn.ml @@ -0,0 +1,18 @@ +open! Import +open! Exn + +let%expect_test "[create_s]" = + print_s [%sexp (create_s [%message "foo"] : t)]; + [%expect {| + foo |}]; + print_s [%sexp (create_s [%message "foo" "bar"] : t)]; + [%expect {| + (foo bar) |}]; + let sexp = [%message "foo"] in + print_s [%sexp (phys_equal sexp (sexp_of_t (create_s sexp)) : bool)]; + [%expect {| + true |}]; +;; + +let%test _ = not (does_raise Fn.ignore) +let%test _ = does_raise (fun () -> failwith "foo") diff --git a/test/test_exn.mli b/test/test_exn.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_exn.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_exported_int_conversions.ml b/test/test_exported_int_conversions.ml new file mode 100644 index 0000000..65fa77a --- /dev/null +++ b/test/test_exported_int_conversions.ml @@ -0,0 +1,235 @@ +open! Import + +module type S = sig + type t [@@deriving compare, sexp_of] + + val num_bits : int + + val min_value : t + val minus_one : t + val zero : t + val one : t + val max_value : t + + val to_int64 : t -> int64 + + val shift_right : t -> int -> t + + val random : Random.State.t -> t -> t -> t +end + +module I : S with type t = int = struct + include Int + let random = Random.State.int_incl +end + +module Native : S with type t = nativeint = struct + include Nativeint + let random = Random.State.nativeint_incl +end + +module I32 : S with type t = int32 = struct + include Int32 + let random = Random.State.int32_incl +end + +module I64 : S with type t = int64 = struct + include Int64 + let random = Random.State.int64_incl +end + +module I63 : S with type t = Int63.t = struct + include Int63 + let random state lo hi = Int63.random_incl ~state lo hi +end + +let iter (type a) (module M : S with type t = a) ~f = + let state = Random.State.make [| 0; 1; 2; 3; 4; 5 |] in + List.iter ~f [M.min_value; M.minus_one; M.zero; M.one; M.max_value]; + for _ = 1 to 10_000 do + (* skew toward low numbers of bits so that, e.g., choosing a random int64 does + frequently find a value that can be converted to int32. *) + let strip_bits = Random.State.int_incl state 0 (M.num_bits - 1) in + let lo = M.shift_right M.min_value strip_bits in + let hi = M.shift_right M.max_value strip_bits in + f (M.random state lo hi) + done + +let try_with f x = Option.try_with (fun () -> f x) + +(* Checks that a conversion from [A.t] to [B.t] is total using [of] and [to]. *) +let test_total (type a) (type b) + (module A : S with type t = a) + (module B : S with type t = b) + ~of_:b_of_a ~to_:a_to_b + = + iter (module A) ~f:(fun a -> + require_compare_equal [%here] (module B) (b_of_a a) (a_to_b a); + require_compare_equal [%here] (module Int64) (A.to_int64 a) (B.to_int64 (b_of_a a))) + +let truncate int64 ~num_bits = + Int64.shift_right (Int64.shift_left int64 (64 - num_bits)) (64 - num_bits) + +(* Checks that a conversion from [A.t] to [B.t] is partial using [of] and [to], and the + [_exn] equivalents. In the case where the conversion fails, ensure that the value, + converted to an [Int64.t] is outside the representable range of [B.t] converted to an + [Int64.t] as well. *) +let test_partial (type a) (type b) + (module A : S with type t = a) + (module B : S with type t = b) + ~of_:b_of_a ~of_exn:b_of_a_exn ~of_trunc:b_of_a_trunc + ~to_:a_to_b ~to_exn:a_to_b_exn ~to_trunc:a_to_b_trunc + = + let module B_option = struct type t = B.t option [@@deriving compare, sexp_of] end in + let convertible_count = ref 0 in + iter (module A) ~f:(fun a -> + require_compare_equal [%here] (module B_option) (b_of_a a) (a_to_b a); + require_compare_equal [%here] (module B_option) (b_of_a a) (try_with b_of_a_exn a); + require_compare_equal [%here] (module B_option) (a_to_b a) (try_with a_to_b_exn a); + match b_of_a a with + | Some b -> + Int.incr convertible_count; + require_compare_equal [%here] (module B) b (b_of_a_trunc a); + require_compare_equal [%here] (module B) b (a_to_b_trunc a); + require_compare_equal [%here] (module Int64) (A.to_int64 a) (B.to_int64 b) + | None -> + let trunc = truncate (A.to_int64 a) ~num_bits:B.num_bits in + require_compare_equal [%here] (module Int64) trunc (B.to_int64 (b_of_a_trunc a)); + require_compare_equal [%here] (module Int64) trunc (B.to_int64 (a_to_b_trunc a)); + require [%here] (Int64.( > ) (A.to_int64 a) (B.to_int64 B.max_value) || + Int64.( < ) (A.to_int64 a) (B.to_int64 B.min_value)) + ~if_false_then_print_s:(lazy [%message "failed to convert" ~_:(a : A.t)])); + (* Make sure we stress the conversion a nontrivial number of times. This makes sure the + random generation is useful and we aren't just testing the hard-coded examples. *) + require [%here] (!convertible_count > 100) + ~if_false_then_print_s: + (lazy [%message + "did not test successful conversion often enough" + (convertible_count : int ref)]) + +let%expect_test "int <-> nativeint" = + test_total (module I) (module Native) + ~of_:Nativeint.of_int + ~to_:Int.to_nativeint; + [%expect {| |}]; + test_partial (module Native) (module I) + ~of_:Int.of_nativeint ~of_exn:Int.of_nativeint_exn ~of_trunc:Int.of_nativeint_trunc + ~to_:Nativeint.to_int ~to_exn:Nativeint.to_int_exn ~to_trunc:Nativeint.to_int_trunc; + [%expect {| |}]; +;; + +let%expect_test "int <-> int32" = + test_partial (module I) (module I32) + ~of_:Int32.of_int ~of_exn:Int32.of_int_exn ~of_trunc:Int32.of_int_trunc + ~to_:Int.to_int32 ~to_exn:Int.to_int32_exn ~to_trunc:Int.to_int32_trunc; + [%expect {| |}]; + test_partial (module I32) (module I) + ~of_:Int.of_int32 ~of_exn:Int.of_int32_exn ~of_trunc:Int.of_int32_trunc + ~to_:Int32.to_int ~to_exn:Int32.to_int_exn ~to_trunc:Int32.to_int_trunc; + [%expect {| |}]; +;; + +let%expect_test "nativeint <-> int32" = + test_partial (module Native) (module I32) + ~of_: Int32.of_nativeint + ~of_exn: Int32.of_nativeint_exn + ~of_trunc: Int32.of_nativeint_trunc + ~to_: Nativeint.to_int32 + ~to_exn: Nativeint.to_int32_exn + ~to_trunc: Nativeint.to_int32_trunc; + [%expect {| |}]; + test_total (module I32) (module Native) + ~of_:Nativeint.of_int32 + ~to_:Int32.to_nativeint; + [%expect {| |}]; +;; + +let%expect_test "int <-> int64" = + test_total (module I) (module I64) + ~of_:Int64.of_int + ~to_:Int.to_int64; + [%expect {| |}]; + test_partial (module I64) (module I) + ~of_:Int.of_int64 ~of_exn:Int.of_int64_exn ~of_trunc:Int.of_int64_trunc + ~to_:Int64.to_int ~to_exn:Int64.to_int_exn ~to_trunc:Int64.to_int_trunc; + [%expect {| |}]; +;; + +let%expect_test "nativeint <-> int64" = + test_total (module Native) (module I64) + ~of_:Int64.of_nativeint + ~to_:Nativeint.to_int64; + [%expect {| |}]; + test_partial (module I64) (module Native) + ~of_: Nativeint.of_int64 + ~of_exn: Nativeint.of_int64_exn + ~of_trunc: Nativeint.of_int64_trunc + ~to_: Int64.to_nativeint + ~to_exn: Int64.to_nativeint_exn + ~to_trunc: Int64.to_nativeint_trunc; + [%expect {| |}]; +;; + +let%expect_test "int32 <-> int64" = + test_total (module I32) (module I64) + ~of_:Int64.of_int32 + ~to_:Int32.to_int64; + [%expect {| |}]; + test_partial (module I64) (module I32) + ~of_:Int32.of_int64 ~of_exn:Int32.of_int64_exn ~of_trunc:Int32.of_int64_trunc + ~to_:Int64.to_int32 ~to_exn:Int64.to_int32_exn ~to_trunc:Int64.to_int32_trunc; + [%expect {| |}]; +;; + + +let%expect_test "int <-> int63" = + test_total (module I) (module I63) + ~of_:Int63.of_int + ~to_:Int63.of_int; + [%expect {| |}]; + test_partial (module I63) (module I) + ~of_:Int63.to_int ~of_exn:Int63.to_int_exn ~of_trunc:Int63.to_int_trunc + ~to_:Int63.to_int ~to_exn:Int63.to_int_exn ~to_trunc:Int63.to_int_trunc; + [%expect {| |}]; +;; + +let%expect_test "nativeint <-> int63" = + test_partial (module Native) (module I63) + ~of_: Int63.of_nativeint + ~of_exn: Int63.of_nativeint_exn + ~of_trunc: Int63.of_nativeint_trunc + ~to_: Int63.of_nativeint + ~to_exn: Int63.of_nativeint_exn + ~to_trunc: Int63.of_nativeint_trunc; + [%expect {| |}]; + test_partial (module I63) (module Native) + ~of_: Int63.to_nativeint + ~of_exn: Int63.to_nativeint_exn + ~of_trunc: Int63.to_nativeint_trunc + ~to_: Int63.to_nativeint + ~to_exn: Int63.to_nativeint_exn + ~to_trunc: Int63.to_nativeint_trunc; + [%expect {| |}]; +;; + +let%expect_test "int32 <-> int63" = + test_total (module I32) (module I63) + ~of_:Int63.of_int32 + ~to_:Int63.of_int32; + [%expect {| |}]; + test_partial (module I63) (module I32) + ~of_:Int63.to_int32 ~of_exn:Int63.to_int32_exn ~of_trunc:Int63.to_int32_trunc + ~to_:Int63.to_int32 ~to_exn:Int63.to_int32_exn ~to_trunc:Int63.to_int32_trunc; + [%expect {| |}]; +;; + +let%expect_test "int64 <-> int63" = + test_partial (module I64) (module I63) + ~of_:Int63.of_int64 ~of_exn:Int63.of_int64_exn ~of_trunc:Int63.of_int64_trunc + ~to_:Int63.of_int64 ~to_exn:Int63.of_int64_exn ~to_trunc:Int63.of_int64_trunc; + [%expect {| |}]; + test_total (module I63) (module I64) + ~of_:Int63.to_int64 + ~to_:Int63.to_int64; + [%expect {| |}]; +;; diff --git a/test/test_exported_int_conversions.mli b/test/test_exported_int_conversions.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_exported_int_conversions.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_float.ml b/test/test_float.ml new file mode 100644 index 0000000..74baca9 --- /dev/null +++ b/test/test_float.ml @@ -0,0 +1,1025 @@ +open! Import +open! Float +open! Float.Private + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_hash_coherence [%here] (module Float) [ min_value; 0.; 37.; max_value ]; + [%expect {| |}]; +;; + +let exponent_bits = 11 +let mantissa_bits = 52 + +let exponent_mask64 = Int64.((shift_left one exponent_bits) - one) +let exponent_mask = Int64.to_int_exn exponent_mask64 +let mantissa_mask = Int63.((shift_left one mantissa_bits) - one) +let _mantissa_mask64 = Int63.to_int64 mantissa_mask + +let%test_unit "upper/lower_bound_for_int" = + assert ( + [%compare.equal: (int * t * t) list] + ([8; 16; 31; 32; 52; 53; 54; 62; 63; 64] + |> List.map ~f:(fun x -> (x, lower_bound_for_int x, upper_bound_for_int x))) + [( 8 + , -128.99999999999997 + , 127.99999999999999) + ; (16 + , -32768.999999999993 + , 32767.999999999996) + ; (31 + , -1073741824.9999998 + , 1073741823.9999999) + ; (32 + , -2147483648.9999995 + , 2147483647.9999998) + ; (52 + , -2251799813685248.5 + , 2251799813685247.8) + ; (53 + , -4503599627370496. + , 4503599627370495.5) + ; (54 + , -9007199254740992. + , 9007199254740991.) + ; (62 + , -2305843009213693952. + , 2305843009213693696.) + ; (63 + , -4611686018427387904. + , 4611686018427387392.) + ; (64 + , -9223372036854775808. + , 9223372036854774784.) + ]) +;; + +let%test_unit _ = + (* on 64-bit platform ppx_hash hashes floats exactly the same as polymorphic hash *) + match Word_size.word_size with + | W32 -> () + | W64 -> + List.iter ~f:(fun float -> + let hash1 = Caml.Hashtbl.hash float in + let hash2 = [%hash: float] float in + let hash3 = specialized_hash float in + if (not Int.(hash1 = hash2 && hash1 = hash3)) + then raise_s [%message "bad" + (hash1 : Int.Hex.t) + (hash2 : Int.Hex.t) + (hash3 : Int.Hex.t)] + ) + [ 0.926038888360971146 + ; 34.1638588598232076 + ] +;; + +let test_both_ways (a : t) (b : int64) = + Int64.(=) (to_int64_preserve_order_exn a) b && Float.(=) (of_int64_preserve_order b) a +;; + +let%test _ = test_both_ways 0. 0L +let%test _ = test_both_ways (-0.) 0L +let%test _ = test_both_ways 1. Int64.(shift_left 1023L 52) +let%test _ = test_both_ways (-2.) Int64.(neg (shift_left 1024L 52)) +let%test _ = test_both_ways infinity Int64.(shift_left 2047L 52) +let%test _ = test_both_ways neg_infinity Int64.(neg (shift_left 2047L 52)) + +let%test _ = one_ulp `Down infinity = max_finite_value +let%test _ = is_nan (one_ulp `Up infinity) +let%test _ = is_nan (one_ulp `Down neg_infinity) +let%test _ = one_ulp `Up neg_infinity = ~-. max_finite_value + +(* Some tests to make sure that the compiler is generating code for handling subnormal + numbers at runtime accurately. *) +let x () = min_positive_subnormal_value +let y () = min_positive_normal_value + +let%test _ = test_both_ways (x ()) 1L +let%test _ = test_both_ways (y ()) Int64.(shift_left 1L 52) + +let%test _ = x () > 0. +let%test_unit _ = [%test_result: float] (x () /. 2.) ~expect:0. + +let%test _ = one_ulp `Up 0. = x () +let%test _ = one_ulp `Down 0. = ~-. (x ()) + +let are_one_ulp_apart a b = one_ulp `Up a = b + +let%test _ = are_one_ulp_apart (x ()) (2. *. x ()) +let%test _ = are_one_ulp_apart (2. *. x ()) (3. *. x ()) + +let one_ulp_below_y () = y () -. x () +let%test _ = one_ulp_below_y () < y () +let%test _ = y () -. one_ulp_below_y () = x () +let%test _ = are_one_ulp_apart (one_ulp_below_y ()) (y ()) + +let one_ulp_above_y () = y () +. x () +let%test _ = y () < one_ulp_above_y () +let%test _ = one_ulp_above_y () -. y () = x () +let%test _ = are_one_ulp_apart (y ()) (one_ulp_above_y ()) + +let%test _ = not (are_one_ulp_apart (one_ulp_below_y ()) (one_ulp_above_y ())) + +(* [2 * min_positive_normal_value] is where the ulp increases for the first time. *) +let z () = 2. *. y () +let one_ulp_below_z () = z () -. x () +let%test _ = one_ulp_below_z () < z () +let%test _ = z () -. one_ulp_below_z () = x () +let%test _ = are_one_ulp_apart (one_ulp_below_z ()) (z ()) + +let one_ulp_above_z () = z () +. 2. *. x () +let%test _ = z () < one_ulp_above_z () +let%test _ = one_ulp_above_z () -. z () = 2. *. x () +let%test _ = are_one_ulp_apart (z ()) (one_ulp_above_z ()) + +let%test_module "clamp" = + (module struct + let%test _ = clamp_exn 1.0 ~min:2. ~max:3. = 2. + let%test _ = clamp_exn 2.5 ~min:2. ~max:3. = 2.5 + let%test _ = clamp_exn 3.5 ~min:2. ~max:3. = 3. + + let%test_unit "clamp" = + [%test_result: float Or_error.t] (clamp 3.5 ~min:2. ~max:3.) ~expect:(Ok 3.) + + let%test_unit "clamp nan" = + [%test_result: float Or_error.t] (clamp nan ~min:2. ~max:3.) ~expect:(Ok nan) + + let%test "clamp bad" = Or_error.is_error (clamp 2.5 ~min:3. ~max:2.) + end) + +let%test_unit _ = + [%test_result: Int64.t] + (Int64.bits_of_float 1.1235582092889474E+307) ~expect:0x7fb0000000000000L + +let%test_module "IEEE" = + (module struct + (* Note: IEEE 754 defines NaN values to be those where the exponent is all 1s and the + mantissa is nonzero. test_result sees nan values as equal because it is based + on [compare] rather than [=]. (If [x] and [x'] are nan, [compare x x'] returns 0, + whereas [x = x'] returns [false]. This is the case regardless of whether or not + [x] and [x'] are bit-identical values of nan.) *) + let f (t : t) (negative : bool) (exponent : int) (mantissa : Int63.t) : unit = + let str = to_string t in + let is_nan = is_nan t in + (* the sign doesn't matter when nan *) + if not is_nan then + [%test_result: bool] ~message:("ieee_negative " ^ str) + (ieee_negative t) ~expect:negative; + [%test_result: int] ~message:("ieee_exponent " ^ str) + (ieee_exponent t) ~expect:exponent; + if is_nan + then assert (Int63.(zero <> ieee_mantissa t)) + else [%test_result: Int63.t] ~message:("ieee_mantissa " ^ str) + (ieee_mantissa t) ~expect:mantissa; + [%test_result: t] + ~message:(Printf.sprintf !"create_ieee ~negative:%B ~exponent:%d ~mantissa:%{Int63}" + negative exponent mantissa) + (create_ieee_exn ~negative ~exponent ~mantissa) + ~expect:t + + let%test_unit _ = + let (!!) x = Int63.of_int x in + f zero false 0 (!! 0); + f min_positive_subnormal_value false 0 (!! 1); + f min_positive_normal_value false 1 (!! 0); + f epsilon_float false Int.(1023 - mantissa_bits) (!! 0); + f one false 1023 (!! 0); + f minus_one true 1023 (!! 0); + f max_finite_value false Int.(exponent_mask - 1) mantissa_mask; + f infinity false exponent_mask (!! 0); + f neg_infinity true exponent_mask (!! 0); + f nan false exponent_mask (!! 1) + + (* test the normalized case, that is, 1 <= exponent <= 2046 *) + let%test_unit _ = + let g ~negative ~exponent ~mantissa = + assert (create_ieee_exn ~negative ~exponent + ~mantissa:(Int63.of_int64_exn mantissa) + = + (if negative then -1. else 1.) + * 2. **. (Float.of_int exponent - 1023.) + * (1. + (2. **. -52.) * Int64.to_float mantissa)) + in + g ~negative:false ~exponent:1 ~mantissa:147L; + g ~negative:true ~exponent:137 ~mantissa:13L; + g ~negative:false ~exponent:1015 ~mantissa:1370001L; + g ~negative:true ~exponent:2046 ~mantissa:137000100945L + end) + +let%test_module _ = + (module struct + let test f expect = + let actual = to_padded_compact_string f in + if String.(actual <> expect) + then raise_s + [%message "failure" + (f : t ) + (expect : string) + (actual : string)] + + let both f expect = + assert (f > 0.); + test f expect; + test (~-.f) ("-"^expect); + ;; + + let decr = one_ulp `Down + let incr = one_ulp `Up + + let boundary f ~closer_to_zero ~at = + assert (f > 0.); + (* If [f] looks like an odd multiple of 0.05, it might be slightly under-represented + as a float, e.g. + + 1. -. 0.95 = 0.0500000000000000444 + + In such case, sadly, the right way for [to_padded_compact_string], just as for + [sprintf "%.1f"], is to round it down. However, the next representable number + should be rounded up: + + # let x = 0.95 in sprintf "%.0f / %.1f / %.2f / %.3f / %.20f" x x x x x;; + - : string = "1 / 0.9 / 0.95 / 0.950 / 0.94999999999999995559" + + # let x = incr 0.95 in sprintf "%.0f / %.1f / %.2f / %.3f / %.20f" x x x x x ;; + - : string = "1 / 1.0 / 0.95 / 0.950 / 0.95000000000000006661" + + *) + let f = + if f >= 1000. then + f + else + let x = Printf.sprintf "%.20f" f in + let spot = String.index_exn x '.' in + (* the following condition is only meant to work for small multiples of 0.05 *) + let (+) = Int.(+) in + let (=) = Char.(=) in + if x.[spot + 2] = '4' && x.[spot + 3] = '9' && x.[spot + 4] = '9' then + (* something like 0.94999999999999995559 *) + incr f + else + f + in + both (decr f) closer_to_zero; + both f at; + ;; + + let%test_unit _ = test nan "nan " + let%test_unit _ = test 0.0 "0 " + let%test_unit _ = both min_positive_subnormal_value "0 " + let%test_unit _ = both infinity "inf " + + let%test_unit _ = boundary 0.05 ~closer_to_zero: "0 " ~at: "0.1" + let%test_unit _ = boundary 0.15 ~closer_to_zero: "0.1" ~at: "0.2" + (* glibc printf resolves ties to even, cf. + http://www.exploringbinary.com/inconsistent-rounding-of-printed-floating-point-numbers/ + Ties are resolved differently in JavaScript - mark some tests as no running with JavaScript. + *) + let%test_unit _ [@tags "no-js"] = + boundary (* tie *) 0.25 ~closer_to_zero: "0.2" ~at: "0.2" + let%test_unit _ [@tags "no-js"] = + boundary (incr 0.25)~closer_to_zero: "0.2" ~at: "0.3" + let%test_unit _ = boundary 0.35 ~closer_to_zero: "0.3" ~at: "0.4" + let%test_unit _ = boundary 0.45 ~closer_to_zero: "0.4" ~at: "0.5" + let%test_unit _ = both 0.50 "0.5" + let%test_unit _ = boundary 0.55 ~closer_to_zero: "0.5" ~at: "0.6" + let%test_unit _ = boundary 0.65 ~closer_to_zero: "0.6" ~at: "0.7" + (* this time tie-to-even means round away from 0 *) + let%test_unit _ = boundary (* tie *) 0.75 ~closer_to_zero: "0.7" ~at: "0.8" + let%test_unit _ = boundary 0.85 ~closer_to_zero: "0.8" ~at: "0.9" + let%test_unit _ = boundary 0.95 ~closer_to_zero: "0.9" ~at: "1 " + let%test_unit _ = boundary 1.05 ~closer_to_zero: "1 " ~at: "1.1" + let%test_unit _ [@tags "no-js"] = + boundary 3.25 ~closer_to_zero: "3.2" ~at: "3.2" + let%test_unit _ [@tags "no-js"] = + boundary (incr 3.25)~closer_to_zero: "3.2" ~at: "3.3" + let%test_unit _ = boundary 3.75 ~closer_to_zero: "3.7" ~at: "3.8" + let%test_unit _ = boundary 9.95 ~closer_to_zero: "9.9" ~at: "10 " + let%test_unit _ = boundary 10.05 ~closer_to_zero: "10 " ~at: "10.1" + let%test_unit _ = boundary 100.05 ~closer_to_zero:"100 " ~at: "100.1" + let%test_unit _ [@tags "no-js"] = + boundary (* tie *) 999.25 ~closer_to_zero:"999.2" ~at: "999.2" + let%test_unit _ [@tags "no-js"] = + boundary (incr 999.25)~closer_to_zero:"999.2" ~at: "999.3" + let%test_unit _ = boundary 999.75 ~closer_to_zero:"999.7" ~at: "999.8" + let%test_unit _ = boundary 999.95 ~closer_to_zero:"999.9" ~at: "1k " + let%test_unit _ = both 1000. "1k " + + (* some ties which we resolve manually in [iround_ratio_exn] *) + let%test_unit _ = boundary 1050. ~closer_to_zero: "1k " ~at: "1k " + let%test_unit _ = boundary (incr 1050.) ~closer_to_zero: "1k " ~at: "1k1" + let%test_unit _ = boundary 1950. ~closer_to_zero: "1k9" ~at: "2k " + let%test_unit _ = boundary 3250. ~closer_to_zero: "3k2" ~at: "3k2" + let%test_unit _ = boundary (incr 3250.) ~closer_to_zero: "3k2" ~at: "3k3" + let%test_unit _ = boundary 9950. ~closer_to_zero: "9k9" ~at: "10k " + let%test_unit _ = boundary 33_250. ~closer_to_zero: "33k2" ~at: "33k2" + let%test_unit _ = boundary (incr 33_250.) ~closer_to_zero: "33k2" ~at: "33k3" + let%test_unit _ = boundary 33_350. ~closer_to_zero: "33k3" ~at: "33k4" + let%test_unit _ = boundary 33_750. ~closer_to_zero: "33k7" ~at: "33k8" + let%test_unit _ = boundary 333_250. ~closer_to_zero:"333k2" ~at: "333k2" + let%test_unit _ = boundary (incr 333_250.) ~closer_to_zero:"333k2" ~at: "333k3" + let%test_unit _ = boundary 333_750. ~closer_to_zero:"333k7" ~at: "333k8" + let%test_unit _ = boundary 999_850. ~closer_to_zero:"999k8" ~at: "999k8" + let%test_unit _ = boundary (incr 999_850.) ~closer_to_zero:"999k8" ~at: "999k9" + let%test_unit _ = boundary 999_950. ~closer_to_zero:"999k9" ~at: "1m " + let%test_unit _ = boundary 1_050_000. ~closer_to_zero: "1m " ~at: "1m " + let%test_unit _ = boundary (incr 1_050_000.) ~closer_to_zero: "1m " ~at: "1m1" + + let%test_unit _ = boundary 999_950_000. ~closer_to_zero:"999m9" ~at: "1g " + let%test_unit _ = boundary 999_950_000_000. ~closer_to_zero:"999g9" ~at: "1t " + let%test_unit _ = boundary 999_950_000_000_000. ~closer_to_zero:"999t9" ~at: "1p " + let%test_unit _ = boundary 999_950_000_000_000_000. ~closer_to_zero:"999p9" ~at:"1.0e+18" + + (* Test the boundary between the subnormals and the normals. *) + let%test_unit _ = boundary min_positive_normal_value ~closer_to_zero:"0 " ~at:"0 " + end) + +let%test "int_pow" = + let tol = 1e-15 in + let test (x, n) = + let reference_value = x **. of_int n in + let relative_error = (int_pow x n -. reference_value) /. reference_value in + abs relative_error < tol + in + List.for_all ~f:test + [(1.5, 17); (1.5, 42); (0.99, 64); (2., -5); (2., -1) + ; (-1.3, 2); (-1.3, -1); (-1.3, -2); (5., 0) + ; (nan, 0); (0., 0); (infinity, 0) + ] + +let%test "int_pow misc" = + int_pow 0. (-1) = infinity + && int_pow (-0.) (-1) = neg_infinity + && int_pow (-0.) (-2) = infinity + && int_pow 1.5 5000 = infinity + && int_pow 1.5 (-5000) = 0. + && int_pow (-1.) Int.max_value = -1. + && int_pow (-1.) Int.min_value = 1. + +(* some ugly corner cases with extremely large exponents and some serious precision loss *) +let%test "int_pow bad cases" [@tags "64-bits-only"] = + let a = one_ulp `Down 1. in + let b = one_ulp `Up 1. in + let large = 1 lsl 61 in + let small = Int.neg large in + (* this huge discrepancy comes from the fact that [1 / a = b] but this is a very poor + approximation, and in particular [1 / b = one_ulp `Down a = a * a]. *) + a **. of_int small = 1.5114276650041252e+111 + && int_pow a small = 2.2844048619719663e+222 + && int_pow b large = 2.2844048619719663e+222 + && b **. of_int large = 2.2844135865396268e+222 + +let%test_unit "sign_exn" = + List.iter ~f:(fun (input,expect) -> assert (Sign.equal (sign_exn input) expect)) + [ (1e-30, Sign.Pos) + ; (-0., Zero) + ; (0., Zero) + ; (neg_infinity, Neg) + ] + +let%test _ = + match sign_exn nan with + | Neg | Zero | Pos -> false + | exception _ -> true + +let%test_unit "sign_or_nan" = + List.iter ~f:(fun (input,expect) -> assert (Sign_or_nan.equal (sign_or_nan input) expect)) + [ (1e-30, Sign_or_nan.Pos) + ; (-0., Zero) + ; (0., Zero) + ; (neg_infinity, Neg) + ; (nan, Nan) + ] + +let%test_module _ = + (module struct + let check v expect = + match Validate.result v, expect with + | Ok (), `Ok | Error _, `Error -> () + | r, expect -> + raise_s + [%message "mismatch" (r : unit Or_error.t) (expect : [ `Ok | `Error ])] + ;; + + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) nan) `Error + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) infinity) `Error + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) neg_infinity) `Error + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) (-1.)) `Error + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) 0.) `Ok + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) 1.) `Ok + + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) nan) `Error + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) infinity) `Error + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) neg_infinity) `Error + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) (-1.)) `Ok + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) 0.) `Ok + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) 1.) `Error + + (* Some of the following tests used to live in lib_test/core_float_test.ml. *) + + let () = Random.init 137 + + (* round: + ... <-)[-><-)[-><-)[-><-)[-><-)[-><-)[-> ... + ... -+-----+-----+-----+-----+-----+-----+- ... + ... -3 -2 -1 0 1 2 3 ... + so round x -. x should be in (-0.5,0.5] + *) + let round_test x = + let y = round x in + -0.5 < y -. x && y -. x <= 0.5 + + let iround_up_vs_down_test x = + let expected_difference = + if Parts.fractional (modf x) = 0. then + 0 + else + 1 + in + match iround_up x, iround_down x with + | Some x, Some y -> Int.(x - y = expected_difference) + | _, _ -> true + + let test_all_six x + ~specialized_iround ~specialized_iround_exn ~float_rounding + ~dir ~validate = + let result1 = iround x ~dir in + let result2 = Option.try_with (fun () -> iround_exn x ~dir) in + let result3 = specialized_iround x in + let result4 = Option.try_with (fun () -> specialized_iround_exn x) in + let result5 = Option.try_with (fun () -> Int.of_float (float_rounding x)) in + let result6 = Option.try_with (fun () -> Int.of_float (round ~dir x)) in + let (=) = Caml.(=) in + if result1 = result2 && result2 = result3 && result3 = result4 + && result4 = result5 && result5 = result6 then + validate result1 + else + false + + (* iround ~dir:`Nearest built so this should always be true *) + let iround_nearest_test x = + test_all_six x + ~specialized_iround:iround_nearest + ~specialized_iround_exn:iround_nearest_exn + ~float_rounding:round_nearest + ~dir:`Nearest + ~validate:(function + | None -> true + | Some y -> + let y = of_int y in + -0.5 < y -. x && y -. x <= 0.5) + + (* iround_down: + ... )[<---)[<---)[<---)[<---)[<---)[<---)[ ... + ... -+-----+-----+-----+-----+-----+-----+- ... + ... -3 -2 -1 0 1 2 3 ... + so x -. iround_down x should be in [0,1) + *) + let iround_down_test x = + test_all_six x + ~specialized_iround:iround_down + ~specialized_iround_exn:iround_down_exn + ~float_rounding:round_down + ~dir:`Down + ~validate:(function + | None -> true + | Some y -> + let y = of_int y in + 0. <= x -. y && x -. y < 1.) + + (* iround_up: + ... ](--->](--->](--->](--->](--->](--->]( ... + ... -+-----+-----+-----+-----+-----+-----+- ... + ... -3 -2 -1 0 1 2 3 ... + so iround_up x -. x should be in [0,1) + *) + let iround_up_test x = + test_all_six x + ~specialized_iround:iround_up + ~specialized_iround_exn:iround_up_exn + ~float_rounding:round_up + ~dir:`Up + ~validate:(function + | None -> true + | Some y -> + let y = of_int y in + 0. <= y -. x && y -. x < 1.) + + (* iround_towards_zero: + ... ](--->](--->](---><--->)[<---)[<---)[ ... + ... -+-----+-----+-----+-----+-----+-----+- ... + ... -3 -2 -1 0 1 2 3 ... + so abs x -. abs (iround_towards_zero x) should be in [0,1) + *) + let iround_towards_zero_test x = + test_all_six x + ~specialized_iround:iround_towards_zero + ~specialized_iround_exn:iround_towards_zero_exn + ~float_rounding:round_towards_zero + ~dir:`Zero + ~validate:(function + | None -> true + | Some y -> + let x = abs x in + let y = abs (of_int y) in + 0. <= x -. y && x -. y < 1. && (Sign.(sign_exn x = sign_exn y) || y = 0.0)) + + (* Easy cases that used to live inline with the code above. *) + let%test_unit _ = [%test_result: int option] (iround_up (-3.4)) ~expect:(Some (-3)) + let%test_unit _ = [%test_result: int option] (iround_up 0.0) ~expect:(Some 0) + let%test_unit _ = [%test_result: int option] (iround_up 3.4) ~expect:(Some 4) + + let%test_unit _ = [%test_result: int] (iround_up_exn (-3.4)) ~expect:(-3) + let%test_unit _ = [%test_result: int] (iround_up_exn 0.0) ~expect:0 + let%test_unit _ = [%test_result: int] (iround_up_exn 3.4) ~expect:4 + + let%test_unit _ = [%test_result: int option] (iround_down (-3.4)) ~expect:(Some (-4)) + let%test_unit _ = [%test_result: int option] (iround_down 0.0) ~expect:(Some 0) + let%test_unit _ = [%test_result: int option] (iround_down 3.4) ~expect:(Some 3) + + let%test_unit _ = [%test_result: int] (iround_down_exn (-3.4)) ~expect:(-4) + let%test_unit _ = [%test_result: int] (iround_down_exn 0.0) ~expect:0 + let%test_unit _ = [%test_result: int] (iround_down_exn 3.4) ~expect:3 + + let%test_unit _ = [%test_result: int option] (iround_towards_zero (-3.4)) ~expect:(Some (-3)) + let%test_unit _ = [%test_result: int option] (iround_towards_zero 0.0) ~expect:(Some 0) + let%test_unit _ = [%test_result: int option] (iround_towards_zero 3.4) ~expect:(Some 3) + + let%test_unit _ = [%test_result: int] (iround_towards_zero_exn (-3.4)) ~expect:(-3) + let%test_unit _ = [%test_result: int] (iround_towards_zero_exn 0.0) ~expect:0 + let%test_unit _ = [%test_result: int] (iround_towards_zero_exn 3.4) ~expect:3 + + let%test_unit _ = [%test_result: int option] (iround_nearest (-3.6)) ~expect:(Some (-4)) + let%test_unit _ = [%test_result: int option] (iround_nearest (-3.5)) ~expect:(Some (-3)) + let%test_unit _ = [%test_result: int option] (iround_nearest (-3.4)) ~expect:(Some (-3)) + let%test_unit _ = [%test_result: int option] (iround_nearest 0.0) ~expect:(Some 0) + let%test_unit _ = [%test_result: int option] (iround_nearest 3.4) ~expect:(Some 3) + let%test_unit _ = [%test_result: int option] (iround_nearest 3.5) ~expect:(Some 4) + let%test_unit _ = [%test_result: int option] (iround_nearest 3.6) ~expect:(Some 4) + + let%test_unit _ = [%test_result: int] (iround_nearest_exn (-3.6)) ~expect:(-4) + let%test_unit _ = [%test_result: int] (iround_nearest_exn (-3.5)) ~expect:(-3) + let%test_unit _ = [%test_result: int] (iround_nearest_exn (-3.4)) ~expect:(-3) + let%test_unit _ = [%test_result: int] (iround_nearest_exn 0.0 ) ~expect:0 + let%test_unit _ = [%test_result: int] (iround_nearest_exn 3.4 ) ~expect:3 + let%test_unit _ = [%test_result: int] (iround_nearest_exn 3.5 ) ~expect:4 + let%test_unit _ = [%test_result: int] (iround_nearest_exn 3.6 ) ~expect:4 + + let special_values_test () = + [%test_result: float] (round (-.1.50001)) ~expect:(-.2.); + [%test_result: float] (round (-.1.5)) ~expect:(-.1.); + [%test_result: float] (round (-.0.50001)) ~expect:(-.1.); + [%test_result: float] (round (-.0.5)) ~expect:0.; + [%test_result: float] (round 0.49999) ~expect:0.; + [%test_result: float] (round 0.5) ~expect:1.; + [%test_result: float] (round 1.49999) ~expect:1.; + [%test_result: float] (round 1.5) ~expect:2.; + [%test_result: int] (iround_exn ~dir:`Up (-.2.)) ~expect:(-2); + [%test_result: int] (iround_exn ~dir:`Up (-.1.9999)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Up (-.1.)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Up (-.0.9999)) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Up 0.) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Up 0.00001) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Up 1.) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Up 1.00001) ~expect:2; + [%test_result: int] (iround_up_exn (-.2.)) ~expect:(-2); + [%test_result: int] (iround_up_exn (-.1.9999)) ~expect:(-1); + [%test_result: int] (iround_up_exn (-.1.)) ~expect:(-1); + [%test_result: int] (iround_up_exn (-.0.9999)) ~expect:0; + [%test_result: int] (iround_up_exn 0.) ~expect:0; + [%test_result: int] (iround_up_exn 0.00001) ~expect:1; + [%test_result: int] (iround_up_exn 1.) ~expect:1; + [%test_result: int] (iround_up_exn 1.00001) ~expect:2; + [%test_result: int] (iround_exn ~dir:`Down (-.1.00001)) ~expect:(-2); + [%test_result: int] (iround_exn ~dir:`Down (-.1.)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Down (-.0.00001)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Down 0.) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Down 0.99999) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Down 1.) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Down 1.99999) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Down 2.) ~expect:2; + [%test_result: int] (iround_down_exn (-.1.00001)) ~expect:(-2); + [%test_result: int] (iround_down_exn (-.1.)) ~expect:(-1); + [%test_result: int] (iround_down_exn (-.0.00001)) ~expect:(-1); + [%test_result: int] (iround_down_exn 0.) ~expect:0; + [%test_result: int] (iround_down_exn 0.99999) ~expect:0; + [%test_result: int] (iround_down_exn 1.) ~expect:1; + [%test_result: int] (iround_down_exn 1.99999) ~expect:1; + [%test_result: int] (iround_down_exn 2.) ~expect:2; + [%test_result: int] (iround_exn ~dir:`Zero (-.2.)) ~expect:(-2); + [%test_result: int] (iround_exn ~dir:`Zero (-.1.99999)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Zero (-.1.)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Zero (-.0.99999)) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Zero 0.99999) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Zero 1.) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Zero 1.99999) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Zero 2.) ~expect:2 + + let is_64_bit_platform = of_int Int.max_value >= 2. **. 60. + + (* Tests for values close to [iround_lbound] and [iround_ubound]. *) + let extremities_test ~round = + let (+) = Int.(+) in + let (-) = Int.(-) in + if is_64_bit_platform then ( + (* 64 bits *) + [%test_result: int option] (round (2.0 **. 62. -. 512.)) ~expect:(Some (Int.max_value - 511)); + [%test_result: int option] (round (2.0 **. 62. -. 1024.)) ~expect:(Some (Int.max_value - 1023)); + [%test_result: int option] (round (-. (2.0 **. 62.))) ~expect:(Some Int.min_value); + [%test_result: int option] (round (-. (2.0 **. 62. -. 512.))) ~expect:(Some (Int.min_value + 512)); + [%test_result: int option] (round (2.0 **. 62.)) ~expect:None; + [%test_result: int option] (round (-. (2.0 **. 62. +. 1024.))) ~expect:None) + else ( + let int_size_minus_one = of_int (Int.num_bits - 1) in + (* 32 bits *) + [%test_result: int option] (round (2.0 **. int_size_minus_one -. 1.)) ~expect:(Some Int.max_value); + [%test_result: int option] (round (2.0 **. int_size_minus_one -. 2.)) ~expect:(Some (Int.max_value - 1)); + [%test_result: int option] (round (-. (2.0 **. int_size_minus_one))) ~expect:(Some Int.min_value); + [%test_result: int option] (round (-. (2.0 **. int_size_minus_one -. 1.))) ~expect:(Some (Int.min_value + 1)); + [%test_result: int option] (round (2.0 **. int_size_minus_one)) ~expect:None; + [%test_result: int option] (round (-. (2.0 **. int_size_minus_one +. 1.))) ~expect:None) + + let%test_unit _ = extremities_test ~round:iround_down + let%test_unit _ = extremities_test ~round:iround_up + let%test_unit _ = extremities_test ~round:iround_nearest + let%test_unit _ = extremities_test ~round:iround_towards_zero + + (* test values beyond the integers range *) + let large_value_test x = + [%test_result: int option] (iround_down x) ~expect:None; + [%test_result: int option] (iround ~dir:`Down x) ~expect:None; + [%test_result: int option] (iround_up x) ~expect:None; + [%test_result: int option] (iround ~dir:`Up x) ~expect:None; + [%test_result: int option] (iround_towards_zero x) ~expect:None; + [%test_result: int option] (iround ~dir:`Zero x) ~expect:None; + [%test_result: int option] (iround_nearest x) ~expect:None; + [%test_result: int option] (iround ~dir:`Nearest x) ~expect:None; + + assert (Exn.does_raise (fun () -> iround_down_exn x)); + assert (Exn.does_raise (fun () -> iround_exn ~dir:`Down x)); + assert (Exn.does_raise (fun () -> iround_up_exn x)); + assert (Exn.does_raise (fun () -> iround_exn ~dir:`Up x)); + assert (Exn.does_raise (fun () -> iround_towards_zero_exn x)); + assert (Exn.does_raise (fun () -> iround_exn ~dir:`Zero x)); + assert (Exn.does_raise (fun () -> iround_nearest_exn x)); + assert (Exn.does_raise (fun () -> iround_exn ~dir:`Nearest x)); + + [%test_result: float] (round_down x) ~expect:x; + [%test_result: float] (round ~dir:`Down x) ~expect:x; + [%test_result: float] (round_up x) ~expect:x; + [%test_result: float] (round ~dir:`Up x) ~expect:x; + [%test_result: float] (round_towards_zero x) ~expect:x; + [%test_result: float] (round ~dir:`Zero x) ~expect:x; + [%test_result: float] (round_nearest x) ~expect:x; + [%test_result: float] (round ~dir:`Nearest x) ~expect:x + + let large_numbers = + let (+) = Int.(+) in + let (-) = Int.(-) in + List.concat ( + List.init (1024 - 64) ~f:(fun x -> + let x = of_int (x + 64) in + let y = + [2. **. x; + 2. **. x -. 2. **. (x -. 53.); (* one ulp down *) + 2. **. x +. 2. **. (x -. 52.)] (* one ulp up *) + in + y @ (List.map y ~f:neg))) + @ + [infinity; + neg_infinity] + + let%test_unit _ = List.iter large_numbers ~f:large_value_test + + let numbers_near_powers_of_two = + List.concat ( + List.init 64 ~f:(fun i -> + let pow2 = 2. **. of_int i in + let x = + [ pow2; + one_ulp `Down (pow2 +. 0.5); + pow2 +. 0.5; + one_ulp `Down (pow2 +. 1.0); + pow2 +. 1.0; + one_ulp `Down (pow2 +. 1.5); + pow2 +. 1.5; + one_ulp `Down (pow2 +. 2.0); + pow2 +. 2.0; + one_ulp `Down (pow2 *. 2.0 -. 1.0); + one_ulp `Down pow2; + one_ulp `Up pow2 + ] + in + x @ (List.map x ~f:neg) + )) + + let%test _ = List.for_all numbers_near_powers_of_two ~f:iround_up_vs_down_test + let%test _ = List.for_all numbers_near_powers_of_two ~f:iround_nearest_test + let%test _ = List.for_all numbers_near_powers_of_two ~f:iround_down_test + let%test _ = List.for_all numbers_near_powers_of_two ~f:iround_up_test + let%test _ = List.for_all numbers_near_powers_of_two ~f:iround_towards_zero_test + let%test _ = List.for_all numbers_near_powers_of_two ~f:round_test + + (* code for generating random floats on which to test functions *) + let rec absirand () = + let open Int.O in + let rec aux acc cnt = + if cnt = 0 then + acc + else + let bit = if Random.bool () then 1 else 0 in + aux (2 * acc + bit) (cnt - 1) + in + let result = aux 0 (if is_64_bit_platform then 62 else 30) in + if result >= Int.max_value - 255 then + (* On a 64-bit box, [float x > Int.max_value] when [x >= Int.max_value - 255], so + [iround (float x)] would be out of bounds. So we try again. This branch of code + runs with probability 6e-17 :-) As such, we have some fixed tests in + [extremities_test] above, to ensure that we do always check some examples in + that range. *) + absirand () + else + result + + (* -Int.max_value <= frand () <= Int.max_value *) + let frand () = + let x = (of_int (absirand ())) +. Random.float 1.0 in + if Random.bool () then + -1.0 *. x + else + x + + let randoms = List.init ~f:(fun _ -> frand ()) 10_000 + + let%test _ = List.for_all randoms ~f:iround_up_vs_down_test + let%test _ = List.for_all randoms ~f:iround_nearest_test + let%test _ = List.for_all randoms ~f:iround_down_test + let%test _ = List.for_all randoms ~f:iround_up_test + let%test _ = List.for_all randoms ~f:iround_towards_zero_test + let%test _ = List.for_all randoms ~f:round_test + let%test_unit _ = special_values_test () + let%test _ = iround_nearest_test (of_int Int.max_value) + let%test _ = iround_nearest_test (of_int Int.min_value) + end) + +module Test_bounds ( + I : sig + type t + val num_bits : int + val of_float : float -> t + val to_int64 : t -> Int64.t + val max_value : t + val min_value : t + end + ) = struct + open I + + let float_lower_bound = lower_bound_for_int num_bits + let float_upper_bound = upper_bound_for_int num_bits + + let%test_unit "lower bound is valid" = ignore (of_float float_lower_bound : t) + let%test_unit "upper bound is valid" = ignore (of_float float_upper_bound : t) + + let%test "smaller than lower bound is not valid" = + Exn.does_raise (fun () -> of_float (one_ulp `Down float_lower_bound)) + let%test "bigger than upper bound is not valid" = + Exn.does_raise (fun () -> of_float (one_ulp `Up float_upper_bound)) + + (* We use [Caml.Int64.of_float] in the next two tests because [Int64.of_float] rejects + out-of-range inputs, whereas [Caml.Int.of_float] simply overflows (returns + [Int64.min_int]). *) + + let%test "smaller than lower bound overflows" = + let lower_bound = Int64.of_float float_lower_bound in + let lower_bound_minus_epsilon = Caml.Int64.of_float (one_ulp `Down float_lower_bound) in + let min_value = to_int64 min_value in + if Int.(=) num_bits 64 + (* We cannot detect overflow because on Intel overflow results in min_value. *) + then true + else begin + assert (Int64.(<=) lower_bound_minus_epsilon lower_bound); + (* a value smaller than min_value would overflow if converted to [t] *) + Int64.(<) lower_bound_minus_epsilon min_value + end + + let%test "bigger than upper bound overflows" = + let upper_bound = Int64.of_float float_upper_bound in + let upper_bound_plus_epsilon = Caml.Int64.of_float (one_ulp `Up float_upper_bound) in + let max_value = to_int64 max_value in + if Int.(=) num_bits 64 + (* upper_bound_plus_epsilon is not representable as a Int64.t, it has overflowed *) + then Int64.(<) upper_bound_plus_epsilon upper_bound + else begin + assert (Int64.(>=) upper_bound_plus_epsilon upper_bound); + (* a value greater than max_value would overflow if converted to [t] *) + Int64.(>) upper_bound_plus_epsilon max_value + end +end + +let%test_module "Int" = (module Test_bounds(Int)) +let%test_module "Int32" = (module Test_bounds(Int32)) +let%test_module "Int63" = (module Test_bounds(Int63)) +let%test_module "Int63_emul" = (module Test_bounds(Base.Not_exposed_properly.Int63_emul)) +let%test_module "Int64" = (module Test_bounds(Int64)) +let%test_module "Nativeint" = (module Test_bounds(Nativeint)) + +let%test_unit _ = [%test_result: string] (to_string 3.14) ~expect:"3.14" +let%test_unit _ = [%test_result: string] (to_string 3.1400000000000001) ~expect:"3.14" +let%test_unit _ = [%test_result: string] (to_string 3.1400000000000004) ~expect:"3.1400000000000006" +let%test_unit _ = [%test_result: string] (to_string 8.000000000000002) ~expect:"8.0000000000000018" +let%test_unit _ = [%test_result: string] (to_string 9.992) ~expect:"9.992" +let%test_unit _ = [%test_result: string] (to_string (2.**.63. *. (1. +. 2.**. (-52.)))) ~expect:"9.2233720368547779e+18" +let%test_unit _ = [%test_result: string] (to_string (-3.)) ~expect:"-3." +let%test_unit _ = [%test_result: string] (to_string nan) ~expect:"nan" +let%test_unit _ = [%test_result: string] (to_string infinity) ~expect:"inf" +let%test_unit _ = [%test_result: string] (to_string neg_infinity) ~expect:"-inf" +let%test_unit _ = [%test_result: string] (to_string 3e100) ~expect:"3e+100" +let%test_unit _ = [%test_result: string] (to_string max_finite_value) ~expect:"1.7976931348623157e+308" +let%test_unit _ = [%test_result: string] (to_string min_positive_subnormal_value) ~expect:"4.94065645841247e-324" + +let%test _ = epsilon_float = (one_ulp `Up 1.) -. 1. + +let%test _ = one_ulp_less_than_half = 0.49999999999999994 + +let%test _ = + round_down 3.6 = 3. + && round_down (-3.6) = -4. + +let%test _ = + round_up 3.6 = 4. + && round_up (-3.6) = -3. + +let%test _ = + round_towards_zero 3.6 = 3. + && round_towards_zero (-3.6) = -3. + +let%test _ = round_nearest_half_to_even 0. = 0. +let%test _ = round_nearest_half_to_even 0.5 = 0. +let%test _ = round_nearest_half_to_even (-0.5) = 0. +let%test _ = round_nearest_half_to_even (one_ulp `Up 0.5) = 1. +let%test _ = round_nearest_half_to_even (one_ulp `Down 0.5) = 0. +let%test _ = round_nearest_half_to_even (one_ulp `Up (-0.5)) = 0. +let%test _ = round_nearest_half_to_even (one_ulp `Down (-0.5)) = -1. +let%test _ = round_nearest_half_to_even 3.5 = 4. +let%test _ = round_nearest_half_to_even 4.5 = 4. +let%test _ = round_nearest_half_to_even (one_ulp `Up (-5.5)) = -5. +let%test _ = round_nearest_half_to_even 5.5 = 6. +let%test _ = round_nearest_half_to_even 6.5 = 6. +let%test _ = round_nearest_half_to_even (one_ulp `Up (-. (2. **. 52.))) = -. (2. **. 52.) +let%test _ = round_nearest (one_ulp `Up (-. (2. **. 52.))) = 1. -. (2. **. 52.) + +let%test_module _ = + (module struct + (* check we raise on invalid input *) + let must_fail f x = Exn.does_raise (fun () -> f x) + let must_succeed f x = ignore (f x); true + let%test _ = must_fail int63_round_nearest_portable_alloc_exn nan + let%test _ = must_fail int63_round_nearest_portable_alloc_exn max_value + let%test _ = must_fail int63_round_nearest_portable_alloc_exn min_value + let%test _ = must_fail int63_round_nearest_portable_alloc_exn (2. **. 63.) + let%test _ = must_fail int63_round_nearest_portable_alloc_exn (~-. (2. **. 63.)) + let%test _ = must_succeed int63_round_nearest_portable_alloc_exn (2. **. 62. -. 512.) + let%test _ = must_fail int63_round_nearest_portable_alloc_exn (2. **. 62.) + let%test _ = must_fail int63_round_nearest_portable_alloc_exn (~-. (2. **. 62.) -. 1024.) + let%test _ = must_succeed int63_round_nearest_portable_alloc_exn (~-. (2. **. 62.)) + end) + +let%test _ = + round_nearest 3.6 = 4. + && round_nearest (-3.6) = -4. + + +(* The redefinition of [sexp_of_t] in float.ml assumes sexp conversion uses E rather than + e. *) +let%test_unit "e vs E" = [%test_result: Sexp.t] [%sexp (1.4e100 : t)] ~expect:(Atom "1.4E+100") + +let%test_module _ = + (module struct + let test ?delimiter ~decimals f s s_strip_zero = + let s' = to_string_hum ?delimiter ~decimals ~strip_zero:false f in + if String.(s' <> s) then + raise_s + [%message + "to_string_hum ~strip_zero:false" + ~input:(f : float) + (decimals : int) + ~got:(s' : string) + ~expected:(s : string) + ]; + let s_strip_zero' = to_string_hum ?delimiter ~decimals ~strip_zero:true f in + if String.(s_strip_zero' <> s_strip_zero) then + raise_s + [%message + "to_string_hum ~strip_zero:true" + ~input:(f : float) + (decimals : int) + ~got:(s_strip_zero : string) + ~expected:(s_strip_zero' : string) + ]; + ;; + + let%test_unit _ = test ~decimals:3 0.99999 "1.000" "1" + let%test_unit _ = test ~decimals:3 0.00001 "0.000" "0" + let%test_unit _ = test ~decimals:3 ~-.12345.1 "-12_345.100" "-12_345.1" + let%test_unit _ = test ~delimiter:',' ~decimals:3 ~-.12345.1 "-12,345.100" "-12,345.1" + let%test_unit _ = test ~decimals:0 0.99999 "1" "1" + let%test_unit _ = test ~decimals:0 0.00001 "0" "0" + let%test_unit _ = test ~decimals:0 ~-.12345.1 "-12_345" "-12_345" + let%test_unit _ = test ~decimals:0 (5.0 /. 0.0) "inf" "inf" + let%test_unit _ = test ~decimals:0 (-5.0 /. 0.0) "-inf" "-inf" + let%test_unit _ = test ~decimals:0 (0.0 /. 0.0) "nan" "nan" + let%test_unit _ = test ~decimals:2 (5.0 /. 0.0) "inf" "inf" + let%test_unit _ = test ~decimals:2 (-5.0 /. 0.0) "-inf" "-inf" + let%test_unit _ = test ~decimals:2 (0.0 /. 0.0) "nan" "nan" + let%test_unit _ = test ~decimals:5 (10_000.0 /. 3.0) "3_333.33333" "3_333.33333" + let%test_unit _ = test ~decimals:2 ~-.0.00001 "-0.00" "-0" + + let rand_test n = + let go () = + let f = Random.float 1_000_000.0 -. 500_000.0 in + let repeatable to_str = + let s = to_str f in + if String.(<>) (String.split s ~on:',' |> String.concat |> of_string |> to_str) s + then raise_s [%message "failed" (f : t)] + in + repeatable (to_string_hum ~decimals:3 ~strip_zero:false); + in + try + for _ = 0 to Int.(-) n 1 do go () done; + true + with e -> + eprintf "%s\n%!" (Exn.to_string e); + false + ;; + + let%test _ = rand_test 10_000 + ;; + end) +;; + +let%test_module "Hexadecimal syntax" = + (module struct + + let should_fail str = Exn.does_raise (fun () -> Caml.float_of_string str) + + let test_equal str g = Caml.float_of_string str = g + let%test _ = should_fail "0x" + let%test _ = should_fail "0x.p0" + let%test _ = test_equal "0x0" 0. + let%test _ = test_equal "0x1.b7p-1" 0.857421875 + let%test _ = test_equal "0x1.999999999999ap-4" 0.1 + + end) +;; + +let%expect_test "square" = + printf "%f\n" (square 1.5); + printf "%f\n" (square (-2.5)); + [%expect {| + 2.250000 + 6.250000 |}] +;; + +let%expect_test "mathematical constants" = + (* Compare to the from-string conversion of numbers from Wolfram Alpha *) + let eq x s = assert (x = of_string s) in + eq pi "3.141592653589793238462643383279502884197169399375105820974"; + eq sqrt_pi "1.772453850905516027298167483341145182797549456122387128213"; + eq sqrt_2pi "2.506628274631000502415765284811045253006986740609938316629"; + eq euler "0.577215664901532860606512090082402431042159335939923598805"; + (* Check size of diff from ordinary computation. *) + printf "sqrt pi diff : %.20f\n" (sqrt_pi - sqrt pi); + printf "sqrt 2pi diff : %.20f\n" (sqrt_2pi - sqrt (2. * pi)); + [%expect {| + sqrt pi diff : 0.00000000000000022204 + sqrt 2pi diff : 0.00000000000000044409 |}] + +let%test _ = not (is_negative Float.nan) +let%test _ = not (is_non_positive Float.nan) +let%test _ = is_non_negative (-0.) + + +let%test_unit "int to float conversion consistency" = + let test_int63 x = + [%test_result: float] (Float.of_int63 x) ~expect:(Float.of_int64 (Int63.to_int64 x)) + in + let test_int x = + [%test_result: float] (Float.of_int x) ~expect:(Float.of_int63 (Int63.of_int x)); + test_int63 (Int63.of_int x) + in + test_int 0; + test_int 35; + test_int (-1); + test_int Int.max_value; + test_int Int.min_value; + + test_int63 Int63.zero; + test_int63 Int63.min_value; + test_int63 Int63.max_value; + + let rand = Random.State.make [| Hashtbl.hash "int to float conversion consistency" |] in + for _i = 0 to 100 do + let x = Random.State.int rand Int.max_value in + test_int x; + done; + () +;; diff --git a/test/test_float.mli b/test/test_float.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_float.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_fn.ml b/test/test_fn.ml new file mode 100644 index 0000000..7e866ed --- /dev/null +++ b/test/test_fn.ml @@ -0,0 +1,12 @@ +open! Import +open! Fn + +(* enforce that we're testing [Fn.(|>)] and not ppx_pipebang. *) +let (_ : 'a -> ('a -> 'b) -> 'b) = (|>) + +let%test _ = 1 |> fun x -> x = 1 +let%test _ = 1 |> fun x -> x + 1 |> fun y -> y = 2 + +let%test _ = 0 = apply_n_times ~n:0 (fun _ -> assert false) 0 +let%test _ = 0 = apply_n_times ~n:(-3) (fun _ -> assert false) 0 +let%test _ = 10 = apply_n_times ~n:10 ((+) 1) 0 diff --git a/test/test_fn.mli b/test/test_fn.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_fn.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_hash_set.ml b/test/test_hash_set.ml new file mode 100644 index 0000000..d6c2731 --- /dev/null +++ b/test/test_hash_set.ml @@ -0,0 +1,64 @@ +open! Import +open! Hash_set + +let%test_module "Set Intersection" = + (module struct + let run_test first_contents second_contents ~expect = + let of_list lst = + let s = create (module String) in + List.iter lst ~f:(add s); + s + in + let s1 = of_list first_contents in + let s2 = of_list second_contents in + let expect = of_list expect in + let result = inter s1 s2 in + iter result ~f:(fun x -> assert (mem expect x)); + iter expect ~f:(fun x -> assert (mem result x)); + let equal x y = 0 = String.compare x y in + assert (List.equal equal (to_list result) (to_list expect)); + assert ((length result) = (length expect)); + (* Make sure the sets are unmodified by the inter *) + assert ((List.length first_contents) = length s1); + assert ((List.length second_contents) = length s2) + ;; + + let%test_unit "First smaller" = + run_test ["0"; "3"; "99"] + ["0";"1";"2";"3"] + ~expect:["0"; "3"] + + let%test_unit "Second smaller" = + run_test ["a";"b";"c";"d"] + [ "b"; "d"] + ~expect:[ "b"; "d"] + + let%test_unit "No intersection" = + run_test ~expect:[] ["a";"b";"c";"d"] ["1";"2";"3";"4"] + end) + +let%expect_test "sexp" = + let ints = List.init 20 ~f:(fun x -> x * x) in + let int_hash_set = Hash_set.of_list (module Int) ints in + print_s [%sexp (int_hash_set : int Hash_set.t)]; + [%expect {| (0 1 4 9 16 25 36 49 64 81 100 121 144 169 196 225 256 289 324 361) |}]; + let strs = List.init 20 ~f:(fun x -> Int.to_string x) in + let str_hash_set = Hash_set.of_list (module String) strs in + print_s [%sexp (str_hash_set : string Hash_set.t)]; + [%expect {| (0 1 10 11 12 13 14 15 16 17 18 19 2 3 4 5 6 7 8 9) |}]; +;; + +let%expect_test "to_array" = + let empty_array = to_array (Hash_set.of_list (module Int) []) in + print_s [%sexp (empty_array : int Array.t)]; + [%expect {| () |}]; + let array_from_to_array = to_array (Hash_set.of_list (module Int) [1; 2; 3; 4; 5;]) in + print_s [%sexp (array_from_to_array : int Array.t)]; + [%expect {| (1 3 2 4 5) |}]; + let array_via_to_list = + to_list (Hash_set.of_list (module Int) [1; 2; 3; 4; 5;]) |> Array.of_list + in + print_s [%sexp (array_via_to_list : int Array.t)]; + [%expect {| (1 3 2 4 5) |}]; +;; + diff --git a/test/test_hash_set.mli b/test/test_hash_set.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_hash_set.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_hashtbl.ml b/test/test_hashtbl.ml new file mode 100644 index 0000000..3e1eda7 --- /dev/null +++ b/test/test_hashtbl.ml @@ -0,0 +1,24 @@ +open! Base + +type int_hashtbl = int Hashtbl.M(Int).t [@@deriving sexp] + +let%test "Hashtbl.merge succeeds with first-class-module interface" = + let t1 = Hashtbl.create (module Int) in + let t2 = Hashtbl.create (module Int) in + let result = + Hashtbl.merge t1 t2 ~f:(fun ~key:_ -> function + | `Left x -> x + | `Right x -> x + | `Both _ -> assert false) + |> Hashtbl.to_alist + in + List.equal Poly.equal result [] + +let%test_module _ = (module Hashtbl_tests.Make(struct + include Hashtbl + + let create_poly ?size () = Poly.create ?size () + + let of_alist_poly_exn l = Poly.of_alist_exn l + let of_alist_poly_or_error l = Poly.of_alist_or_error l + end)) diff --git a/test/test_hashtbl.mli b/test/test_hashtbl.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_hashtbl.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_identifiable.ml b/test/test_identifiable.ml new file mode 100644 index 0000000..2045baa --- /dev/null +++ b/test/test_identifiable.ml @@ -0,0 +1,17 @@ +open! Import +open! Identifiable + +module T = struct + type t = string + include + Make (struct + let module_name = "test" + include String + end) +end + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_hash_coherence [%here] (module T) + ([ ""; "a"; "foo" ] |> List.map ~f:T.of_string); + [%expect {| |}]; +;; diff --git a/test/test_identifiable.mli b/test/test_identifiable.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_identifiable.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_indexed_container.ml b/test/test_indexed_container.ml new file mode 100644 index 0000000..a695c04 --- /dev/null +++ b/test/test_indexed_container.ml @@ -0,0 +1,181 @@ +open! Import + +module type S = Indexed_container.S1 with type 'a t = 'a list + +module This_list : S = struct + include List + include (Indexed_container.Make (struct + type 'a t = 'a list + let fold = List.fold + let iter = `Custom List.iter + let length = `Custom List.length + let foldi = `Define_using_fold + let iteri = `Define_using_fold + end)) +end + +module That_list : S = List + +let examples = + [ [] + ; [1] + ; [2; 3] + ; [4; 5; 1] + ; List.init 8 ~f:(fun i -> i*i) + ] +;; + +module type Output = sig + type t [@@deriving compare, sexp_of] +end + +module Int_list = struct + type t = int list [@@deriving compare, sexp_of] +end + +module Int_pair_option = struct + type t = (int * int) option [@@deriving compare, sexp_of] +end + +module Int_option = struct + type t = int option [@@deriving compare, sexp_of] +end + +let check (type a) + here + examples + ~actual + ~expect + (module Output : Output with type t = a) = + List.iter examples ~f:(fun example -> + let actual = actual example in + let expect = expect example in + require here (Output.compare actual expect = 0) + ~if_false_then_print_s:(lazy [%message (expect : Output.t)]); + print_s [%sexp (actual : Output.t)]); +;; + +let%expect_test "foldi" = + let f i acc elt = + if i % 2 = 0 then elt :: acc else acc + in + check [%here] examples (module Int_list) + ~actual:(fun list -> This_list.foldi list ~init:[] ~f) + ~expect:(fun list -> That_list.foldi list ~init:[] ~f); + [%expect {| + () + (1) + (2) + (1 4) + (36 16 4 0) |}] +;; + +let%expect_test "findi" = + let check f = + check [%here] examples (module Int_pair_option) + ~actual:(fun list -> This_list.findi list ~f) + ~expect:(fun list -> That_list.findi list ~f); + in + check (fun i _elt -> i = 0); + [%expect {| + () + ((0 1)) + ((0 2)) + ((0 4)) + ((0 0)) |}]; + check (fun _i elt -> elt = 1); + [%expect {| + () + ((0 1)) + () + ((2 1)) + ((1 1)) |}]; +;; + +let%expect_test "find_mapi" = + let f i elt = + if elt = 1 then Some (i * 100 + elt) else None + in + check [%here] examples (module Int_option) + ~actual:(fun list -> This_list.find_mapi list ~f) + ~expect:(fun list -> That_list.find_mapi list ~f); + [%expect {| + () + (1) + () + (201) + (101) |}]; +;; + +let%expect_test "iteri" = + let go iteri = + let acc = ref [] in + iteri ~f:(fun i elt -> acc := i :: elt :: !acc); + !acc + in + check [%here] examples (module Int_list) + ~actual:(fun list -> go (This_list.iteri list)) + ~expect:(fun list -> go (That_list.iteri list)); + [%expect {| + () + (0 1) + (1 3 0 2) + (2 1 1 5 0 4) + (7 49 6 36 5 25 4 16 3 9 2 4 1 1 0 0) |}]; +;; + +let bool_examples = + [ []; + [true]; + [false]; + [false; false]; + [true; false]; + [false; true]; + [true; true]; + ] +;; + +let%expect_test "for_alli" = + let f _i elt = elt in + check [%here] bool_examples (module Bool) + ~actual:(fun list -> This_list.for_alli list ~f) + ~expect:(fun list -> That_list.for_alli list ~f); + [%expect {| + true + true + false + false + false + false + true |}]; +;; + +let%expect_test "existsi" = + let f _i elt = elt in + check [%here] bool_examples (module Bool) + ~actual:(fun list -> This_list.existsi list ~f) + ~expect:(fun list -> That_list.existsi list ~f); + [%expect {| + false + true + false + false + true + true + true |}]; +;; + +let%expect_test "counti" = + let f _i elt = elt in + check [%here] bool_examples (module Int) + ~actual:(fun list -> This_list.counti list ~f) + ~expect:(fun list -> That_list.counti list ~f); + [%expect {| + 0 + 1 + 0 + 0 + 1 + 1 + 2 |}]; +;; diff --git a/test/test_indexed_container.mli b/test/test_indexed_container.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_indexed_container.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_info.ml b/test/test_info.ml new file mode 100644 index 0000000..81c9501 --- /dev/null +++ b/test/test_info.ml @@ -0,0 +1,62 @@ +open! Import +open! Info + +let%test_unit _ = + [%test_result: string] (to_string_hum (of_exn (Failure "foo"))) + ~expect:"(Failure foo)" +;; + +let%test_unit _ = + [%test_result: string] (to_string_hum (tag (of_string "b") ~tag:"a")) + ~expect:"(a b)" +;; + +let%test_unit _ = + [%test_result: string] + (to_string_hum (of_list (List.map ~f:of_string [ "a"; "b"; "c" ]))) + ~expect:"(a b c)" +;; + +let of_strings strings = of_list (List.map ~f:of_string strings) + +let nested = + of_list + (List.map ~f:of_strings + [ [ "a"; "b"; "c" ] + ; [ "d"; "e"; "f" ] + ; [ "g"; "h"; "i" ] + ]) +;; + +let%test_unit _ = + [%test_result: string] (to_string_hum nested) ~expect:"(a b c d e f g h i)" +;; + +let%test_unit _ = + [%test_result: Sexp.t] (sexp_of_t nested) + ~expect:(sexp_of_t (of_strings [ "a"; "b"; "c" + ; "d"; "e"; "f" + ; "g"; "h"; "i" ])) +;; + +let%test_unit _ = + match to_exn (of_exn (Failure "foo")) with + | Failure "foo" -> () + | exn -> raise_s [%sexp { got : exn = exn; expected = Failure "foo" }] +;; + +let round t = + let sexp = sexp_of_t t in + Sexp.(=) sexp (sexp_of_t (t_of_sexp sexp)) +;; + +let%test _ = round (of_string "hello") +let%test _ = round (of_thunk (fun () -> "hello")) +let%test _ = round (create "tag" 13 [%sexp_of: int]) +let%test _ = round (tag (of_string "hello") ~tag:"tag") +let%test _ = round (tag_arg (of_string "hello") "tag" 13 + [%sexp_of: int]) +let%test _ = round (of_list [ of_string "hello"; of_string "goodbye" ]) +let%test _ = round (t_of_sexp (Sexplib.Sexp.of_string "((random sexp 1)(b 2)((c (1 2 3))))")) + +let%test _ = String.equal (to_string_hum (of_string "a\nb")) "a\nb" diff --git a/test/test_info.mli b/test/test_info.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_info.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int.ml b/test/test_int.ml new file mode 100644 index 0000000..e380bfe --- /dev/null +++ b/test/test_int.ml @@ -0,0 +1,117 @@ +open! Import +open! Int + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_int_hash_coherence [%here] (module Int); + [%expect {| |}]; +;; + +let%expect_test "[max_value_30_bits]" = + print_s [%sexp (max_value_30_bits : t)]; + [%expect {| + 1_073_741_823 |}]; +;; + +let%expect_test "hex" = + let test x = + let n = Or_error.try_with (fun () -> Int.Hex.of_string x) in + print_s [%message (n : int Or_error.t)] + in + test "0x1c5f"; + [%expect {| + (n (Ok 7_263)) + |}]; + test "0x1c5f NON-HEX-GARBAGE"; + [%expect {| + (n ( + Error ( + Failure + "Base.Int.Hex.of_string: invalid input \"0x1c5f NON-HEX-GARBAGE\""))) + |}] + +let%test_module "Hex" = + (module struct + + let f (i,s_hum) = + let s = String.filter s_hum ~f:(fun c -> not (Char.equal c '_')) in + let sexp_hum = Sexp.Atom s_hum in + let sexp = Sexp.Atom s in + [%test_result: Sexp.t] ~message:"sexp_of_t" ~expect:sexp (Hex.sexp_of_t i); + [%test_result: int] ~message:"t_of_sexp" ~expect:i (Hex.t_of_sexp sexp); + [%test_result: int] ~message:"t_of_sexp[human]" ~expect:i (Hex.t_of_sexp sexp_hum); + [%test_result: string] ~message:"to_string" ~expect:s (Hex.to_string i); + [%test_result: string] ~message:"to_string_hum" ~expect:s_hum (Hex.to_string_hum i); + [%test_result: int] ~message:"of_string" ~expect:i (Hex.of_string s); + [%test_result: int] ~message:"of_string[human]" ~expect:i (Hex.of_string s_hum); + ;; + + let%test_unit _ = + List.iter ~f + [ 0, "0x0" + ; 1, "0x1" + ; 2, "0x2" + ; 5, "0x5" + ; 10, "0xa" + ; 16, "0x10" + ; 254, "0xfe" + ; 65_535, "0xffff" + ; 65_536, "0x1_0000" + ; 1_000_000, "0xf_4240" + ; -1, "-0x1" + ; -2, "-0x2" + ; -1_000_000, "-0xf_4240" + ; max_value, + (match num_bits with + | 31 -> "0x3fff_ffff" + | 32 -> "0x7fff_ffff" + | 63 -> "0x3fff_ffff_ffff_ffff" + | _ -> assert false) + ; min_value, + (match num_bits with + | 31 -> "-0x4000_0000" + | 32 -> "-0x8000_0000" + | 63 -> "-0x4000_0000_0000_0000" + | _ -> assert false) + ] + + let%test_unit _ = + [%test_result: int] (Hex.of_string "0XA") ~expect:10 + + let%test_unit _ = + match Option.try_with (fun () -> Hex.of_string "0") with + | None -> () + | Some _ -> failwith "Hex must always have a 0x prefix." + + let%test_unit _ = + match Option.try_with (fun () -> Hex.of_string "0x_0") with + | None -> () + | Some _ -> failwith "Hex may not have '_' before the first digit." + + end) + +let%test _ = (neg 5 + 5 = 0) + +let%test _ = pow min_value 1 = min_value +let%test _ = pow max_value 1 = max_value + +let%test "comparisons" = + let original_compare (x : int) y = Caml.compare x y in + let valid_compare x y = + let result = compare x y in + let expect = original_compare x y in + assert (Bool.(=) (result < 0) (expect < 0)); + assert (Bool.(=) (result > 0) (expect > 0)); + assert (Bool.(=) (result = 0) (expect = 0)); + assert (result = expect); + in + (valid_compare min_value min_value); + (valid_compare min_value (-1)); + (valid_compare (-1) min_value); + (valid_compare min_value 0); + (valid_compare 0 min_value); + (valid_compare max_value (-1)); + (valid_compare (-1) max_value); + (valid_compare max_value min_value); + (valid_compare max_value max_value); + true +;; diff --git a/test/test_int.mli b/test/test_int.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int32.ml b/test/test_int32.ml new file mode 100644 index 0000000..2be6461 --- /dev/null +++ b/test/test_int32.ml @@ -0,0 +1,8 @@ + +open! Import +open! Int32 + +let%expect_test "hash coherence" = + check_int_hash_coherence [%here] (module Int32); + [%expect {| |}]; +;; diff --git a/test/test_int32.mli b/test/test_int32.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int32.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int32_pow2.ml b/test/test_int32_pow2.ml new file mode 100644 index 0000000..90ad146 --- /dev/null +++ b/test/test_int32_pow2.ml @@ -0,0 +1,114 @@ +open! Import +open! Int32 + +let of_ints = List.map ~f:(of_int_exn);; + +let examples = + of_ints + [ -1 + ; 0 + ; 1 + ; 2 + ; 3 + ; 4 + ; 5 + ; 7 + ; 8 + ; 9 + ; 63 + ; 64 + ; 65] +;; + +let examples_64_bit = + [ min_value + ; succ min_value + ; pred max_value + ; max_value ] +;; + +let print_for ints f = + List.iter ints ~f:(fun i -> + print_s [%message + "" + ~_:(i : int32) + ~_:(Or_error.try_with (fun () -> f i) : int Or_error.t)]) +;; + +let%expect_test "[floor_log2]" = + print_for examples floor_log2; + [%expect {| + (-1 (Error ("[Int32.floor_log2] got invalid input" -1))) + (0 (Error ("[Int32.floor_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 1)) + (4 (Ok 2)) + (5 (Ok 2)) + (7 (Ok 2)) + (8 (Ok 3)) + (9 (Ok 3)) + (63 (Ok 5)) + (64 (Ok 6)) + (65 (Ok 6)) |}]; +;; + +let%expect_test "[floor_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit floor_log2; + [%expect {| + (-2_147_483_648 (Error ("[Int32.floor_log2] got invalid input" -2147483648))) + (-2_147_483_647 (Error ("[Int32.floor_log2] got invalid input" -2147483647))) + (2_147_483_646 (Ok 30)) + (2_147_483_647 (Ok 30)) |}]; +;; + +let%expect_test "[ceil_log2]" = + print_for examples ceil_log2; + [%expect {| + (-1 (Error ("[Int32.ceil_log2] got invalid input" -1))) + (0 (Error ("[Int32.ceil_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 2)) + (4 (Ok 2)) + (5 (Ok 3)) + (7 (Ok 3)) + (8 (Ok 3)) + (9 (Ok 4)) + (63 (Ok 6)) + (64 (Ok 6)) + (65 (Ok 7)) |}]; +;; + +let%expect_test "[ceil_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit ceil_log2; + [%expect {| + (-2_147_483_648 (Error ("[Int32.ceil_log2] got invalid input" -2147483648))) + (-2_147_483_647 (Error ("[Int32.ceil_log2] got invalid input" -2147483647))) + (2_147_483_646 (Ok 31)) + (2_147_483_647 (Ok 31)) |}]; +;; + +let%test_module "int_math" = + (module struct + + let test_cases () = + of_ints [ 0b10101010; 0b1010101010101010; 0b101010101010101010101010; + 0b10000000; 0b1000000000001000; 0b100000000000000000001000; ] + ;; + + let%test_unit "ceil_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = ceil_pow2 x in + assert( (is_pow2 p2) && (p2 >= x && x >= (p2 / of_int_exn 2)) ) + ) + ;; + + let%test_unit "floor_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = floor_pow2 x in + assert( (is_pow2 p2) && ((of_int_exn 2 * p2) >= x && x >= p2) ) + ) + ;; + + end) diff --git a/test/test_int32_pow2.mli b/test/test_int32_pow2.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int32_pow2.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int63.ml b/test/test_int63.ml new file mode 100644 index 0000000..6c90ae7 --- /dev/null +++ b/test/test_int63.ml @@ -0,0 +1,76 @@ +open! Import +open! Int63 + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_int_hash_coherence [%here] (module Int63); + [%expect {| |}]; +;; + +let%test_unit _ = [%test_result: t] max_value ~expect:(of_int64_exn 4611686018427387903L) +let%test_unit _ = [%test_result: t] min_value ~expect:(of_int64_exn (-4611686018427387904L)) + +let%test_unit _ = + [%test_result: t] (of_int32_exn Int32.min_value) ~expect:(of_int32 Int32.min_value) +let%test_unit _ = + [%test_result: t] (of_int32_exn Int32.max_value) ~expect:(of_int32 Int32.max_value) + +let%test "typical random 0" = Exn.does_raise (fun () -> random zero) + +let%test_module "Overflow_exn" = + (module struct + open Overflow_exn + + let%test_module "( + )" = + (module struct + let test t = Exn.does_raise (fun () -> t + t) + let%test "max_value / 2 + 1" = test (succ (max_value / of_int 2)) + let%test "min_value / 2 - 1" = test (pred (min_value / of_int 2)) + let%test "min_value + min_value" = test min_value + let%test "max_value + max_value" = test max_value + end) + ;; + + let%test_module "( - )" = + (module struct + let%test "min_value - 1" = Exn.does_raise (fun () -> min_value - one) + let%test "max_value - -1" = Exn.does_raise (fun () -> max_value - neg one) + let%test "min_value / 2 - max_value / 2 - 2" = + Exn.does_raise (fun () -> min_value / of_int 2 - max_value / of_int 2 - of_int 2) + let%test "min_value - max_value" = Exn.does_raise (fun () -> min_value - max_value) + let%test "max_value - min_value" = Exn.does_raise (fun () -> max_value - min_value) + let%test "max_value - -max_value" = + Exn.does_raise (fun () -> max_value - neg max_value) + end) + ;; + end) + +let%expect_test "[floor_log2]" = + let floor_log2 t = print_s [%sexp (floor_log2 t : int)] in + show_raise (fun () -> floor_log2 zero); + [%expect {| + (raised ("[Int.floor_log2] got invalid input" 0)) |}]; + floor_log2 one; + [%expect {| + 0 |}]; + for i = 1 to 8 do + floor_log2 (i |> of_int); + done; + [%expect {| + 0 + 1 + 1 + 2 + 2 + 2 + 2 + 3 |}]; + floor_log2 (one lsl 61 - one); + [%expect {| + 60 |}]; + floor_log2 (one lsl 61); + [%expect {| + 61 |}]; + floor_log2 max_value; + [%expect {| + 61 |}]; +;; diff --git a/test/test_int63.mli b/test/test_int63.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int63.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int63_emul.ml b/test/test_int63_emul.ml new file mode 100644 index 0000000..16b7031 --- /dev/null +++ b/test/test_int63_emul.ml @@ -0,0 +1,14 @@ +open! Core_kernel +open! Expect_test_helpers_kernel + +module Int63_emul = Base.Not_exposed_properly.Int63_emul + +let%expect_test _ = + let s63 = Int63.( Hex.to_string min_value) in + let s63_emul = Int63_emul.(Hex.to_string min_value) in + print_s [%message (s63 : string) (s63_emul : string)]; + require [%here] (String.equal s63 s63_emul); + [%expect {| + ((s63 -0x4000000000000000) + (s63_emul -0x4000000000000000)) |}]; +;; diff --git a/test/test_int63_emul.mli b/test/test_int63_emul.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int63_emul.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int64.ml b/test/test_int64.ml new file mode 100644 index 0000000..38b6521 --- /dev/null +++ b/test/test_int64.ml @@ -0,0 +1,7 @@ +open! Import +open! Int64 + +let%expect_test "hash coherence" = + check_int_hash_coherence [%here] (module Int64); + [%expect {| |}]; +;; diff --git a/test/test_int64.mli b/test/test_int64.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int64.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int64_pow2.ml b/test/test_int64_pow2.ml new file mode 100644 index 0000000..aa37f72 --- /dev/null +++ b/test/test_int64_pow2.ml @@ -0,0 +1,123 @@ +open! Import +open! Int64 + +let examples = + [ -1L + ; 0L + ; 1L + ; 2L + ; 3L + ; 4L + ; 5L + ; 7L + ; 8L + ; 9L + ; 63L + ; 64L + ; 65L + ] +;; + +let examples_64_bit = + [ min_value + ; succ min_value + ; pred max_value + ; max_value ] +;; + +let print_for ints f = + List.iter ints ~f:(fun i -> + print_s [%message + "" + ~_:(i : int64) + ~_:(Or_error.try_with (fun () -> f i) : int Or_error.t)]) +;; + +let%expect_test "[floor_log2]" = + print_for examples floor_log2; + [%expect {| + (-1 (Error ("[Int64.floor_log2] got invalid input" -1))) + (0 (Error ("[Int64.floor_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 1)) + (4 (Ok 2)) + (5 (Ok 2)) + (7 (Ok 2)) + (8 (Ok 3)) + (9 (Ok 3)) + (63 (Ok 5)) + (64 (Ok 6)) + (65 (Ok 6)) |}]; +;; + +let%expect_test "[floor_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit floor_log2; + [%expect {| + (-9_223_372_036_854_775_808 ( + Error ("[Int64.floor_log2] got invalid input" -9223372036854775808))) + (-9_223_372_036_854_775_807 ( + Error ("[Int64.floor_log2] got invalid input" -9223372036854775807))) + (9_223_372_036_854_775_806 (Ok 62)) + (9_223_372_036_854_775_807 (Ok 62)) |}]; +;; + +let%expect_test "[ceil_log2]" = + print_for examples ceil_log2; + [%expect {| + (-1 (Error ("[Int64.ceil_log2] got invalid input" -1))) + (0 (Error ("[Int64.ceil_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 2)) + (4 (Ok 2)) + (5 (Ok 3)) + (7 (Ok 3)) + (8 (Ok 3)) + (9 (Ok 4)) + (63 (Ok 6)) + (64 (Ok 6)) + (65 (Ok 7)) |}]; +;; + +let%expect_test "[ceil_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit ceil_log2; + [%expect {| + (-9_223_372_036_854_775_808 ( + Error ("[Int64.ceil_log2] got invalid input" -9223372036854775808))) + (-9_223_372_036_854_775_807 ( + Error ("[Int64.ceil_log2] got invalid input" -9223372036854775807))) + (9_223_372_036_854_775_806 (Ok 63)) + (9_223_372_036_854_775_807 (Ok 63)) |}]; +;; + +let%test_module "int64_math" = + (module struct + + let test_cases () = + let cases = + [ 0b10101010L; 0b1010101010101010L; 0b101010101010101010101010L; + 0b10000000L; 0b1000000000001000L; 0b100000000000000000001000L; ] + in + let cases = + cases @ [ (0b1010101010101010L lsl 16) lor 0b1010101010101010L; + (0b1000000000000000L lsl 16) lor 0b0000000000001000L; ] + in + let added_cases = List.map cases ~f:(fun x -> x lsl 16) in + List.concat [ cases; added_cases ] + ;; + + let%test_unit "ceil_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = ceil_pow2 x in + assert( (is_pow2 p2) && (p2 >= x && x >= (p2 / 2L)) ) + ) + ;; + + let%test_unit "floor_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = floor_pow2 x in + assert( (is_pow2 p2) && ((2L * p2) >= x && x >= p2) ) + ) + ;; + end) diff --git a/test/test_int64_pow2.mli b/test/test_int64_pow2.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int64_pow2.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int_conversions.ml b/test/test_int_conversions.ml new file mode 100644 index 0000000..32b1af4 --- /dev/null +++ b/test/test_int_conversions.ml @@ -0,0 +1,170 @@ +open! Import +open! Base.Not_exposed_properly.Int_conversions + +let%test_module "pretty" = + (module struct + + let check input output = + List.for_all [""; "+"; "-"] ~f:(fun prefix -> + let input = prefix ^ input in + let output = prefix ^ output in + [%compare.equal: string] output (insert_underscores input)) + + let%test _ = check "1" "1" + let%test _ = check "12" "12" + let%test _ = check "123" "123" + let%test _ = check "1234" "1_234" + let%test _ = check "12345" "12_345" + let%test _ = check "123456" "123_456" + let%test _ = check "1234567" "1_234_567" + let%test _ = check "12345678" "12_345_678" + let%test _ = check "123456789" "123_456_789" + let%test _ = check "1234567890" "1_234_567_890" + + end) + +let%test_module "conversions" = + (module struct + module type S = sig + include Int.S + val module_name : string + end + + let test_conversion (type a) (type b) loc ma mb + a_to_b_or_error + a_to_b_trunc + b_to_a_trunc + = + let (module A : S with type t = a) = ma in + let (module B : S with type t = b) = mb in + let examples = + [ A.min_value + ; A.minus_one + ; A.zero + ; A.one + ; A.max_value + ; B.min_value |> b_to_a_trunc + ; B.max_value |> b_to_a_trunc + ] + |> List.concat_map ~f:(fun a -> [ A.pred a; a; A.succ a ]) + |> List.dedup_and_sort ~compare:A.compare + |> List.sort ~compare:A.compare + in + List.iter examples ~f:(fun a -> + let b' = a_to_b_trunc a in + let a' = b_to_a_trunc b' in + match a_to_b_or_error a with + | Ok b -> + require loc (B.equal b b') + ~if_false_then_print_s: + (lazy [%message + "conversion produced wrong value" + ~from: (A.module_name : string) + ~to_: (B.module_name : string) + ~input: (a : A.t) + ~output: (b : B.t) + ~expected: (b' : B.t)]); + require loc (A.equal a a') + ~if_false_then_print_s: + (lazy + [%message + "conversion does not round-trip" + ~from: (A.module_name : string) + ~to_: (B.module_name : string) + ~input: (a : A.t) + ~output: (b : B.t) + ~round_trip: (a' : A.t)]) + | Error error -> + require loc (not (A.equal a a')) + ~if_false_then_print_s: + (lazy + [%message + "conversion failed" + ~from: (A.module_name : string) + ~to_: (B.module_name : string) + ~input: (a : A.t) + ~expected_output: (b' : B.t) + ~error: (error : Error.t)])) + + let test loc ma mb + (a_to_b_trunc, a_to_b_or_error) + (b_to_a_trunc, b_to_a_or_error) + = + test_conversion loc ma mb a_to_b_or_error a_to_b_trunc b_to_a_trunc; + test_conversion loc mb ma b_to_a_or_error b_to_a_trunc a_to_b_trunc + + module Int = struct include Int let module_name = "Int" end + module Int32 = struct include Int32 let module_name = "Int32" end + module Int64 = struct include Int64 let module_name = "Int64" end + module Nativeint = struct include Nativeint let module_name = "Nativeint" end + + let with_exn f x = Or_error.try_with (fun () -> f x) + let optional f x = Or_error.try_with (fun () -> Option.value_exn (f x)) + let alwaysok f x = Ok (f x) + + let%expect_test "int <-> int32" = + test [%here] (module Int) (module Int32) + (Caml.Int32.of_int, with_exn int_to_int32_exn) + (Caml.Int32.to_int, with_exn int32_to_int_exn); + [%expect {| |}]; + test [%here] (module Int) (module Int32) + (Caml.Int32.of_int, optional int_to_int32) + (Caml.Int32.to_int, optional int32_to_int); + [%expect {| |}]; + ;; + + let%expect_test "int <-> int64" = + test [%here] (module Int) (module Int64) + (Caml.Int64.of_int, alwaysok int_to_int64) + (Caml.Int64.to_int, with_exn int64_to_int_exn); + [%expect {| |}]; + test [%here] (module Int) (module Int64) + (Caml.Int64.of_int, alwaysok int_to_int64) + (Caml.Int64.to_int, optional int64_to_int); + [%expect {| |}]; + ;; + + let%expect_test "int <-> nativeint" = + test [%here] (module Int) (module Nativeint) + (Caml.Nativeint.of_int, alwaysok int_to_nativeint) + (Caml.Nativeint.to_int, with_exn nativeint_to_int_exn); + [%expect {| |}]; + test [%here] (module Int) (module Nativeint) + (Caml.Nativeint.of_int, alwaysok int_to_nativeint) + (Caml.Nativeint.to_int, optional nativeint_to_int); + [%expect {| |}]; + ;; + + let%expect_test "int32 <-> int64" = + test [%here] (module Int32) (module Int64) + (Caml.Int64.of_int32, alwaysok int32_to_int64) + (Caml.Int64.to_int32, with_exn int64_to_int32_exn); + [%expect {| |}]; + test [%here] (module Int32) (module Int64) + (Caml.Int64.of_int32, alwaysok int32_to_int64) + (Caml.Int64.to_int32, optional int64_to_int32); + [%expect {| |}]; + ;; + + let%expect_test "int32 <-> nativeint" = + test [%here] (module Int32) (module Nativeint) + (Caml.Nativeint.of_int32, alwaysok int32_to_nativeint) + (Caml.Nativeint.to_int32, with_exn nativeint_to_int32_exn); + [%expect {| |}]; + test [%here] (module Int32) (module Nativeint) + (Caml.Nativeint.of_int32, alwaysok int32_to_nativeint) + (Caml.Nativeint.to_int32, optional nativeint_to_int32); + [%expect {| |}]; + ;; + + let%expect_test "int64 <-> nativeint" = + test [%here] (module Int64) (module Nativeint) + (Caml.Int64.to_nativeint, with_exn int64_to_nativeint_exn) + (Caml.Int64.of_nativeint, alwaysok nativeint_to_int64); + [%expect {| |}]; + test [%here] (module Int64) (module Nativeint) + (Caml.Int64.to_nativeint, optional int64_to_nativeint) + (Caml.Int64.of_nativeint, alwaysok nativeint_to_int64); + [%expect {| |}]; + ;; + end) diff --git a/test/test_int_conversions.mli b/test/test_int_conversions.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int_conversions.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int_hash.ml b/test/test_int_hash.ml new file mode 100644 index 0000000..6ad8fa0 --- /dev/null +++ b/test/test_int_hash.ml @@ -0,0 +1,8 @@ +open! Base +open! Import + +let%expect_test "int hash is not ident" [@tags "64-bits-only"] = + print_s [%message "hash of 10" + (Int.hash 10 : int) ]; + [%expect {| ("hash of 10" ("Int.hash 10" 1_579_120_067_278_557_813)) |}]; +;; diff --git a/test/test_int_hash.mli b/test/test_int_hash.mli new file mode 100644 index 0000000..1b56838 --- /dev/null +++ b/test/test_int_hash.mli @@ -0,0 +1 @@ +(* this interface is deliberately empty *) diff --git a/test/test_int_math.ml b/test/test_int_math.ml new file mode 100644 index 0000000..99b9b3e --- /dev/null +++ b/test/test_int_math.ml @@ -0,0 +1,192 @@ +open! Import +open! Base.Not_exposed_properly.Int_math + +let%test_unit _ = + let x = match Word_size.word_size with W32 -> 9 | W64 -> 10 in + for i = 0 to x do + for j = 0 to x do + assert (int_pow i j + = Caml.(int_of_float ((float_of_int i) ** (float_of_int j)))) + done + done + +module Make (X : T) : sig end = struct + open X + include Make(X) + + let%test_module "integer-rounding" = + (module struct + + let check dir ~range:(lower, upper) ~modulus expected = + let modulus = of_int_exn modulus in + let expected = of_int_exn expected in + for i = lower to upper do + let observed = round ~dir ~to_multiple_of:modulus (of_int_exn i) in + if observed <> expected then + raise_s [%message "invalid result" (i : int)] + done + ;; + + let%test_unit _ = check ~modulus:10 `Down ~range:( 10, 19) 10 + let%test_unit _ = check ~modulus:10 `Down ~range:( 0, 9) 0 + let%test_unit _ = check ~modulus:10 `Down ~range:(-10, -1) (-10) + let%test_unit _ = check ~modulus:10 `Down ~range:(-20, -11) (-20) + + let%test_unit _ = check ~modulus:10 `Up ~range:( 11, 20) 20 + let%test_unit _ = check ~modulus:10 `Up ~range:( 1, 10) 10 + let%test_unit _ = check ~modulus:10 `Up ~range:( -9, 0) 0 + let%test_unit _ = check ~modulus:10 `Up ~range:(-19, -10) (-10) + + let%test_unit _ = check ~modulus:10 `Zero ~range:( 10, 19) 10 + let%test_unit _ = check ~modulus:10 `Zero ~range:( -9, 9) 0 + let%test_unit _ = check ~modulus:10 `Zero ~range:(-19, -10) (-10) + + let%test_unit _ = check ~modulus:10 `Nearest ~range:( 15, 24) 20 + let%test_unit _ = check ~modulus:10 `Nearest ~range:( 5, 14) 10 + let%test_unit _ = check ~modulus:10 `Nearest ~range:( -5, 4) 0 + let%test_unit _ = check ~modulus:10 `Nearest ~range:(-15, -6) (-10) + let%test_unit _ = check ~modulus:10 `Nearest ~range:(-25, -16) (-20) + + let%test_unit _ = check ~modulus:5 `Nearest ~range:( 8, 12) 10 + let%test_unit _ = check ~modulus:5 `Nearest ~range:( 3, 7) 5 + let%test_unit _ = check ~modulus:5 `Nearest ~range:( -2, 2) 0 + let%test_unit _ = check ~modulus:5 `Nearest ~range:( -7, -3) (-5) + let%test_unit _ = check ~modulus:5 `Nearest ~range:(-12, -8) (-10) + end) + + let%test_module "remainder-and-modulus" = + (module struct + + let one = of_int_exn 1 + + let check_integers x y = + let sexp_of_t t = sexp_of_string (to_string t) in + let check_raises f what = + match f () with + | exception _ -> () + | z -> + raise_s + [%message "produced result instead of raising" + (what : string) + (x : t) + (y : t) + (z : t)] + in + let check_true cond what = + if not cond + then raise_s + [%message "failed" + (what : string) + (x : t) + (y : t)] + in + if y = zero + then + begin + check_raises (fun () -> x / y) "division by zero"; + check_raises (fun () -> rem x y) "rem _ zero"; + check_raises (fun () -> x % y) "_ % zero"; + check_raises (fun () -> x /% y) "_ /% zero"; + end + else + begin + if x < zero + then check_true (rem x y <= zero) "non-positive remainder" + else check_true (rem x y >= zero) "non-negative remainder"; + check_true (abs (rem x y) <= abs y - one) "range of remainder"; + if y < zero then begin + check_raises (fun () -> x % y) "_ % negative"; + check_raises (fun () -> x /% y) "_ /% negative" + end + else begin + check_true (x = (x /% y) * y + (x % y)) "(/%) and (%) identity"; + check_true (x = (x / y) * y + (rem x y)) "(/) and rem identity"; + check_true (x % y >= zero) "non-negative (%)"; + check_true (x % y <= y - one) "range of (%)"; + if x > zero && y > zero + then begin + check_true (x /% y = x / y) "(/%) and (/) identity"; + check_true (x % y = rem x y) "(%) and rem identity" + end; + end + end + ;; + + let check_natural_numbers x y = + List.iter [ x ; -x ; x+one ; -(x + one) ] ~f:(fun x -> + List.iter [ y ; -y ; y+one ; -(y + one) ] ~f:(fun y -> + check_integers x y)) + + let%test_unit "deterministic" = + let big1 = of_int_exn 118_310_344 in + let big2 = of_int_exn 828_172_408 in + (* Important to test the case where one value is a multiple of the other. Note that + the [x + one] and [y + one] cases in [check_natural_numbers] ensure that we also + test non-multiple cases. *) + assert (big2 = big1 * of_int_exn 7); + let values = [ zero ; one ; big1 ; big2 ] in + List.iter values ~f:(fun x -> + List.iter values ~f:(fun y -> + check_natural_numbers x y)) + + let%test_unit "random" = + let rand = Random.State.make [| 8; 67; -5_309 |] in + for _ = 0 to 1_000 do + let max_value = 1_000_000_000 in + let x = of_int_exn (Random.State.int rand max_value) in + let y = of_int_exn (Random.State.int rand max_value) in + check_natural_numbers x y + done + end) +end + +include Make (Int) +include Make (Int32) +include Make (Int63) +include Make (Int64) +include Make (Nativeint) + +let%test_module "pow" = + (module struct + let%test _ = int_pow 0 0 = 1 + let%test _ = int_pow 0 1 = 0 + let%test _ = int_pow 10 1 = 10 + let%test _ = int_pow 10 2 = 100 + let%test _ = int_pow 10 3 = 1_000 + let%test _ = int_pow 10 4 = 10_000 + let%test _ = int_pow 10 5 = 100_000 + let%test _ = int_pow 2 10 = 1024 + + let%test _ = int_pow 0 1_000_000 = 0 + let%test _ = int_pow 1 1_000_000 = 1 + let%test _ = int_pow (-1) 1_000_000 = 1 + let%test _ = int_pow (-1) 1_000_001 = -1 + + let (=) = Int64.(=) + let%test _ = int64_pow 0L 0L = 1L + let%test _ = int64_pow 0L 1_000_000L = 0L + let%test _ = int64_pow 1L 1_000_000L = 1L + let%test _ = int64_pow (-1L) 1_000_000L = 1L + let%test _ = int64_pow (-1L) 1_000_001L = -1L + + let%test _ = int64_pow 10L 1L = 10L + let%test _ = int64_pow 10L 2L = 100L + let%test _ = int64_pow 10L 3L = 1_000L + let%test _ = int64_pow 10L 4L = 10_000L + let%test _ = int64_pow 10L 5L = 100_000L + let%test _ = int64_pow 2L 10L = 1_024L + let%test _ = int64_pow 5L 27L = 7450580596923828125L + + let exception_thrown pow b e = Exn.does_raise (fun () -> pow b e) + + let%test _ = exception_thrown int_pow 10 60 + let%test _ = exception_thrown int64_pow 10L 60L + let%test _ = exception_thrown int_pow 10 (-1) + let%test _ = exception_thrown int64_pow 10L (-1L) + + let%test _ = exception_thrown int64_pow 2L 63L + let%test _ = not (exception_thrown int64_pow 2L 62L) + + let%test _ = exception_thrown int64_pow (-2L) 63L + let%test _ = not (exception_thrown int64_pow (-2L) 62L) + end) diff --git a/test/test_int_math.mli b/test/test_int_math.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int_math.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int_pow2.ml b/test/test_int_pow2.ml new file mode 100644 index 0000000..9f7d76d --- /dev/null +++ b/test/test_int_pow2.ml @@ -0,0 +1,126 @@ +open! Import +open! Int + +let examples = + [ -1 + ; 0 + ; 1 + ; 2 + ; 3 + ; 4 + ; 5 + ; 7 + ; 8 + ; 9 + ; 63 + ; 64 + ; 65 ] +;; + +let examples_64_bit = + [ Int.min_value + ; Int.min_value + 1 + ; Int.max_value - 1 + ; Int.max_value ] +;; + +let print_for ints f = + List.iter ints ~f:(fun i -> + print_s [%message + "" + ~_:(i : int) + ~_:(Or_error.try_with (fun () -> f i) : int Or_error.t)]) +;; + +let%expect_test "[floor_log2]" = + print_for examples floor_log2; + [%expect {| + (-1 (Error ("[Int.floor_log2] got invalid input" -1))) + (0 (Error ("[Int.floor_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 1)) + (4 (Ok 2)) + (5 (Ok 2)) + (7 (Ok 2)) + (8 (Ok 3)) + (9 (Ok 3)) + (63 (Ok 5)) + (64 (Ok 6)) + (65 (Ok 6)) |}]; +;; + +let%expect_test "[floor_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit floor_log2; + [%expect {| + (-4_611_686_018_427_387_904 ( + Error ("[Int.floor_log2] got invalid input" -4611686018427387904))) + (-4_611_686_018_427_387_903 ( + Error ("[Int.floor_log2] got invalid input" -4611686018427387903))) + (4_611_686_018_427_387_902 (Ok 61)) + (4_611_686_018_427_387_903 (Ok 61)) |}]; +;; + +let%expect_test "[ceil_log2]" = + print_for examples ceil_log2; + [%expect {| + (-1 (Error ("[Int.ceil_log2] got invalid input" -1))) + (0 (Error ("[Int.ceil_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 2)) + (4 (Ok 2)) + (5 (Ok 3)) + (7 (Ok 3)) + (8 (Ok 3)) + (9 (Ok 4)) + (63 (Ok 6)) + (64 (Ok 6)) + (65 (Ok 7)) |}]; +;; + +let%expect_test "[ceil_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit ceil_log2; + [%expect {| + (-4_611_686_018_427_387_904 ( + Error ("[Int.ceil_log2] got invalid input" -4611686018427387904))) + (-4_611_686_018_427_387_903 ( + Error ("[Int.ceil_log2] got invalid input" -4611686018427387903))) + (4_611_686_018_427_387_902 (Ok 62)) + (4_611_686_018_427_387_903 (Ok 62)) |}]; +;; + +let%test_module "int_math" = + (module struct + + let test_cases () = + let cases = + [ 0b10101010; 0b1010101010101010; 0b101010101010101010101010; + 0b10000000; 0b1000000000001000; 0b100000000000000000001000; ] + in + match Word_size.word_size with + | W64 -> (* create some >32 bit values... *) + (* We can't use literals directly because the compiler complains on 32 bits. *) + let cases = + cases @ [ (0b1010101010101010 lsl 16) lor 0b1010101010101010; + (0b1000000000000000 lsl 16) lor 0b0000000000001000; ] + in + let added_cases = List.map cases ~f:(fun x -> x lsl 16) in + List.concat [ cases; added_cases ] + | W32 -> cases + ;; + + let%test_unit "ceil_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = ceil_pow2 x in + assert( (is_pow2 p2) && (p2 >= x && x >= (p2 / 2)) ) + ) + ;; + + let%test_unit "floor_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = floor_pow2 x in + assert( (is_pow2 p2) && ((2 * p2) >= x && x >= p2) ) + ) + ;; + end) diff --git a/test/test_int_pow2.mli b/test/test_int_pow2.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int_pow2.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_lazy.ml b/test/test_lazy.ml new file mode 100644 index 0000000..55603a0 --- /dev/null +++ b/test/test_lazy.ml @@ -0,0 +1,55 @@ +open! Import +open! Lazy + +let%test_unit _ = + let r = ref 0 in + let t = return () >>= fun () -> Int.incr r; return () in + assert (!r = 0); + force t; + assert (!r = 1); + force t; + assert (!r = 1) +;; + +let%test_unit _ = + let r = ref 0 in + let t = return () >>= fun () -> lazy (Int.incr r) in + assert (!r = 0); + force t; + assert (!r = 1); + force t; + assert (!r = 1) +;; + +let%test_module _ = + (module struct + + module M1 = struct + type nonrec t = { x : int t } [@@deriving sexp_of] + end + + module M2 = struct + type t = { x : int T_unforcing.t } [@@deriving sexp_of] + end + + let%test_unit _ = + let v = lazy 42 in + let (_ : int) = + (* no needed, but the purpose of this test is not to test this compiler + optimization *) + force v + in + assert (is_val v); + let t1 = { M1. x = v } in + let t2 = { M2. x = v } in + assert (Sexp.equal (M1.sexp_of_t t1) (M2.sexp_of_t t2)) + ;; + + let%test_unit _ = + let t1 = { M1. x = lazy (40 + 2) } in + let t2 = { M2. x = lazy (40 + 2) } in + assert (not (Sexp.equal (M1.sexp_of_t t1) (M2.sexp_of_t t2))); + assert (is_val t1.x); + assert (not (is_val t2.x)) + ;; + end) diff --git a/test/test_lazy.mli b/test/test_lazy.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_lazy.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_list.ml b/test/test_list.ml new file mode 100644 index 0000000..b6d8181 --- /dev/null +++ b/test/test_list.ml @@ -0,0 +1,610 @@ +open! Import +open! List + +let%test_module "reduce_balanced" = + (module struct + let test expect list = + [%test_result: string option] ~expect + (reduce_balanced ~f:(fun a b -> "(" ^ a ^ "+" ^ b ^ ")") list) + + let%test_unit "length 0" = + test None [] + + let%test_unit "length 1" = + test (Some "a") ["a"] + + let%test_unit "length 2" = + test (Some "(a+b)") ["a"; "b"] + + let%test_unit "length 6" = + test (Some "(((a+b)+(c+d))+(e+f))") ["a";"b";"c";"d";"e";"f"] + + let%test_unit "longer" = + (* pairs (index, number of times f called on me) to check: + 1. f called on results in index order + 2. total number of calls on any element is low + called on 2^n + 1 to demonstrate lack of balance (most elements are distance 7 from + the tree root, but one is distance 1) *) + let data = map (range 0 65) ~f:(fun i -> [(i, 0)]) in + let f x y = map (x @ y) ~f:(fun (ix, cx) -> (ix, cx + 1)) in + match reduce_balanced data ~f with + | None -> failwith "None" + | Some l -> + [%test_result: int] ~expect:65 (List.length l); + iteri l ~f:(fun actual_index (computed_index, num_f) -> + let expected_num_f = if actual_index = 64 then 1 else 7 in + [%test_result: int * int] + ~expect:(actual_index, expected_num_f) (computed_index, num_f)) + end) + +let%test_module "range symmetries" = + (module struct + + let basic ~stride ~start ~stop ~start_n ~stop_n ~result = + [%compare.equal: int t] (range ~stride ~start ~stop start_n stop_n) result + + let test stride (start_n, start) (stop_n, stop) result = + basic ~stride ~start ~stop ~start_n ~stop_n ~result + && (* works for negative [start] and [stop] *) + basic ~stride:(-stride) + ~start_n:(-start_n) + ~stop_n:(-stop_n) + ~start + ~stop + ~result:(List.map result ~f:(fun x -> -x)) + + let%test _ = test 1 ( 3, `inclusive) ( 1, `exclusive) [] + let%test _ = test 1 ( 3, `inclusive) ( 3, `exclusive) [] + let%test _ = test 1 ( 3, `inclusive) ( 4, `exclusive) [3] + let%test _ = test 1 ( 3, `inclusive) ( 8, `exclusive) [3;4;5;6;7] + let%test _ = test 3 ( 4, `inclusive) (10, `exclusive) [4;7] + let%test _ = test 3 ( 4, `inclusive) (11, `exclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (12, `exclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (13, `exclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (14, `exclusive) [4;7;10;13] + + let%test _ = test (-1) ( 1, `inclusive) ( 3, `exclusive) [] + let%test _ = test (-1) ( 3, `inclusive) ( 3, `exclusive) [] + let%test _ = test (-1) ( 4, `inclusive) ( 3, `exclusive) [4] + let%test _ = test (-1) ( 8, `inclusive) ( 3, `exclusive) [8;7;6;5;4] + let%test _ = test (-3) (10, `inclusive) ( 4, `exclusive) [10;7] + let%test _ = test (-3) (10, `inclusive) ( 3, `exclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 2, `exclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 1, `exclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 0, `exclusive) [10;7;4;1] + + let%test _ = test 1 ( 3, `exclusive) ( 1, `exclusive) [] + let%test _ = test 1 ( 3, `exclusive) ( 3, `exclusive) [] + let%test _ = test 1 ( 3, `exclusive) ( 4, `exclusive) [] + let%test _ = test 1 ( 3, `exclusive) ( 8, `exclusive) [4;5;6;7] + let%test _ = test 3 ( 4, `exclusive) (10, `exclusive) [7] + let%test _ = test 3 ( 4, `exclusive) (11, `exclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (12, `exclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (13, `exclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (14, `exclusive) [7;10;13] + + let%test _ = test (-1) ( 1, `exclusive) ( 3, `exclusive) [] + let%test _ = test (-1) ( 3, `exclusive) ( 3, `exclusive) [] + let%test _ = test (-1) ( 4, `exclusive) ( 3, `exclusive) [] + let%test _ = test (-1) ( 8, `exclusive) ( 3, `exclusive) [7;6;5;4] + let%test _ = test (-3) (10, `exclusive) ( 4, `exclusive) [7] + let%test _ = test (-3) (10, `exclusive) ( 3, `exclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 2, `exclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 1, `exclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 0, `exclusive) [7;4;1] + + let%test _ = test 1 ( 3, `inclusive) ( 1, `inclusive) [] + let%test _ = test 1 ( 3, `inclusive) ( 3, `inclusive) [3] + let%test _ = test 1 ( 3, `inclusive) ( 4, `inclusive) [3;4] + let%test _ = test 1 ( 3, `inclusive) ( 8, `inclusive) [3;4;5;6;7;8] + let%test _ = test 3 ( 4, `inclusive) (10, `inclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (11, `inclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (12, `inclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (13, `inclusive) [4;7;10;13] + let%test _ = test 3 ( 4, `inclusive) (14, `inclusive) [4;7;10;13] + + let%test _ = test (-1) ( 1, `inclusive) ( 3, `inclusive) [] + let%test _ = test (-1) ( 3, `inclusive) ( 3, `inclusive) [3] + let%test _ = test (-1) ( 4, `inclusive) ( 3, `inclusive) [4;3] + let%test _ = test (-1) ( 8, `inclusive) ( 3, `inclusive) [8;7;6;5;4;3] + let%test _ = test (-3) (10, `inclusive) ( 4, `inclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 3, `inclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 2, `inclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 1, `inclusive) [10;7;4;1] + let%test _ = test (-3) (10, `inclusive) ( 0, `inclusive) [10;7;4;1] + + let%test _ = test 1 ( 3, `exclusive) ( 1, `inclusive) [] + let%test _ = test 1 ( 3, `exclusive) ( 3, `inclusive) [] + let%test _ = test 1 ( 3, `exclusive) ( 4, `inclusive) [4] + let%test _ = test 1 ( 3, `exclusive) ( 8, `inclusive) [4;5;6;7;8] + let%test _ = test 3 ( 4, `exclusive) (10, `inclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (11, `inclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (12, `inclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (13, `inclusive) [7;10;13] + let%test _ = test 3 ( 4, `exclusive) (14, `inclusive) [7;10;13] + + let%test _ = test (-1) ( 1, `exclusive) ( 3, `inclusive) [] + let%test _ = test (-1) ( 3, `exclusive) ( 3, `inclusive) [] + let%test _ = test (-1) ( 4, `exclusive) ( 3, `inclusive) [3] + let%test _ = test (-1) ( 8, `exclusive) ( 3, `inclusive) [7;6;5;4;3] + let%test _ = test (-3) (10, `exclusive) ( 4, `inclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 3, `inclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 2, `inclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 1, `inclusive) [7;4;1] + let%test _ = test (-3) (10, `exclusive) ( 0, `inclusive) [7;4;1] + + let test_start_inc_exc stride start (stop, stop_inc_exc) result = + test stride (start, `inclusive) (stop, stop_inc_exc) result + && begin + match result with + | [] -> true + | head :: tail -> + head = start && test stride (start, `exclusive) (stop, stop_inc_exc) tail + end + + let test_inc_exc stride start stop result = + test_start_inc_exc stride start (stop, `inclusive) result + && begin + match List.rev result with + | [] -> true + | last :: all_but_last -> + let all_but_last = List.rev all_but_last in + if last = stop then + test_start_inc_exc stride start (stop, `exclusive) all_but_last + else + true + end + + let%test _ = test_inc_exc 1 4 10 [4;5;6;7;8;9;10] + let%test _ = test_inc_exc 3 4 10 [4;7;10] + let%test _ = test_inc_exc 3 4 11 [4;7;10] + let%test _ = test_inc_exc 3 4 12 [4;7;10] + let%test _ = test_inc_exc 3 4 13 [4;7;10;13] + let%test _ = test_inc_exc 3 4 14 [4;7;10;13] + + end) + +module Test_values = struct + let long1 = + let v = lazy (range 1 100_000) in + fun () -> Lazy.force v + + let long2 = + let v = lazy (range 2 100_001) in + fun () -> Lazy.force v + + let l1 = [1;2;3;4;5;6;7;8;9;10] +end + +let%test_unit _ = [%test_result: int list] (rev_append [1;2;3] [4;5;6]) ~expect:[3;2;1;4;5;6] +let%test_unit _ = [%test_result: int list] (rev_append [] [4;5;6]) ~expect:[4;5;6] +let%test_unit _ = [%test_result: int list] (rev_append [1;2;3] []) ~expect:[3;2;1] +let%test_unit _ = [%test_result: int list] (rev_append [1] [2;3]) ~expect:[1;2;3] +let%test_unit _ = [%test_result: int list] (rev_append [1;2] [3]) ~expect:[2;1;3] +let%test_unit _ = + let long = Test_values.long1 () in + ignore (rev_append long long:int list) + +let%test_unit _ = + let long1 = Test_values.long1 () in + let long2 = Test_values.long2 () in + [%test_result: int list] (map long1 ~f:(fun x -> x + 1)) ~expect:long2 + +let test_ordering n = + let l = List.range 0 n in + let r = ref [] in + let _ : unit list = List.map l ~f:(fun x -> r := x :: !r) in + [%test_eq: int list] l (List.rev !r) + +let%test_unit _ = test_ordering 10 +let%test_unit _ = test_ordering 1000 +let%test_unit _ = test_ordering 1_000_000 + +let%test _ = for_all2_exn [] [] ~f:(fun _ _ -> assert false) + +let%test_unit _ = [%test_result: int option] (find_mapi [0;5;2;1;4] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 0) +let%test_unit _ = [%test_result: int option] (find_mapi [3;5;2;1;4] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 4) +let%test_unit _ = [%test_result: int option] (find_mapi [3;5;1;1;4] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 8) +let%test_unit _ = [%test_result: int option] (find_mapi [3;5;1;1;2] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:None + +let%test_unit _ = [%test_result: bool] (for_alli [] ~f:(fun _ _ -> false)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_alli [0;1;2;3] ~f:(fun i x -> i = x)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_alli [0;1;3;3] ~f:(fun i x -> i = x)) ~expect:false +let%test_unit _ = [%test_result: bool] (existsi [] ~f:(fun _ _ -> true)) ~expect:false +let%test_unit _ = [%test_result: bool] (existsi [0;1;2;3] ~f:(fun i x -> i <> x)) ~expect:false +let%test_unit _ = [%test_result: bool] (existsi [0;1;3;3] ~f:(fun i x -> i <> x)) ~expect:true + +let%test_unit _ = [%test_result: int list] (append [1;2;3] [4;5;6]) ~expect:[1;2;3;4;5;6] +let%test_unit _ = [%test_result: int list] (append [] [4;5;6]) ~expect:[4;5;6] +let%test_unit _ = [%test_result: int list] (append [1;2;3] []) ~expect:[1;2;3] +let%test_unit _ = [%test_result: int list] (append [1] [2;3]) ~expect:[1;2;3] +let%test_unit _ = [%test_result: int list] (append [1;2] [3]) ~expect:[1;2;3] +let%test_unit _ = + let long = Test_values.long1 () in + ignore (append long long:int list) + +let%test_unit _ = [%test_result: int list] (map ~f:Fn.id Test_values.l1) ~expect:Test_values.l1 +let%test_unit _ = [%test_result: int list] (map ~f:Fn.id []) ~expect:[] +let%test_unit _ = [%test_result: float list] (map ~f:(fun x -> x +. 5.) [1.;2.;3.]) ~expect:[6.;7.;8.] +let%test_unit _ = + ignore (map ~f:Fn.id (Test_values.long1 ()):int list) + +let%test_unit _ = + [%test_result: (int * char) list] + (map2_exn ~f:(fun a b -> a, b) [1;2;3] ['a';'b';'c']) + ~expect:[(1,'a'); (2,'b'); (3,'c')] +let%test_unit _ = [%test_result: _ list] (map2_exn ~f:(fun _ _ -> ()) [] []) ~expect:[] +let%test_unit _ = + let long = Test_values.long1 () in + ignore (map2_exn ~f:(fun _ _ -> ()) long long:unit list) + +let%test_unit _ = [%test_result: int list] (rev_map_append [1;2;3;4;5] [6] ~f:Fn.id) ~expect:[5;4;3;2;1;6] +let%test_unit _ = [%test_result: int list] (rev_map_append [1;2;3;4;5] [6] ~f:(fun x -> 2 * x)) ~expect:[10;8;6;4;2;6] +let%test_unit _ = [%test_result: int list] (rev_map_append [] [6] ~f:(fun _ -> failwith "bug!")) ~expect:[6] + +let%test_unit _ = + [%test_result: int list] + (fold_right ~f:(fun e acc -> e :: acc) Test_values.l1 ~init:[]) + ~expect:Test_values.l1 +let%test_unit _ = [%test_result: string] (fold_right ~f:(fun e acc -> e ^ acc) ["1";"2"] ~init:"3") ~expect:"123" +let%test_unit _ = [%test_result: unit] (fold_right ~f:(fun _ _ -> ()) [] ~init:()) ~expect:() +let%test_unit _ = + let long = Test_values.long1 () in + ignore (fold_right ~f:(fun e acc -> e :: acc) long ~init:[]) + +let%test_unit _ = + let l1 = Test_values.l1 in + [%test_result: int list * int list] (unzip (zip_exn l1 (List.rev l1))) ~expect:(l1, List.rev l1) +;; + +let%test_unit _ = + let long = Test_values.long1 () in + ignore (unzip (zip_exn long long)) +;; + +let%test_unit _ = [%test_result: int list * int list] (unzip [(1,2) ; (4,5) ]) ~expect:([1; 4], [2; 5] ) +let%test_unit _ = [%test_result: int list * int list * int list] (unzip3 [(1,2,3); (4,5,6)]) ~expect:([1; 4], [2; 5], [3; 6]) + +let%test_unit _ = [%test_result: (int * int) list Or_unequal_lengths.t] (zip [1;2;3] [4;5;6]) ~expect:(Ok [1,4;2,5;3,6]) +let%test_unit _ = [%test_result: (int * int) list Or_unequal_lengths.t] (zip [1] [4;5;6] ) ~expect:Unequal_lengths + +let%test_unit _ = [%test_result: (int * int) list] (zip_exn [1;2;3] [4;5;6]) ~expect:[1,4;2,5;3,6] + +let%expect_test _ = + show_raise (fun () -> zip_exn [1] [4;5;6]); + [%expect {| (raised (Invalid_argument "length mismatch in zip_exn: 1 <> 3 ")) |}] +;; + +let%test_unit _ = + [%test_result: (int * string) list] + (mapi ~f:(fun i x -> (i,x)) ["one";"two";"three";"four"]) + ~expect:[0,"one";1,"two";2,"three";3,"four"] +let%test_unit _ = [%test_result: (int * _) list] (mapi ~f:(fun i x -> (i,x)) []) ~expect:[] + +let%test_module "group" = + (module struct + let%test_unit _ = [%test_result: int list list] (group [1;2;3;4] ~break:(fun _ x -> x = 3)) ~expect:[[1;2];[3;4]] + + let%test_unit _ = [%test_result: int list list] (group [] ~break:(fun _ -> assert false)) ~expect:[] + + let mis = ['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i'] + let equal_letters = + [['M'];['i'];['s';'s'];['i'];['s';'s'];['i'];['p';'p'];['i']] + let single_letters = + [['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i']] + let every_three = + [['M'; 'i'; 's']; ['s'; 'i'; 's']; ['s'; 'i'; 'p']; ['p'; 'i' ]] + + let%test_unit _ = [%test_result: char list list] (group ~break:Char.(<>) mis) ~expect:equal_letters + let%test_unit _ = [%test_result: char list list] (group ~break:(fun _ _ -> false) mis) ~expect:single_letters + let%test_unit _ = [%test_result: char list list] (groupi ~break:(fun i _ _ -> i % 3 = 0) mis) ~expect:every_three + end) + +let%test_module "chunks_of" = + (module struct + + let test length break_every = + let l = List.init length ~f:Fn.id in + let b = chunks_of l ~length:break_every in + [%test_eq: int list] (List.concat b) l; + List.iter b ~f:([%test_pred: int list] (fun batch -> + List.length batch <= break_every)); + ;; + + let expect_exn length break_every = + match test length break_every with + | exception _ -> () + | () -> raise_s [%message "Didn't raise." (length : int) (break_every : int)] + ;; + + let%test_unit _ = + for n = 0 to 10 do + for k = n + 2 downto 1 do + test n k + done + done; + expect_exn 1 0; + expect_exn 1 (-1); + ;; + + let%test_unit _ = [%test_result: _ list list] (chunks_of [] ~length:1) ~expect:[] + end) + +let%test _ = last_exn [1;2;3] = 3 +let%test _ = last_exn [1] = 1 +let%test _ = last_exn (Test_values.long1 ()) = 99_999 + +let%test _ = is_prefix [] ~prefix:[] ~equal:(=) +let%test _ = is_prefix [1] ~prefix:[] ~equal:(=) +let%test _ = is_prefix [1] ~prefix:[1] ~equal:(=) +let%test _ = not (is_prefix [1] ~prefix:[1;2] ~equal:(=)) +let%test _ = not (is_prefix [1;3] ~prefix:[1;2] ~equal:(=)) +let%test _ = is_prefix [1;2;3] ~prefix:[1;2] ~equal:(=) + +let%test_unit _ = + List.iter ~f:(fun (t, expect) -> + assert (Poly.equal expect (find_consecutive_duplicate t ~equal:Poly.equal))) + [ [] , None + ; [ 1 ] , None + ; [ 1; 1 ] , Some (1, 1) + ; [ 1; 2 ] , None + ; [ 1; 2; 1 ] , None + ; [ 1; 2; 2 ] , Some (2, 2) + ; [ 1; 1; 2; 2 ], Some (1, 1) + ] +;; + +let%test_unit _ = + [%test_result: ((int * char) * (int * char)) option] + (find_consecutive_duplicate [(0,'a');(1,'b');(2,'b')] + ~equal:(fun (_, a) (_, b) -> Char.(=) a b)) + ~expect:(Some ((1, 'b'), (2, 'b'))) +;; + +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`Last [] + ~equal:Int.(=)) ~expect:[] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`Last [5;5;5;5;5] + ~equal:Int.(=)) ~expect:[5] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`Last [5;6;5;6;5;6] + ~equal:Int.(=)) ~expect:[5;6;5;6;5;6] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`Last [5;5;6;6;5;5;8;8] + ~equal:Int.(=)) ~expect:[5;6;5;8] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`Last [(0,1);(0,2);(2,2);(4,1)] + ~equal:(fun (a,_) (b,_) -> Int.(=) a b)) ~expect:[ (0,2);(2,2);(4,1)] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`Last [(0,1);(2,2);(0,2);(4,1)] + ~equal:(fun (a,_) (b,_) -> Int.(=) a b)) ~expect:[(0,1);(2,2);(0,2);(4,1)] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`Last [(0,1);(2,1);(0,2);(4,2)] + ~equal:(fun (_,a) (_,b) -> Int.(=) a b)) ~expect:[ (2,1); (4,2)] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`Last [(0,1);(2,2);(0,2);(4,1)] + ~equal:(fun (_,a) (_,b) -> Int.(=) a b)) ~expect:[(0,1); (0,2);(4,1)] + +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`First [] + ~equal:Int.(=)) ~expect:[] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`First [5;5;5;5;5] + ~equal:Int.(=)) ~expect:[5] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`First [5;6;5;6;5;6] + ~equal:Int.(=)) ~expect:[5;6;5;6;5;6] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`First [5;5;6;6;5;5;8;8] + ~equal:Int.(=)) ~expect:[5;6;5;8] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`First [(0,1);(0,2);(2,2);(4,1)] + ~equal:(fun (a,_) (b,_) -> Int.(=) a b)) ~expect:[(0,1); (2,2);(4,1)] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`First [(0,1);(2,2);(0,2);(4,1)] + ~equal:(fun (a,_) (b,_) -> Int.(=) a b)) ~expect:[(0,1);(2,2);(0,2);(4,1)] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`First [(0,1);(2,1);(0,2);(4,2)] + ~equal:(fun (_,a) (_,b) -> Int.(=) a b)) ~expect:[(0,1); (0,2); ] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`First [(0,1);(2,2);(0,2);(4,1)] + ~equal:(fun (_,a) (_,b) -> Int.(=) a b)) ~expect:[(0,1);(2,2); (4,1)] + +let%test_unit _ = [%test_result: int list] (dedup_and_sort ~compare:Int.compare []) ~expect:[] +let%test_unit _ = [%test_result: int list] (dedup_and_sort ~compare:Int.compare [5;5;5;5;5]) ~expect:[5] +let%test_unit _ = [%test_result: int] (length (dedup_and_sort ~compare:Int.compare [2;1;5;3;4])) ~expect:5 +let%test_unit _ = [%test_result: int] (length (dedup_and_sort ~compare:Int.compare [2;3;5;3;4])) ~expect:4 +let%test_unit _ = [%test_result: int] (length (dedup_and_sort [(0,1);(2,2);(0,2);(4,1)] ~compare:(fun (a,_) (b,_) -> Int.compare a b))) ~expect:3 +let%test_unit _ = [%test_result: int] (length (dedup_and_sort [(0,1);(2,2);(0,2);(4,1)] ~compare:(fun (_,a) (_,b) -> Int.compare a b))) ~expect:2 + +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare []) ~expect:None +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3]) ~expect:None +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3;4]) ~expect:None +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3;3]) ~expect:(Some 3) +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3;5;4;6;12]) ~expect:None +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3;5;4;5;12]) ~expect:(Some 5) +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3;5;12;5;12]) ~expect:(Some 5) +let%test_unit _ = + [%test_result: (int * int) option] + (find_a_dup ~compare:[%compare: int * int] [(0,1);(2,2);(0,2);(4,1)]) + ~expect:None +let%test _ = (find_a_dup [(0,1);(2,2);(0,2);(4,1)] + ~compare:(fun (_,a) (_,b) -> Int.compare a b)) + |> Option.is_some +let%test _ = let dup = find_a_dup [(0,1);(2,2);(0,2);(4,1)] + ~compare:(fun (a,_) (b,_) -> Int.compare a b) + in + match dup with + | Some (0, _) -> true + | _ -> false + +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare []) ~expect:false +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3]) ~expect:false +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3;4]) ~expect:false +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3;3]) ~expect:true +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3;5;4;6;12]) ~expect:false +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3;5;4;5;12]) ~expect:true +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3;5;12;5;12]) ~expect:true +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:[%compare: int * int] [(0,1);(2,2);(0,2);(4,1)]) ~expect:false +let%test_unit _ = [%test_result: bool] (contains_dup [(0,1);(2,2);(0,2);(4,1)] ~compare:(fun (_,a) (_,b) -> Int.compare a b)) ~expect:true + +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare []) ~expect:[] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3]) ~expect:[] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3;4]) ~expect:[] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3;3]) ~expect:[3] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3;5;4;6;12]) ~expect:[] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3;5;4;5;12]) ~expect:[5] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3;5;12;5;12]) ~expect:[5;12] +let%test_unit _ = [%test_result: (int * int) list] (find_all_dups ~compare:[%compare: int * int] [(0,1);(2,2);(0,2);(4,1)]) ~expect:[] +let%test_unit _ = [%test_result: int] (length (find_all_dups [(0,1);(2,2);(0,2);(4,1)] ~compare:(fun (_,a) (_,b) -> Int.compare a b))) ~expect:2 +let%test_unit _ = [%test_result: int] (length (find_all_dups [(0,1);(2,2);(0,2);(4,1)] ~compare:(fun (a,_) (b,_) -> Int.compare a b))) ~expect:1 + +let%test_unit _ = [%test_result: int] (counti [0;1;2;3;4] ~f:(fun idx x -> idx = x)) ~expect:5 +let%test_unit _ = [%test_result: int] (counti [0;1;2;3;4] ~f:(fun idx x -> idx = 4-x)) ~expect:1 + +let%test_unit _ = [%test_result: int list] (filter_map ~f:(fun x -> Some x) Test_values.l1) ~expect:Test_values.l1 +let%test_unit _ = [%test_result: int list] (filter_map ~f:(fun x -> Some x) []) ~expect:[] +let%test_unit _ = [%test_result: int list] (filter_map ~f:(fun _x -> None) [1.;2.;3.]) ~expect:[] +let%test_unit _ = [%test_result: int list] (filter_map ~f:(fun x -> if (x > 0) then Some x else None) [1;-1;3]) ~expect:[1;3] + +let%test_unit _ = [%test_result: int list] (filter_mapi ~f:(fun _i x -> Some x) Test_values.l1) ~expect:Test_values.l1 +let%test_unit _ = [%test_result: int list] (filter_mapi ~f:(fun _i x -> Some x) []) ~expect:[] +let%test_unit _ = [%test_result: int list] (filter_mapi ~f:(fun _i _x -> None) [1.;2.;3.]) ~expect:[] +let%test_unit _ = [%test_result: int list] (filter_mapi ~f:(fun _i x -> if (x > 0) then Some x else None) [1;-1;3]) ~expect:[1;3] +let%test_unit _ = [%test_result: int list] (filter_mapi ~f:(fun i x -> if (i % 2=0) then Some x else None) [1;-1;3]) ~expect:[1;3] + +let%test_unit _ = [%test_result: (int list * int list)] (split_n [1;2;3;4;5;6] 3) ~expect:([1;2;3],[4;5;6]) +let%test_unit _ = [%test_result: (int list * int list)] (split_n [1;2;3;4;5;6] 100) ~expect:([1;2;3;4;5;6],[]) +let%test_unit _ = [%test_result: (int list * int list)] (split_n [1;2;3;4;5;6] 0) ~expect:([],[1;2;3;4;5;6]) +let%test_unit _ = [%test_result: (int list * int list)] (split_n [1;2;3;4;5;6] (-5)) ~expect:([],[1;2;3;4;5;6]) + +let%test_unit _ = [%test_result: int list] (take [1;2;3;4;5;6] 3) ~expect:[1;2;3] +let%test_unit _ = [%test_result: int list] (take [1;2;3;4;5;6] 100) ~expect:[1;2;3;4;5;6] +let%test_unit _ = [%test_result: int list] (take [1;2;3;4;5;6] 0) ~expect:[] +let%test_unit _ = [%test_result: int list] (take [1;2;3;4;5;6] (-5)) ~expect:[] + +let%test_unit _ = [%test_result: int list] (drop [1;2;3;4;5;6] 3) ~expect:[4;5;6] +let%test_unit _ = [%test_result: int list] (drop [1;2;3;4;5;6] 100) ~expect:[] +let%test_unit _ = [%test_result: int list] (drop [1;2;3;4;5;6] 0) ~expect:[1;2;3;4;5;6] +let%test_unit _ = [%test_result: int list] (drop [1;2;3;4;5;6] (-5)) ~expect:[1;2;3;4;5;6] + +let%test_module "{take,drop,split}_while" = + (module struct + + let pred = function + | '0' .. '9' -> true + | _ -> false + + let test xs prefix suffix = + let (prefix1, suffix1) = split_while ~f:pred xs in + let prefix2 = take_while xs ~f:pred in + let suffix2 = drop_while xs ~f:pred in + [%test_eq: char list] xs (prefix @ suffix); + [%test_result: char list] ~expect:prefix prefix1; + [%test_result: char list] ~expect:prefix prefix2; + [%test_result: char list] ~expect:suffix suffix1; + [%test_result: char list] ~expect:suffix suffix2 + + let%test_unit _ = test ['1';'2';'3';'a';'b';'c'] ['1';'2';'3'] ['a';'b';'c'] + let%test_unit _ = test ['1';'2'; 'a';'b';'c'] ['1';'2' ] ['a';'b';'c'] + let%test_unit _ = test ['1'; 'a';'b';'c'] ['1' ] ['a';'b';'c'] + let%test_unit _ = test [ 'a';'b';'c'] [ ] ['a';'b';'c'] + let%test_unit _ = test ['1';'2';'3' ] ['1';'2';'3'] [ ] + let%test_unit _ = test [ ] [ ] [ ] + + end) + +let%test_unit _ = [%test_result: int list] (concat []) ~expect:[] +let%test_unit _ = [%test_result: int list] (concat [[]]) ~expect:[] +let%test_unit _ = [%test_result: int list] (concat [[3]]) ~expect:[3] +let%test_unit _ = [%test_result: int list] (concat [[1;2;3;4]]) ~expect:[1;2;3;4] +let%test_unit _ = [%test_result: int list] + (concat [[1;2;3;4];[5;6;7];[8;9;10];[];[11;12]]) + ~expect:[1;2;3;4;5;6;7;8;9;10;11;12] + +let%test_unit _ = [%test_result: bool] (is_sorted [] ~compare:Int.compare) ~expect:true +let%test_unit _ = [%test_result: bool] (is_sorted [1] ~compare:Int.compare) ~expect:true +let%test_unit _ = [%test_result: bool] (is_sorted [1; 2; 3; 4] ~compare:Int.compare) ~expect:true +let%test_unit _ = [%test_result: bool] (is_sorted [2; 1] ~compare:Int.compare) ~expect:false +let%test_unit _ = [%test_result: bool] (is_sorted [1; 3; 2] ~compare:Int.compare) ~expect:false + +let%test_unit _ = + List.iter + ~f:(fun (t, expect) -> [%test_result: bool] ~expect (is_sorted_strictly t ~compare:Int.compare)) + [ [] , true; + [ 1 ] , true; + [ 1; 2 ] , true; + [ 1; 1 ] , false; + [ 2; 1 ] , false; + [ 1; 2; 3 ], true; + [ 1; 1; 3 ], false; + [ 1; 2; 2 ], false; + ] +;; + +let%test_unit _ = [%test_result: int option] (random_element []) ~expect:None +let%test_unit _ = [%test_result: int option] (random_element [0]) ~expect:(Some 0) + +let%test_module "transpose" = + (module struct + + let round_trip a b = + [%test_result: int list list option] (transpose a) ~expect:(Some b); + [%test_result: int list list option] (transpose b) ~expect:(Some a) + + let%test_unit _ = round_trip [] [] + + let%test_unit _ = [%test_result: int list list option] (transpose [[]]) ~expect:(Some []) + let%test_unit _ = [%test_result: int list list option] (transpose [[]; []]) ~expect:(Some []) + let%test_unit _ = [%test_result: int list list option] (transpose [[]; []; []]) ~expect:(Some []) + + let%test_unit _ = round_trip [[1]] [[1]] + + let%test_unit _ = round_trip [[1]; + [2]] [[1; 2]] + + let%test_unit _ = round_trip [[1]; + [2]; + [3]] [[1; 2; 3]] + + let%test_unit _ = round_trip [[1; 2]; + [3; 4]] [[1; 3]; + [2; 4]] + + let%test_unit _ = round_trip [[1; 2; 3]; + [4; 5; 6]] [[1; 4]; + [2; 5]; + [3; 6]] + + let%test_unit _ = [%test_result: int list list option] (transpose [[]; [1]]) ~expect:None + + let%test_unit _ = [%test_result: int list list option] (transpose [[1;2];[3]]) ~expect:None + + end) + +let%test_unit _ = [%test_result: int list] (intersperse [1;2;3] ~sep:0) ~expect:[1;0;2;0;3] +let%test_unit _ = [%test_result: int list] (intersperse [1;2] ~sep:0) ~expect:[1;0;2] +let%test_unit _ = [%test_result: int list] (intersperse [1] ~sep:0) ~expect:[1] +let%test_unit _ = [%test_result: int list] (intersperse [] ~sep:0) ~expect:[] + +let test_fold_map list ~init ~f ~expect = + [%test_result: int list] (folding_map list ~init ~f) ~expect:(snd expect); + [%test_result: _ * int list] (fold_map list ~init ~f) ~expect + +let test_fold_mapi list ~init ~f ~expect = + [%test_result: int list] (folding_mapi list ~init ~f) ~expect:(snd expect); + [%test_result: _ * int list] (fold_mapi list ~init ~f) ~expect + +let%test_unit _ = test_fold_map [1;2;3;4] ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) + ~expect:(10, [1;3;6;10]) +let%test_unit _ = test_fold_map [] ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) + ~expect:(0, []) +let%test_unit _ = test_fold_mapi [1;2;3;4] ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) + ~expect:(20, [0;2;8;20]) +let%test_unit _ = test_fold_mapi [] ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) + ~expect:(0, []) diff --git a/test/test_list.mli b/test/test_list.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_list.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_map.ml b/test/test_map.ml new file mode 100644 index 0000000..d433795 --- /dev/null +++ b/test/test_map.ml @@ -0,0 +1,62 @@ +open! Import +open! Map + +let%test _ = + invariants (of_increasing_iterator_unchecked (module Int) ~len:20 ~f:(fun x -> x,x)) + +let%test _ = + invariants (Poly.of_increasing_iterator_unchecked ~len:20 ~f:(fun x -> x,x)) + +module M = M + +let add12 t = add_exn t ~key:1 ~data:2 + +type int_map = int Map.M(Int).t [@@deriving compare, hash, sexp] + +let%expect_test "[add_exn] success" = + print_s [%sexp (add12 (empty (module Int)) : int_map)]; + [%expect {| ((1 2)) |}] +;; + +let%expect_test "[add_exn] failure" = + show_raise (fun () -> add12 (add12 (empty (module Int)))); + [%expect {| (raised ("[Map.add_exn] got key already present" (key 1))) |}] +;; + +let%expect_test "[add] success" = + print_s [%sexp ( + add (empty (module Int)) ~key:1 ~data:2 : int_map Or_duplicate.t)]; + [%expect {| (Ok ((1 2))) |}] +;; + +let%expect_test "[add] duplicate" = + print_s [%sexp ( + add (add12 (empty (module Int))) ~key:1 ~data:2 : int_map Or_duplicate.t)]; + [%expect {| Duplicate |}] +;; + +let%expect_test "[Map.of_alist_multi] preserves value ordering" = + print_s [%sexp ( + Map.of_alist_multi (module String) ["a", 1; "a", 2; "b", 1; "b", 3] + : int list Map.M(String).t)]; + [%expect {| + ((a (1 2)) + (b (1 3))) |}] +;; + +module Poly = struct + let%test _ = + length Poly.empty = 0 + ;; + + let%test _ = + let a = Poly.of_alist_exn [] in + Poly.equal Base.Poly.equal a Poly.empty + ;; + + let%test _ = + let a = Poly.of_alist_exn [("a", 1)] in + let b = Poly.of_alist_exn [(1, "b")] in + length a = length b + ;; +end diff --git a/test/test_map.mlt b/test/test_map.mlt new file mode 100644 index 0000000..03ea12b --- /dev/null +++ b/test/test_map.mlt @@ -0,0 +1,6 @@ +open Base + +let _ = Map.add + +[%%expect {| +|}];; diff --git a/test/test_maybe_bound.ml b/test/test_maybe_bound.ml new file mode 100644 index 0000000..4b134d1 --- /dev/null +++ b/test/test_maybe_bound.ml @@ -0,0 +1,115 @@ +open! Import +open! Maybe_bound + +let%test_unit "bounds_crossed" = + let a, b, c, d = Incl 1, Excl 1, Incl 3, Excl 3 in + let cases = [ + a, a, false; + a, b, false; + a, c, false; + a, d, false; + b, a, false; + b, b, false; + b, c, false; + b, d, false; + c, a, true; + c, b, true; + c, c, false; + c, d, false; + d, a, true; + d, b, true; + d, c, false; + d, d, false; + ] in + List.iter cases ~f:(fun (lower, upper, expect) -> + let actual = bounds_crossed ~lower ~upper ~compare in + assert ([%compare.equal: bool] expect actual)); +;; + +let%test_module "is_lower_bound" = + (module struct + let compare = Int.compare + + let%test _ = is_lower_bound Unbounded ~of_:Int.min_value ~compare + + let%test _ = not (is_lower_bound (Incl 2) ~of_:1 ~compare) + let%test _ = is_lower_bound (Incl 2) ~of_:2 ~compare + let%test _ = is_lower_bound (Incl 2) ~of_:3 ~compare + + let%test _ = not (is_lower_bound (Excl 2) ~of_:1 ~compare) + let%test _ = not (is_lower_bound (Excl 2) ~of_:2 ~compare) + let%test _ = is_lower_bound (Excl 2) ~of_:3 ~compare + end) + +let%test_module "is_upper_bound" = + (module struct + let compare = Int.compare + + let%test _ = is_upper_bound Unbounded ~of_:Int.max_value ~compare + + let%test _ = is_upper_bound (Incl 2) ~of_:1 ~compare + let%test _ = is_upper_bound (Incl 2) ~of_:2 ~compare + let%test _ = not (is_upper_bound (Incl 2) ~of_:3 ~compare) + + let%test _ = is_upper_bound (Excl 2) ~of_:1 ~compare + let%test _ = not (is_upper_bound (Excl 2) ~of_:2 ~compare) + let%test _ = not (is_upper_bound (Excl 2) ~of_:3 ~compare) + end) + +let%test_module "check_range" = + (module struct + let compare = Int.compare + + let tests (lower, upper) cases = + List.iter cases ~f:(fun (n, comparison) -> + [%test_result: interval_comparison] + ~expect:comparison + (compare_to_interval_exn n ~lower ~upper ~compare); + [%test_result: bool] + ~expect:(match comparison with In_range -> true | _ -> false) + (interval_contains_exn n ~lower ~upper ~compare)) + + let%test_unit _ = + tests (Unbounded, Unbounded) + [ (Int.min_value, In_range) + ; (0, In_range) + ; (Int.max_value, In_range) + ] + + let%test_unit _ = + tests (Incl 2, Incl 4) + [ (1, Below_lower_bound) + ; (2, In_range) + ; (3, In_range) + ; (4, In_range) + ; (5, Above_upper_bound) + ] + + let%test_unit _ = + tests (Incl 2, Excl 4) + [ (1, Below_lower_bound) + ; (2, In_range) + ; (3, In_range) + ; (4, Above_upper_bound) + ; (5, Above_upper_bound) + ] + + let%test_unit _ = + tests (Excl 2, Incl 4) + [ (1, Below_lower_bound) + ; (2, Below_lower_bound) + ; (3, In_range) + ; (4, In_range) + ; (5, Above_upper_bound) + ] + + let%test_unit _ = + tests (Excl 2, Excl 4) + [ (1, Below_lower_bound) + ; (2, Below_lower_bound) + ; (3, In_range) + ; (4, Above_upper_bound) + ; (5, Above_upper_bound) + ] + end) + diff --git a/test/test_maybe_bound.mli b/test/test_maybe_bound.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_maybe_bound.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_nativeint.ml b/test/test_nativeint.ml new file mode 100644 index 0000000..9062421 --- /dev/null +++ b/test/test_nativeint.ml @@ -0,0 +1,7 @@ +open! Import +open! Nativeint + +let%expect_test "hash coherence" = + check_int_hash_coherence [%here] (module Nativeint); + [%expect {| |}]; +;; diff --git a/test/test_nativeint.mli b/test/test_nativeint.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_nativeint.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_nativeint_pow2.ml b/test/test_nativeint_pow2.ml new file mode 100644 index 0000000..7a9774d --- /dev/null +++ b/test/test_nativeint_pow2.ml @@ -0,0 +1,127 @@ +open! Import +open! Nativeint + +let examples = + [ -1n + ; 0n + ; 1n + ; 2n + ; 3n + ; 4n + ; 5n + ; 7n + ; 8n + ; 9n + ; 63n + ; 64n + ; 65n + ] +;; + +let examples_64_bit = + [ min_value + ; succ min_value + ; pred max_value + ; max_value ] +;; + +let print_for ints f = + List.iter ints ~f:(fun i -> + print_s [%message + "" + ~_:(i : nativeint) + ~_:(Or_error.try_with (fun () -> f i) : int Or_error.t)]) +;; + +let%expect_test "[floor_log2]" = + print_for examples floor_log2; + [%expect {| + (-1 (Error ("[Nativeint.floor_log2] got invalid input" -1))) + (0 (Error ("[Nativeint.floor_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 1)) + (4 (Ok 2)) + (5 (Ok 2)) + (7 (Ok 2)) + (8 (Ok 3)) + (9 (Ok 3)) + (63 (Ok 5)) + (64 (Ok 6)) + (65 (Ok 6)) |}]; +;; + +let%expect_test "[floor_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit floor_log2; + [%expect {| + (-9_223_372_036_854_775_808 ( + Error ("[Nativeint.floor_log2] got invalid input" -9223372036854775808))) + (-9_223_372_036_854_775_807 ( + Error ("[Nativeint.floor_log2] got invalid input" -9223372036854775807))) + (9_223_372_036_854_775_806 (Ok 62)) + (9_223_372_036_854_775_807 (Ok 62)) |}]; +;; + +let%expect_test "[ceil_log2]" = + print_for examples ceil_log2; + [%expect {| + (-1 (Error ("[Nativeint.ceil_log2] got invalid input" -1))) + (0 (Error ("[Nativeint.ceil_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 2)) + (4 (Ok 2)) + (5 (Ok 3)) + (7 (Ok 3)) + (8 (Ok 3)) + (9 (Ok 4)) + (63 (Ok 6)) + (64 (Ok 6)) + (65 (Ok 7)) |}]; +;; + +let%expect_test "[ceil_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit ceil_log2; + [%expect {| + (-9_223_372_036_854_775_808 ( + Error ("[Nativeint.ceil_log2] got invalid input" -9223372036854775808))) + (-9_223_372_036_854_775_807 ( + Error ("[Nativeint.ceil_log2] got invalid input" -9223372036854775807))) + (9_223_372_036_854_775_806 (Ok 63)) + (9_223_372_036_854_775_807 (Ok 63)) |}]; +;; + +let%test_module "nativeint_math" = + (module struct + + let test_cases () = + let cases = + [ 0b10101010n; 0b1010101010101010n; 0b101010101010101010101010n; + 0b10000000n; 0b1000000000001000n; 0b100000000000000000001000n; ] + in + match Word_size.word_size with + | W64 -> (* create some >32 bit values... *) + (* We can't use literals directly because the compiler complains on 32 bits. *) + let cases = + cases @ [ (0b1010101010101010n lsl 16) lor 0b1010101010101010n; + (0b1000000000000000n lsl 16) lor 0b0000000000001000n; ] + in + let added_cases = List.map cases ~f:(fun x -> x lsl 16) in + List.concat [ cases; added_cases ] + | W32 -> cases + ;; + + let%test_unit "ceil_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = ceil_pow2 x in + assert( (is_pow2 p2) && (p2 >= x && x >= (p2 / (of_int 2))) ) + ) + ;; + + let%test_unit "floor_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = floor_pow2 x in + assert( (is_pow2 p2) && (((of_int 2) * p2) >= x && x >= p2) ) + ) + ;; + end) diff --git a/test/test_not_found.mlt b/test/test_not_found.mlt new file mode 100644 index 0000000..dfeef17 --- /dev/null +++ b/test/test_not_found.mlt @@ -0,0 +1,18 @@ +open Base +open Expect_test_helpers_kernel +;; + +print_s [%sexp (Not_found_s [%message "foo"] : exn)];; +[%%expect {| (Not_found_s foo) |}];; + +Not_found;; +[%%expect {| +Line _, characters 0-9: +Error (Warning 3): deprecated: Not_found +[2016-09] this element comes from the stdlib distributed with OCaml. +Instead of raising [Not_found], consider using [raise_s] with an informative error +message. If code needs to distinguish [Not_found] from other exceptions, please change +it to handle both [Not_found] and [Not_found_s]. Then, instead of raising [Not_found], +raise [Not_found_s] with an informative error message. +|}];; + diff --git a/test/test_obj_array.ml b/test/test_obj_array.ml new file mode 100644 index 0000000..7a45b24 --- /dev/null +++ b/test/test_obj_array.ml @@ -0,0 +1,108 @@ +open! Import + +module Obj_array = Not_exposed_properly.Obj_array +open Obj_array + +let does_raise = Exn.does_raise + +let zero_obj = Caml.Obj.repr (0 : int) + +include + Test_blit.Test + (struct + type t = Caml.Obj.t + let equal = phys_equal + let of_bool b = Caml.Obj.repr (if b then 1 else 2 : int) + end) + (struct + type nonrec t = t [@@deriving sexp_of] + let create = create_zero + let get = get + let set = set + let length = length + end) + (Obj_array) +;; + +(* [create_zero] *) +let%test_unit _ = + let t = create_zero ~len:0 in + assert (length t = 0) +;; + +(* [create] *) +let%test_unit _ = + let str = Caml.Obj.repr "foo" in + let t = create ~len:2 str in + assert (phys_equal (get t 0) str); + assert (phys_equal (get t 1) str) +;; + +let%test_unit _ = + let float = Caml.Obj.repr 3.5 in + let t = create ~len:2 float in + assert (Caml.Obj.tag (Caml.Obj.repr t) = 0); (* not a double array *) + assert (phys_equal (get t 0) float); + assert (phys_equal (get t 1) float); + set t 1 (Caml.Obj.repr 4.); + assert (Float.(=) (Caml.Obj.obj (get t 1)) 4.); +;; + +(* [empty] *) +let%test _ = length empty = 0 + +let%test _ = does_raise (fun () -> get empty 0) + +(* [singleton] *) +let%test _ = length (singleton zero_obj) = 1 + +let%test _ = phys_equal (get (singleton zero_obj) 0) zero_obj + +let%test _ = does_raise (fun () -> get (singleton zero_obj) 1) + +let%test_unit _ = + let f = 13. in + let t = singleton (Caml.Obj.repr f) in + invariant t; + assert (Poly.equal (Caml.Obj.repr f) (get t 0)) +;; + +(* [get], [unsafe_get], [set], [unsafe_set], [unsafe_set_assuming_currently_int] *) +let%test_unit _ = + let t = create_zero ~len:1 in + assert (length t = 1); + assert (phys_equal (get t 0) zero_obj); + assert (phys_equal (unsafe_get t 0) zero_obj); + let one_obj = Caml.Obj.repr (1 : int) in + let check_get expect = + assert (phys_equal (get t 0) expect); + assert (phys_equal (unsafe_get t 0) expect); + in + set t 0 one_obj; + check_get one_obj; + unsafe_set t 0 zero_obj; + check_get zero_obj; + unsafe_set_assuming_currently_int t 0 one_obj; + check_get one_obj +;; + +(* [truncate] *) +let%test _ = does_raise (fun () -> truncate empty ~len:0) +let%test _ = does_raise (fun () -> truncate empty ~len:1) +let%test _ = does_raise (fun () -> truncate empty ~len:(-1)) +let%test _ = does_raise (fun () -> truncate (create_zero ~len:1) ~len:0) +let%test _ = does_raise (fun () -> truncate (create_zero ~len:1) ~len:2) + +let%test_unit _ = + let t = create_zero ~len:1 in + truncate t ~len:1; + assert (length t = 1) +;; + +let%test_unit _ = + let t = create_zero ~len:3 in + truncate t ~len:2; + assert (length t = 2); + truncate t ~len:1; + assert (length t = 1) +;; diff --git a/test/test_option.ml b/test/test_option.ml new file mode 100644 index 0000000..8a92649 --- /dev/null +++ b/test/test_option.ml @@ -0,0 +1,8 @@ +open! Import +open! Option + +let f = (+) +let%test _ = [%compare.equal: int t] (merge None None ~f) None +let%test _ = [%compare.equal: int t] (merge (Some 3) None ~f) (Some 3) +let%test _ = [%compare.equal: int t] (merge None (Some 3) ~f) (Some 3) +let%test _ = [%compare.equal: int t] (merge (Some 1) (Some 3) ~f) (Some 4) diff --git a/test/test_option.mli b/test/test_option.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_option.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_option_array.ml b/test/test_option_array.ml new file mode 100644 index 0000000..a39e732 --- /dev/null +++ b/test/test_option_array.ml @@ -0,0 +1,69 @@ +open! Import +open Option_array + + +let%test_module "Cheap_option" = ( + module struct + open For_testing.Unsafe_cheap_option + + let roundtrip_via_cheap_option (type a) (x : a) = + let opt : a t = some x in + assert (is_some opt); + assert (phys_equal (value_exn opt) x) + + let%test_unit _ = + roundtrip_via_cheap_option 0 + let%test_unit _ = + roundtrip_via_cheap_option 1 + let%test_unit _ = + roundtrip_via_cheap_option (ref 0) + let%test_unit _ = + roundtrip_via_cheap_option `x6e8ee3478e1d7449 + let%test_unit _ = + roundtrip_via_cheap_option 0.0 + let%test _ = + not (is_some none) + + let%test_unit "memory corruption" = + let make_list () = + List.init ~f:(fun i -> Some i) 5 + in + Caml.Gc.minor (); + let x = value_unsafe (some (make_list ())) in + Caml.Gc.minor (); + let _ = List.init ~f:(fun i -> Some (i*100)) 10000 in + [%test_result: Int.t Option.t List.t] + ~expect:(make_list ()) x + end) + +module Sequence = struct + let length = length + let get = get + let set = set +end + +include Base_for_tests.Test_blit.Test1_generic(struct + include Option + + let equal a b = Option.equal Bool.equal a b + let of_bool b = Some b + end) (struct + type nonrec 'a t = 'a t [@@deriving sexp] + type 'a z = 'a + include Sequence + + let create_bool ~len = init_some len ~f:(fun _ -> false) + end)(Option_array) + + +let%test_unit "floats are not re-boxed" = + let one = 1.0 in + let array = init_some 1 ~f:(fun _ -> one) in + assert (phys_equal one (get_some_exn array 0)) + +let%test_unit "segfault does not happen" = + (* if [Core_array] is used instead of [Uniform_array], this dies with a segfault *) + let _array = init 2 ~f:(fun i -> + if i = 0 then Some 1.0 else None) + in + () diff --git a/test/test_or_error.ml b/test/test_or_error.ml new file mode 100644 index 0000000..340bea4 --- /dev/null +++ b/test/test_or_error.ml @@ -0,0 +1,26 @@ +open! Import +open! Or_error + +let%test _ = [%compare.equal: string t] (errorf "foo %d" 13) (error_string "foo 13") + +let%test_unit _ = + for i = 0 to 10; do + assert ([%compare.equal: unit list t] + (combine_errors (List.init i ~f:(fun _ -> Ok ()))) + (Ok (List.init i ~f:(fun _ -> ())))); + done +let%test _ = Result.is_error (combine_errors [ error_string "" ]) +let%test _ = Result.is_error (combine_errors [ Ok (); error_string "" ]) + +let (=) = [%compare.equal: unit t] +let%test _ = combine_errors_unit [Ok (); Ok ()] = Ok () +let%test _ = combine_errors_unit [] = Ok () +let%test _ = + let a = Error.of_string "a" and b = Error.of_string "b" in + match combine_errors_unit [Ok (); Error a; Ok (); Error b] with + | Ok _ -> false + | Error e -> String.equal + (Error.to_string_hum e) + (Error.to_string_hum (Error.of_list [a;b])) +;; + diff --git a/test/test_or_error.mli b/test/test_or_error.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_or_error.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_ordered_collection_common.ml b/test/test_ordered_collection_common.ml new file mode 100644 index 0000000..c5607bb --- /dev/null +++ b/test/test_ordered_collection_common.ml @@ -0,0 +1,66 @@ +open! Import +open! Ordered_collection_common + +let%test_unit "fast check_pos_len_exn is correct" = + let n_vals = + [ 0 + ; 1 + ; 2 + ; 10 + ; 100 + ; Int.max_value / 2 - 2 + ; Int.max_value / 2 - 1 + ; Int.max_value / 2 + ; Int.max_value - 2 + ; Int.max_value - 1 + ; Int.max_value + ] + in + let z_vals = + [ Int.min_value + ; Int.min_value + 1 + ; Int.min_value + 2 + ; Int.min_value / 2 + ; Int.min_value / 2 + 1 + ; Int.min_value / 2 + 2 + ; -100 + ; -10 + ; -2 + ; -1 + ] @ n_vals + in + List.iter z_vals ~f:(fun pos -> + List.iter z_vals ~f:(fun len -> + List.iter n_vals ~f:(fun total_length -> + assert + (Bool.equal + (Exn.does_raise (fun () -> Private.slow_check_pos_len_exn ~pos ~len ~total_length)) + (Exn.does_raise (fun () -> check_pos_len_exn ~pos ~len ~total_length)))))) +;; + +let%test_unit _ = + let vals = [ -1; 0; 1; 2; 3 ] in + List.iter [ 0; 1; 2 ] ~f:(fun total_length -> + List.iter vals ~f:(fun pos -> + List.iter vals ~f:(fun len -> + let result = Result.try_with (fun () -> check_pos_len_exn ~pos ~len ~total_length) in + let valid = pos >= 0 && len >= 0 && len <= total_length - pos in + assert (Bool.equal valid (Result.is_ok result))))) +;; + +let%test_unit _ = + let opts = [ None; Some (-1); Some 0; Some 1; Some 2 ] in + List.iter [ 0; 1; 2 ] ~f:(fun total_length -> + List.iter opts ~f:(fun pos -> + List.iter opts ~f:(fun len -> + let result = Result.try_with (fun () -> get_pos_len_exn () ?pos ?len ~total_length) in + let pos = match pos with Some x -> x | None -> 0 in + let len = match len with Some x -> x | None -> total_length - pos in + let valid = pos >= 0 && len >= 0 && len <= total_length - pos in + match result with + | Error _ -> assert (not valid); + | Ok (pos', len') -> + assert (pos' = pos); + assert (len' = len); + assert valid))) +;; diff --git a/test/test_ordered_collection_common.mli b/test/test_ordered_collection_common.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_ordered_collection_common.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_ordering.ml b/test/test_ordering.ml new file mode 100644 index 0000000..b9dfce9 --- /dev/null +++ b/test/test_ordering.ml @@ -0,0 +1,15 @@ +open! Import +open! Ordering + +let%test _ = equal (of_int (-10)) Less +let%test _ = equal (of_int (-1) ) Less +let%test _ = equal (of_int 0 ) Equal +let%test _ = equal (of_int 1 ) Greater +let%test _ = equal (of_int 10 ) Greater + +let%test _ = equal (of_int (Int.compare 0 1)) Less +let%test _ = equal (of_int (Int.compare 1 1)) Equal +let%test _ = equal (of_int (Int.compare 1 0)) Greater + +let%test _ = List.for_all all ~f:(fun t -> equal t (t |> to_int |> of_int)) +let%test _ = List.for_all [ -1; 0; 1 ] ~f:(fun i -> i = (i |> of_int |> to_int)) diff --git a/test/test_popcount.ml b/test/test_popcount.ml new file mode 100644 index 0000000..8304384 --- /dev/null +++ b/test/test_popcount.ml @@ -0,0 +1,43 @@ +open! Import + +module type T = sig + type t [@@deriving compare, sexp_of] + + (* for implementing popcount_naive *) + val zero : t + val one : t + val (+) : t -> t -> t + val (lsr) : t -> int -> t + val (land) : t -> t -> t + + val quickcheck_generator : t Quickcheck.Generator.t + val to_int_exn : t -> int + + val popcount : t -> int +end + +module Make(Int : T) = struct + let popcount_naive (int : Int.t) : int = + let open Int in + let rec loop n count = + if Int.compare n zero <> 0 + then (loop (n lsr 1) (count + (n land one))) + else count + in + loop int zero + |> to_int_exn + ;; + + let%test_unit _ = + Quickcheck.test Int.quickcheck_generator + ~sexp_of:[%sexp_of: Int.t] + ~f:(fun int -> + let expect = popcount_naive int in + [%test_result: int] ~expect (Int.popcount int)) + ;; +end + +include Make(Quickcheck.Int) +include Make(Quickcheck.Int32) +include Make(Quickcheck.Int64) +include Make(Quickcheck.Nativeint) diff --git a/test/test_popcount.mli b/test/test_popcount.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_popcount.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_ppx_compare_lib.ml b/test/test_ppx_compare_lib.ml new file mode 100644 index 0000000..d418d3a --- /dev/null +++ b/test/test_ppx_compare_lib.ml @@ -0,0 +1,104 @@ +open! Import +open! Ppx_compare_lib + +module Unit = struct + type t = unit [@@deriving compare, sexp_of] +end + +module type T = sig + type t [@@deriving compare, sexp_of] +end + +let test (type a) (module T : T with type t = a) ordered = + List.iteri ordered ~f:(fun i ti -> + List.iteri ordered ~f:(fun j tj -> + require [%here] + (Ordering.equal + (Ordering.of_int (T.compare ti tj)) + (Ordering.of_int (Int.compare i j))) + ~if_false_then_print_s:( + lazy [%message "" ~_:(ti : T.t) ~_:(tj : T.t)]))); +;; + +let%expect_test "bool, char, unit" = + test (module Bool) [ false; true ]; + test (module Char) [ '\000'; 'a'; 'b' ]; + [%expect {| |}]; + test (module Unit) [ () ]; + [%expect {| |}]; +;; + +module type Min_zero_max = sig + include T + val min_value : t + val max_value : t + val zero : t +end + +let test_min_zero_max (type a) (module T : Min_zero_max with type t = a)= + test (module T) + [ T.min_value + ; T.zero + ; T.max_value ] +;; + +let%expect_test _ = + test_min_zero_max (module Float); + test_min_zero_max (module Int); + test_min_zero_max (module Int32); + test_min_zero_max (module Int64); + test_min_zero_max (module Nativeint); +;; + +let%expect_test "option" = + test + (module struct + type t = int option [@@deriving compare, sexp_of] + end) + [ None + ; Some 0 + ; Some 1 ] +;; +let%expect_test "ref" = + test + (module struct + type t = int ref [@@deriving compare, sexp_of] + end) + ([ -1; 0; 1 ] |> List.map ~f:ref); +;; + +module type Sequence = sig + type 'a t [@@deriving compare, sexp_of] + val of_list : 'a list -> 'a t +end + +let test_sequence (module T : Sequence) ordered = + test + (module struct + type t = int T.t [@@deriving compare, sexp_of] + end) + (ordered |> List.map ~f:T.of_list) +;; + +let%expect_test "array, list" = + test_sequence (module Array) + [ [ ] + ; [ 1 ] + ; [ 2 ] + ; [ 1; 2 ] + ; [ 2; 1 ]]; + test_sequence (module List) + [ [ ] + ; [ 1 ] + ; [ 1; 2 ] + ; [ 2 ] + ; [ 2; 1 ]]; +;; + +let%expect_test "[compare_abstract]" = + show_raise (fun () -> compare_abstract ~type_name:"TY" () ()); + [%expect {| + (raised ( + Failure + "Compare called on the type TY, which is abstract in an implementation.")) |}]; +;; diff --git a/test/test_ppx_compare_lib.mli b/test/test_ppx_compare_lib.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_ppx_compare_lib.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_queue.ml b/test/test_queue.ml new file mode 100644 index 0000000..7345577 --- /dev/null +++ b/test/test_queue.ml @@ -0,0 +1,902 @@ +open! Core_kernel + +let%test_module _ = + (module (struct + + open Queue + + module type S = S + + let does_raise = Exn.does_raise + + type nonrec 'a t = 'a t [@@deriving bin_io, sexp] + + let capacity = capacity + let set_capacity = set_capacity + + let%test_unit _ = + let t = create () in + assert (capacity t = 1); + enqueue t 1; + assert (capacity t = 1); + enqueue t 2; + assert (capacity t = 2); + enqueue t 3; + assert (capacity t = 4); + set_capacity t 0; + assert (capacity t = 4); + set_capacity t 3; + assert (capacity t = 4); + set_capacity t 100; + assert (capacity t = 128); + enqueue t 4; + enqueue t 5; + set_capacity t 0; + assert (capacity t = 8); + set_capacity t (-1); + assert (capacity t = 8); + ;; + + + let round_trip_sexp t = + let sexp = sexp_of_t Int.sexp_of_t t in + let t' = t_of_sexp Int.t_of_sexp sexp in + assert (to_list t = to_list t') + ;; + let%test_unit _ = round_trip_sexp (of_list [1; 2; 3; 4]) + let%test_unit _ = round_trip_sexp (create ()) + let%test_unit _ = round_trip_sexp (of_list []) + + let invariant = invariant + + let create = create + let%test_unit _ = + let t = create () in + assert (length t = 0); + assert (capacity t = 1); + ;; + let%test_unit _ = + let t = create ~capacity:0 () in + assert (length t = 0); + assert (capacity t = 1); + ;; + let%test_unit _ = + let t = create ~capacity:6 () in + assert (length t = 0); + assert (capacity t = 8); + ;; + let%test_unit _ = + assert (does_raise (fun () -> (create ~capacity:(-1) () : _ Queue.t))) + ;; + + let singleton = singleton + let%test_unit _ = + let t = singleton 7 in + assert (length t = 1); + assert (capacity t = 1); + assert (dequeue t = Some 7); + assert (dequeue t = None) + ;; + + let init = init + let%test_unit _ = + let t = init 0 ~f:(fun _ -> assert false) in + assert (length t = 0); + assert (capacity t = 1); + assert (dequeue t = None); + ;; + let%test_unit _ = + let t = init 3 ~f:(fun i -> i * 2) in + assert (length t = 3); + assert (capacity t = 4); + assert (dequeue t = Some 0); + assert (dequeue t = Some 2); + assert (dequeue t = Some 4); + assert (dequeue t = None); + ;; + let%test_unit _ = + assert (does_raise (fun () -> (init (-1) ~f:(fun _ -> ()) : unit Queue.t))) + ;; + + let get = get + let set = set + let%test_unit _ = + let t = create () in + let get_opt t i = Option.try_with (fun () -> get t i) in + assert (get_opt t 0 = None); + assert (get_opt t (-1) = None); + assert (get_opt t 10 = None); + List.iter [ -1; 0; 1 ] ~f:(fun i -> assert (does_raise (fun () -> set t i 0))); + enqueue t 0; + enqueue t 1; + enqueue t 2; + assert (get_opt t 0 = Some 0); + assert (get_opt t 1 = Some 1); + assert (get_opt t 2 = Some 2); + assert (get_opt t 3 = None); + ignore (dequeue_exn t); + assert (get_opt t 0 = Some 1); + assert (get_opt t 1 = Some 2); + assert (get_opt t 2 = None); + set t 0 3; + assert (get_opt t 0 = Some 3); + assert (get_opt t 1 = Some 2); + List.iter [ -1; 2 ] ~f:(fun i -> assert (does_raise (fun () -> set t i 0))) + ;; + + let map = map + let%test_unit _ = + for i = 0 to 5 do + let l = List.init i ~f:Fn.id in + let t = of_list l in + let f x = x * 2 in + let t' = map t ~f in + assert (to_list t' = List.map l ~f); + done + ;; + + let%test_unit _ = + let t = create () in + let t' = map t ~f:(fun x -> x * 2) in + assert (length t' = length t); + assert (length t' = 0); + assert (to_list t' = []) + ;; + + let mapi = mapi + let%test_unit _ = + for i = 0 to 5 do + let l = List.init i ~f:Fn.id in + let t = of_list l in + let f i x = (i, x * 2) in + let t' = mapi t ~f in + assert (to_list t' = List.mapi l ~f); + done + + let%test_unit _ = + let t = create () in + let t' = mapi t ~f:(fun i x -> (i, x * 2)) in + assert (length t' = length t); + assert (length t' = 0); + assert (to_list t' = []) + ;; + + include Test_container.Test_S1 (Queue) + + let dequeue_exn = dequeue_exn + let enqueue = enqueue + let peek = peek + let peek_exn = peek_exn + let last = last + let last_exn = last_exn + let%test_unit _ = + let t = create () in + assert (is_none (peek t)); + assert (is_none (last t)); + enqueue t 1; + enqueue t 2; + assert (peek t = Some 1); + assert (peek_exn t = 1); + assert (last t = Some 2); + assert (last_exn t = 2); + assert (dequeue_exn t = 1); + assert (dequeue_exn t = 2); + assert (does_raise (fun () -> dequeue_exn t)); + assert (does_raise (fun () -> peek_exn t)); + assert (does_raise (fun () -> last_exn t)) + ;; + + let enqueue_all = enqueue_all + let%test_unit _ = + let t = create () in + enqueue_all t [1; 2; 3]; + assert (dequeue_exn t = 1); + assert (dequeue_exn t = 2); + assert (last t = Some 3); + enqueue_all t [4; 5]; + assert (last t = Some 5); + assert (dequeue_exn t = 3); + assert (dequeue_exn t = 4); + assert (dequeue_exn t = 5); + assert (does_raise (fun () -> dequeue_exn t)); + enqueue_all t []; + assert (does_raise (fun () -> dequeue_exn t)); + ;; + + let of_list = of_list + let to_list = to_list + + let%test_unit _ = + for i = 0 to 4 do + let list = List.init i ~f:Fn.id in + assert (Poly.equal (to_list (of_list list)) list); + done + ;; + + let%test _ = + let t = create () in + begin + for i = 1 to 5 do enqueue t i done; + to_list t = [1;2;3;4;5] + end + ;; + + let of_array = of_array + let to_array = to_array + + let%test_unit _ = + for len = 0 to 4 do + let array = Array.init len ~f:Fn.id in + assert (Poly.equal (to_array (of_array array)) array); + done + ;; + + let compare = compare + let equal = equal + + let%test_module "comparisons" = + (module struct + + let sign x = if x < 0 then ~-1 else if x > 0 then 1 else 0 + + let test t1 t2 = + [%test_result: bool] + (equal Int.equal t1 t2) + ~expect:(List.equal Int.equal (to_list t1) (to_list t2)); + [%test_result: int] + (sign (compare Int.compare t1 t2)) + ~expect:(sign (List.compare Int.compare (to_list t1) (to_list t2))) + ;; + + let lists = + [ [] + ; [ 1 ] + ; [ 2 ] + ; [ 1; 1 ] + ; [ 1; 2 ] + ; [ 2; 1 ] + ; [ 1; 1; 1 ] + ; [ 1; 2; 3 ] + ; [ 1; 2; 4 ] + ; [ 1; 2; 4; 8 ] + ; [ 1; 2; 3; 4; 5 ] + ] + ;; + + let%test_unit _ = (* [phys_equal] inputs *) + List.iter lists ~f:(fun list -> + let t = of_list list in + test t t) + ;; + + let%test_unit _ = + List.iter lists ~f:(fun list1 -> + List.iter lists ~f:(fun list2 -> + test (of_list list1) (of_list list2))) + ;; + end) + + let clear = clear + + let blit_transfer = blit_transfer + + let%test_unit _ = + let q_list = [1; 2; 3; 4] in + let q = of_list q_list in + let q' = create () in + blit_transfer ~src:q ~dst:q' (); + assert (to_list q' = q_list); + assert (to_list q = []) + ;; + + let%test_unit _ = + let q = of_list [1; 2; 3; 4] in + let q' = create () in + blit_transfer ~src:q ~dst:q' ~len:2 (); + assert (to_list q' = [1; 2]); + assert (to_list q = [3; 4]) + ;; + + let%test_unit "blit_transfer on wrapped queues" = + let list = [1; 2; 3; 4] in + let q = of_list list in + let q' = copy q in + ignore (dequeue_exn q); + ignore (dequeue_exn q); + ignore (dequeue_exn q'); + ignore (dequeue_exn q'); + ignore (dequeue_exn q'); + enqueue q 5; + enqueue q 6; + blit_transfer ~src:q ~dst:q' ~len:3 (); + assert (to_list q' = [4; 3; 4; 5]); + assert (to_list q = [6]) + ;; + + let copy = copy + let dequeue = dequeue + let filter = filter + let filteri = filteri + let filter_inplace = filter_inplace + let filteri_inplace = filteri_inplace + let concat_map = concat_map + let concat_mapi = concat_mapi + let filter_map = filter_map + let filter_mapi = filter_mapi + let counti = counti + let existsi = existsi + let for_alli = for_alli + let iter = iter + let iteri = iteri + let foldi = foldi + let findi = findi + let find_mapi = find_mapi + + let%test_module "Linked_queue bisimulation" = + (module struct + module type Queue_intf = sig + type 'a t [@@deriving sexp_of] + + val create : unit -> 'a t + val enqueue : 'a t -> 'a -> unit + val dequeue : 'a t -> 'a option + val to_array : 'a t -> 'a array + val fold : 'a t -> init:'b -> f:('b -> 'a -> 'b) -> 'b + val foldi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b) -> 'b + val iter : 'a t -> f:('a -> unit) -> unit + val iteri : 'a t -> f:(int -> 'a -> unit) -> unit + val length : 'a t -> int + val clear : 'a t -> unit + val concat_map : 'a t -> f:( 'a -> 'b list) -> 'b t + val concat_mapi : 'a t -> f:(int -> 'a -> 'b list) -> 'b t + val filter_map : 'a t -> f:( 'a -> 'b option) -> 'b t + val filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + val filter : 'a t -> f:( 'a -> bool) -> 'a t + val filteri : 'a t -> f:(int -> 'a -> bool) -> 'a t + val filter_inplace : 'a t -> f:( 'a -> bool) -> unit + val filteri_inplace : 'a t -> f:(int -> 'a -> bool) -> unit + val map : 'a t -> f:( 'a -> 'b) -> 'b t + val mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + val counti : 'a t -> f:(int -> 'a -> bool) -> int + val existsi : 'a t -> f:(int -> 'a -> bool) -> bool + val for_alli : 'a t -> f:(int -> 'a -> bool) -> bool + val findi : 'a t -> f:(int -> 'a -> bool) -> (int * 'a) option + val find_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b option + val transfer : src:'a t -> dst:'a t -> unit + val copy : 'a t -> 'a t + end + + module That_queue : Queue_intf = Linked_queue + + module This_queue : Queue_intf = struct + include Queue + let create () = create () + let transfer ~src ~dst = blit_transfer ~src ~dst () + end + + let this_to_string this_t = + Sexp.to_string (this_t |> [%sexp_of: int This_queue.t]) + ;; + + let that_to_string that_t = + Sexp.to_string (that_t |> [%sexp_of: int That_queue.t]) + ;; + + let array_string arr = + Sexp.to_string (arr |> [%sexp_of: int array]) + ;; + + let create () = (This_queue.create (), That_queue.create ()) + + + let enqueue (t_a, t_b) v = + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + This_queue.enqueue t_a v; + That_queue.enqueue t_b v; + let end_a = This_queue.to_array t_a in + let end_b = That_queue.to_array t_b in + if end_a <> end_b + then failwithf "enqueue transition failure of: %s -> %s vs. %s -> %s" + (array_string start_a) + (array_string end_a) + (array_string start_b) + (array_string end_b) + () + ;; + + let iter (t_a, t_b) = + let r_a, r_b = ref 0, ref 0 in + This_queue.iter t_a ~f:(fun x -> r_a := !r_a + x); + That_queue.iter t_b ~f:(fun x -> r_b := !r_b + x); + if !r_a <> !r_b + then failwithf "error in iter: %s (from %s) <> %s (from %s)" + (Int.to_string !r_a) + (this_to_string t_a) + (Int.to_string !r_b) + (that_to_string t_b) + () + ;; + + let iteri (t_a, t_b) = + let r_a, r_b = ref 0, ref 0 in + This_queue.iteri t_a ~f:(fun i x -> r_a := !r_a + x lxor i); + That_queue.iteri t_b ~f:(fun i x -> r_b := !r_b + x lxor i); + if !r_a <> !r_b + then failwithf "error in iteri: %s (from %s) <> %s (from %s)" + (Int.to_string !r_a) + (this_to_string t_a) + (Int.to_string !r_b) + (that_to_string t_b) + () + ;; + + let dequeue (t_a, t_b) = + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + let a, b = This_queue.dequeue t_a, That_queue.dequeue t_b in + let end_a = This_queue.to_array t_a in + let end_b = That_queue.to_array t_b in + if a <> b || end_a <> end_b + then failwithf "error in dequeue: %s (%s -> %s) <> %s (%s -> %s)" + (Option.value ~default:"None" (Option.map a ~f:Int.to_string)) + (array_string start_a) + (array_string end_a) + (Option.value ~default:"None" (Option.map b ~f:Int.to_string)) + (array_string start_b) + (array_string end_b) + () + ;; + + let clear (t_a, t_b) = + This_queue.clear t_a; + That_queue.clear t_b; + ;; + + let is_even x = (x land 1) = 0 + + let filter (t_a, t_b) = + let t_a' = This_queue.filter t_a ~f:is_even in + let t_b' = That_queue.filter t_b ~f:is_even in + if This_queue.to_array t_a' <> That_queue.to_array t_b' + then failwithf "error in filter: %s -> %s vs. %s -> %s" + (this_to_string t_a) + (this_to_string t_a') + (that_to_string t_b) + (that_to_string t_b') + () + ;; + + let filteri (t_a, t_b) = + let t_a' = This_queue.filteri t_a ~f:(fun i j -> is_even i = is_even j) in + let t_b' = That_queue.filteri t_b ~f:(fun i j -> is_even i = is_even j) in + if This_queue.to_array t_a' <> That_queue.to_array t_b' + then failwithf "error in filteri: %s -> %s vs. %s -> %s" + (this_to_string t_a) + (this_to_string t_a') + (that_to_string t_b) + (that_to_string t_b') + () + ;; + + let filter_inplace (t_a, t_b) = + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + This_queue.filter_inplace t_a ~f:is_even; + That_queue.filter_inplace t_b ~f:is_even; + let end_a = This_queue.to_array t_a in + let end_b = That_queue.to_array t_b in + if end_a <> end_b + then failwithf "error in filter_inplace: %s -> %s vs. %s -> %s" + (array_string start_a) + (array_string end_a) + (array_string start_b) + (array_string end_b) + () + ;; + + let filteri_inplace (t_a, t_b) = + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + let f i x = is_even i = is_even x in + This_queue.filteri_inplace t_a ~f; + That_queue.filteri_inplace t_b ~f; + let end_a = This_queue.to_array t_a in + let end_b = That_queue.to_array t_b in + if end_a <> end_b + then failwithf "error in filteri_inplace: %s -> %s vs. %s -> %s" + (array_string start_a) + (array_string end_a) + (array_string start_b) + (array_string end_b) + () + ;; + + let concat_map (t_a, t_b) = + let f x = [x; x + 1; x + 2] in + let t_a' = This_queue.concat_map t_a ~f in + let t_b' = That_queue.concat_map t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in concat_map: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let concat_mapi (t_a, t_b) = + let f i x = [x; x + 1; x + 2; x + i] in + let t_a' = This_queue.concat_mapi t_a ~f in + let t_b' = That_queue.concat_mapi t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in concat_mapi: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let filter_map (t_a, t_b) = + let f x = if is_even x then None else Some (x + 1) in + let t_a' = This_queue.filter_map t_a ~f in + let t_b' = That_queue.filter_map t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in filter_map: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let filter_mapi (t_a, t_b) = + let f i x = if is_even i = is_even x then None else Some (x + 1 + i) in + let t_a' = This_queue.filter_mapi t_a ~f in + let t_b' = That_queue.filter_mapi t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in filter_mapi: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let map (t_a, t_b) = + let f x = x * 7 in + let t_a' = This_queue.map t_a ~f in + let t_b' = That_queue.map t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in map: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let mapi (t_a, t_b) = + let f i x = (x + 3) lxor i in + let t_a' = This_queue.mapi t_a ~f in + let t_b' = That_queue.mapi t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in mapi: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let counti (t_a, t_b) = + let f i x = i < 7 && (i % 7 = x % 7) in + let a' = This_queue.counti t_a ~f in + let b' = That_queue.counti t_b ~f in + if a' <> b' + then failwithf "error in counti: %d (for %s) <> %d (for %s)" + (a') + (this_to_string t_a) + (b') + (that_to_string t_b) + () + ;; + + let existsi (t_a, t_b) = + let f i x = i < 7 && (i % 7 = x % 7) in + let a' = This_queue.existsi t_a ~f in + let b' = That_queue.existsi t_b ~f in + if a' <> b' + then failwithf "error in existsi: %b (for %s) <> %b (for %s)" + (a') + (this_to_string t_a) + (b') + (that_to_string t_b) + () + ;; + + let for_alli (t_a, t_b) = + let f i x = i >= 7 || (i % 7 <> x % 7) in + let a' = This_queue.for_alli t_a ~f in + let b' = That_queue.for_alli t_b ~f in + if a' <> b' + then failwithf "error in for_alli: %b (for %s) <> %b (for %s)" + (a') + (this_to_string t_a) + (b') + (that_to_string t_b) + () + ;; + + let findi (t_a, t_b) = + let f i x = i < 7 && (i % 7 = x % 7) in + let a' = This_queue.findi t_a ~f in + let b' = That_queue.findi t_b ~f in + if a' <> b' + then failwithf "error in findi: %s (for %s) <> %s (for %s)" + (Sexp.to_string ([%sexp_of: (int * int) option] a')) + (this_to_string t_a) + (Sexp.to_string ([%sexp_of: (int * int) option] b')) + (that_to_string t_b) + () + ;; + + let find_mapi (t_a, t_b) = + let f i x = if i < 7 && (i % 7 = x % 7) then Some (i + x) else None in + let a' = This_queue.find_mapi t_a ~f in + let b' = That_queue.find_mapi t_b ~f in + if a' <> b' + then failwithf "error in find_mapi: %s (for %s) <> %s (for %s)" + (Sexp.to_string ([%sexp_of: int option] a')) + (this_to_string t_a) + (Sexp.to_string ([%sexp_of: int option] b')) + (that_to_string t_b) + () + ;; + + let copy (t_a, t_b) = + let copy_a = This_queue.copy t_a in + let copy_b = That_queue.copy t_b in + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + let end_a = This_queue.to_array copy_a in + let end_b = That_queue.to_array copy_b in + if end_a <> end_b + then failwithf "error in copy: %s -> %s vs. %s -> %s" + (array_string start_a) + (array_string end_a) + (array_string start_b) + (array_string end_b) + () + ;; + + let transfer (t_a, t_b) = + let dst_a = This_queue.create () in + let dst_b = That_queue.create () in + (* sometimes puts some elements in the destination queues *) + if Random.bool () + then begin + List.iter [ 1; 2; 3; 4; 5 ] ~f:(fun elem -> + This_queue.enqueue dst_a elem; + That_queue.enqueue dst_b elem); + end; + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + This_queue.transfer ~src:t_a ~dst:dst_a; + That_queue.transfer ~src:t_b ~dst:dst_b; + let end_a = This_queue.to_array t_a in + let end_b = That_queue.to_array t_b in + let end_a' = This_queue.to_array dst_a in + let end_b' = That_queue.to_array dst_b in + if end_a' <> end_b' || end_a <> end_b + then failwithf "error in transfer: %s -> (%s, %s) vs. %s -> (%s, %s)" + (array_string start_a) + (array_string end_a) + (array_string end_a') + (array_string start_b) + (array_string end_b) + (array_string end_b) + () + ;; + + let fold_check (t_a, t_b) = + let make_list fold t = + fold t ~init:[] ~f:(fun acc x -> x :: acc) + in + let this_l = make_list This_queue.fold t_a in + let that_l = make_list That_queue.fold t_b in + if this_l <> that_l + then failwithf "error in fold: %s (from %s) <> %s (from %s)" + (Sexp.to_string (this_l |> [%sexp_of: int list])) + (this_to_string t_a) + (Sexp.to_string (that_l |> [%sexp_of: int list])) + (that_to_string t_b) + () + ;; + + let foldi_check (t_a, t_b) = + let make_list foldi t = + foldi t ~init:[] ~f:(fun i acc x -> (i,x) :: acc) + in + let this_l = make_list This_queue.foldi t_a in + let that_l = make_list That_queue.foldi t_b in + if this_l <> that_l + then failwithf "error in foldi: %s (from %s) <> %s (from %s)" + (Sexp.to_string (this_l |> [%sexp_of: (int * int) list])) + (this_to_string t_a) + (Sexp.to_string (that_l |> [%sexp_of: (int * int) list])) + (that_to_string t_b) + () + ;; + + let length_check (t_a, t_b) = + let this_len = This_queue.length t_a in + let that_len = That_queue.length t_b in + if this_len <> that_len + then failwithf "error in length: %i (for %s) <> %i (for %s)" + this_len (this_to_string t_a) + that_len (that_to_string t_b) + () + ;; + + let%test_unit _ = + let t = create () in + let rec loop ~all_ops ~non_empty_ops = + if all_ops <= 0 && non_empty_ops <= 0 + then begin + let (t_a, t_b) = t in + let arr_a = This_queue.to_array t_a in + let arr_b = That_queue.to_array t_b in + if arr_a <> arr_b + then failwithf "queue final states not equal: %s vs. %s" + (array_string arr_a) + (array_string arr_b) + () + end else begin + let queue_was_empty = This_queue.length (fst t) = 0 in + let r = Random.int 195 in + begin + if r < 60 + then enqueue t (Random.int 10_000) + else if r < 65 + then dequeue t + else if r < 70 + then clear t + else if r < 80 + then iter t + else if r < 85 + then iteri t + else if r < 90 + then fold_check t + else if r < 95 + then foldi_check t + else if r < 100 + then filter t + else if r < 105 + then filteri t + else if r < 110 + then concat_map t + else if r < 115 + then concat_mapi t + else if r < 120 + then transfer t + else if r < 130 + then filter_map t + else if r < 135 + then filter_mapi t + else if r < 140 + then copy t + else if r < 150 + then filter_inplace t + else if r < 155 + then for_alli t + else if r < 160 + then existsi t + else if r < 165 + then counti t + else if r < 170 + then findi t + else if r < 175 + then find_mapi t + else if r < 180 + then map t + else if r < 185 + then mapi t + else if r < 190 + then filteri_inplace t + else if r < 195 + then length_check t + else failwith "Impossible: We did [Random.int 195] above" + end; + loop + ~all_ops:(all_ops - 1) + ~non_empty_ops:(if queue_was_empty then non_empty_ops else non_empty_ops - 1) + end + in + loop ~all_ops:30_000 ~non_empty_ops:20_000 + ;; + end) + + let binary_search = binary_search + let binary_search_segmented = binary_search_segmented + + let%test_unit "modification-during-iteration" = + let x = `A 0 in + let t = of_list [x; x] in + let f (`A n) = ignore n; clear t in + assert (does_raise (fun () -> iter t ~f)) + ;; + + let%test_unit "more-modification-during-iteration" = + let nested_iter_okay = ref false in + let t = of_list [ `iter; `clear ] in + assert (does_raise (fun () -> + iter t ~f:(function + | `iter -> iter t ~f:ignore; nested_iter_okay := true + | `clear -> clear t))); + assert !nested_iter_okay + ;; + + let%test_unit "modification-during-filter" = + let reached_unreachable = ref false in + let t = of_list [`clear; `unreachable] in + let f x = + match x with + | `clear -> clear t; false + | `unreachable -> reached_unreachable := true; false + in + assert (does_raise (fun () -> filter t ~f)); + assert (not !reached_unreachable) + ;; + + let%test_unit "modification-during-filter-inplace" = + let reached_unreachable = ref false in + let t = of_list [`drop_this; `enqueue_new_element; `unreachable] in + let f x = + begin match x with + | `drop_this | `new_element -> () + | `enqueue_new_element -> enqueue t `new_element + | `unreachable -> reached_unreachable := true + end; + false + in + assert (does_raise (fun () -> filter_inplace t ~f)); + (* even though we said to drop the first element, the aborted call to [filter_inplace] + shouldn't have made that change *) + assert (peek_exn t = `drop_this); + assert (not !reached_unreachable) + ;; + + let%test_unit "filter-inplace-during-iteration" = + let reached_unreachable = ref false in + let t = of_list [`filter_inplace; `unreachable] in + let f x = + match x with + | `filter_inplace -> filter_inplace t ~f:(fun _ -> false) + | `unreachable -> reached_unreachable := true + in + assert (does_raise (fun () -> iter t ~f)); + assert (not !reached_unreachable) + ;; + + module Stable = struct + module V1 = Stable.V1 + include Stable_unit_test.Make (struct + type nonrec t = int V1.t [@@deriving sexp, bin_io, compare] + let equal = [%compare.equal: t] + let tests = + let manipulated = Queue.of_list [0;3;6;1] in + ignore (Queue.dequeue_exn manipulated : int); + ignore (Queue.dequeue_exn manipulated : int); + Queue.enqueue manipulated 4; + [ Queue.of_list [], "()", "\000" + ; Queue.of_list [1; 2; 6; 4], "(1 2 6 4)", "\004\001\002\006\004" + ; manipulated, "(6 1 4)", "\003\006\001\004" + ] + end) + end + end + (* This signature is here to remind us to update the unit tests whenever we + change [Core_queue]. *) + : module type of Queue)) diff --git a/test/test_queue.mli b/test/test_queue.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_queue.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_random.ml b/test/test_random.ml new file mode 100644 index 0000000..bb88758 --- /dev/null +++ b/test/test_random.ml @@ -0,0 +1,185 @@ +open! Import +open! Random + +module State = struct + include State + + let%test_unit "random int above 2^30" [@tags "64-bits-only"] = + let state = make [| 1 ; 2 ; 3 ; 4 ; 5 |] in + for _ = 1 to 100 do + let bound = Int.shift_left 1 40 in + let n = int state bound in + if n < 0 || n >= bound then + failwith (Printf.sprintf "random result %d out of bounds (0,%d)" n (bound-1)) + done + ;; +end + +external random_seed: unit -> Caml.Obj.t = "caml_sys_random_seed";; +let%test_unit _ = + (* test that the return type of "caml_sys_random_seed" is what we expect *) + let module Obj = Caml.Obj in + let obj = random_seed () in + assert (Obj.is_block obj); + assert (Obj.tag obj = Obj.tag (Obj.repr [| 13 |])); + for i = 0 to Obj.size obj - 1 do + assert (Obj.is_int (Obj.field obj i)); + done +;; + +module type T = sig + type t [@@deriving compare, sexp_of] +end +;; + +(* We test that [count] trials of [generate ()] all produce values between [min, max], and + generate at least one value between [lo, hi]. *) +let test (type t) here m count generate ~min ~max ~check_range:(lo, hi) = + let (module T : T with type t = t) = m in + let between t ~lower_bound ~upper_bound = + T.compare t lower_bound >= 0 && + T.compare t upper_bound <= 0 + in + let generated = + List.init count ~f:(fun _ -> generate ()) + |> List.dedup_and_sort ~compare:T.compare + in + require here + (List.for_all generated ~f:(fun t -> + between t ~lower_bound:min ~upper_bound:max)) + ~if_false_then_print_s: + (lazy [%message + "generated values outside of bounds" + (min : T.t) + (max : T.t) + (generated : T.t list)]); + require here + (List.exists generated ~f:(fun t -> + between t ~lower_bound:lo ~upper_bound:hi)) + ~if_false_then_print_s: + (lazy [%message + "did not generate value inside range" + (lo : T.t) + (hi : T.t) + (generated : T.t list)]); +;; + +let%expect_test "float" = + test [%here] (module Float) 1_000 (fun () -> float 100.) + ~min:0. ~max:100. ~check_range:(10.,20.); + [%expect {||}]; +;; + +let%expect_test "float_range" = + test [%here] (module Float) 1_000 (fun () -> float_range (-100.) 100.) + ~min:(-100.) ~max:100. ~check_range:(-20.,-10.); + [%expect {||}]; +;; + +let%expect_test "int" = + test [%here] (module Int) 1_000 (fun () -> int 100) + ~min:0 ~max:99 ~check_range:(10,20); + [%expect {||}]; +;; + +let%expect_test "int_incl" = + test [%here] (module Int) 1_000 (fun () -> int_incl (-100) 100) + ~min:(-100) ~max:100 ~check_range:(-20,-10); + [%expect {||}]; + test [%here] (module Int) 1_000 (fun () -> int_incl 0 Int.max_value) + ~min:0 ~max:Int.max_value ~check_range:(0, Int.max_value / 100); + [%expect {||}]; + test [%here] (module Int) 1_000 (fun () -> int_incl Int.min_value Int.max_value) + ~min:Int.min_value ~max:Int.max_value + ~check_range:(Int.min_value / 100, Int.max_value / 100); + [%expect {||}]; +;; + +let%expect_test "int32" = + test [%here] (module Int32) 1_000 (fun () -> int32 100l) + ~min:0l ~max:99l ~check_range:(10l,20l); + [%expect {||}]; +;; + +let%expect_test "int32_incl" = + test [%here] (module Int32) 1_000 (fun () -> int32_incl (-100l) 100l) + ~min:(-100l) ~max:100l ~check_range:(-20l,-10l); + [%expect {||}]; + test [%here] (module Int32) 1_000 (fun () -> int32_incl 0l Int32.max_value) + ~min:0l ~max:Int32.max_value + ~check_range:(0l, Int32.( / ) Int32.max_value 100l); + [%expect {||}]; + test [%here] (module Int32) 1_000 (fun () -> int32_incl Int32.min_value Int32.max_value) + ~min:Int32.min_value ~max:Int32.max_value + ~check_range:(Int32.( / ) Int32.min_value 100l, Int32.( / ) Int32.max_value 100l); + [%expect {||}]; +;; + +let%expect_test "int64" = + test [%here] (module Int64) 1_000 (fun () -> int64 100L) + ~min:0L ~max:99L ~check_range:(10L,20L); + [%expect {||}]; +;; + +let%expect_test "int64_incl" = + test [%here] (module Int64) 1_000 (fun () -> int64_incl (-100L) 100L) + ~min:(-100L) ~max:100L ~check_range:(-20L,-10L); + [%expect {||}]; + test [%here] (module Int64) 1_000 (fun () -> int64_incl 0L Int64.max_value) + ~min:0L ~max:Int64.max_value + ~check_range:(0L, Int64.( / ) Int64.max_value 100L); + [%expect {||}]; + test [%here] (module Int64) 1_000 (fun () -> int64_incl Int64.min_value Int64.max_value) + ~min:Int64.min_value ~max:Int64.max_value + ~check_range:(Int64.( / ) Int64.min_value 100L, Int64.( / ) Int64.max_value 100L); + [%expect {||}]; +;; + +let%expect_test "nativeint" = + test [%here] (module Nativeint) 1_000 (fun () -> nativeint 100n) + ~min:0n ~max:99n ~check_range:(10n,20n); + [%expect {||}]; +;; + +let%expect_test "nativeint_incl" = + test [%here] (module Nativeint) 1_000 (fun () -> nativeint_incl (-100n) 100n) + ~min:(-100n) ~max:100n ~check_range:(-20n,-10n); + [%expect {||}]; + test [%here] (module Nativeint) 1_000 (fun () -> nativeint_incl 0n Nativeint.max_value) + ~min:0n ~max:Nativeint.max_value + ~check_range:(0n, Nativeint.( / ) Nativeint.max_value 100n); + [%expect {||}]; + test [%here] (module Nativeint) 1_000 (fun () -> + nativeint_incl Nativeint.min_value Nativeint.max_value) + ~min:Nativeint.min_value ~max:Nativeint.max_value + ~check_range:(Nativeint.( / ) Nativeint.min_value 100n, + Nativeint.( / ) Nativeint.max_value 100n); + [%expect {||}]; +;; + +(* The int63 functions come from [Int63] rather than [Random], but we test them here + along with the others anyway. *) + +let%expect_test "int63" = + let i = Int63.of_int in + test [%here] (module Int63) 1_000 (fun () -> Int63.random (i 100)) + ~min:(i 0) ~max:(i 99) ~check_range:(i 10,i 20); + [%expect {||}]; +;; + +let%expect_test "int63_incl" = + let i = Int63.of_int in + test [%here] (module Int63) 1_000 (fun () -> Int63.random_incl (i (-100)) (i 100)) + ~min:(i (-100)) ~max:(i 100) ~check_range:(i (-20),i (-10)); + [%expect {||}]; + test [%here] (module Int63) 1_000 (fun () -> Int63.random_incl (i 0) Int63.max_value) + ~min:(i 0) ~max:Int63.max_value + ~check_range:(i 0, Int63.( / ) Int63.max_value (i 100)); + [%expect {||}]; + test [%here] (module Int63) 1_000 (fun () -> + Int63.random_incl Int63.min_value Int63.max_value) + ~min:Int63.min_value ~max:Int63.max_value + ~check_range:(Int63.( / ) Int63.min_value (i 100), + Int63.( / ) Int63.max_value (i 100)); + [%expect {||}]; +;; diff --git a/test/test_random.mli b/test/test_random.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_random.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_ref.ml b/test/test_ref.ml new file mode 100644 index 0000000..298d903 --- /dev/null +++ b/test/test_ref.ml @@ -0,0 +1,21 @@ +open! Core_kernel +open Ref + +let%test_unit "[set_temporarily] without raise" = + let r = ref 0 in + [%test_result: int] ~expect:1 (set_temporarily r 1 ~f:(fun () -> !r)); + [%test_result: int] ~expect:0 !r; +;; + +let%test_unit "[set_temporarily] with raise" = + let r = ref 0 in + try + never_returns (set_temporarily r 1 ~f:(fun () -> failwith "")); + with _ -> [%test_result: int] ~expect:0 !r +;; + +let%test_unit "[set_temporarily] where [f] sets the ref" = + let r = ref 0 in + set_temporarily r 1 ~f:(fun () -> r := 2); + [%test_result: int] ~expect:0 !r; +;; diff --git a/test/test_ref.mli b/test/test_ref.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_ref.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_sequence.ml b/test/test_sequence.ml new file mode 100644 index 0000000..0887194 --- /dev/null +++ b/test/test_sequence.ml @@ -0,0 +1,488 @@ +open! Import +open! Sequence + +let%test_unit "of_lazy" = + let t = range 0 100 in + [%test_result: int list] + (to_list (of_lazy (lazy t))) + ~expect:(to_list t) + +let%test_unit _ = + let seq_of_seqs = + unfold ~init:0 ~f:(fun i -> + Some (unfold ~init:i ~f:(fun j -> Some ((i, j), j + 1)), + i + 1)) + in + [%test_result: (int * int) list] + (to_list (take (interleave seq_of_seqs) 10)) + ~expect:[ 0,0 + ; 0,1 ; 1,1 + ; 0,2 ; 1,2 ; 2,2 + ; 0,3 ; 1,3 ; 2,3 ; 3,3 + ] + +let%expect_test "round_robin vs interleave" = + let list_of_lists = + [ [1; 10; 100; 1000] + ; [2; 20; 200] + ; [3; 30] + ; [4] + ] + in + let list_of_seqs = List.map list_of_lists ~f:of_list in + let seq_of_seqs = of_list list_of_seqs in + print_s [%sexp (to_list (round_robin list_of_seqs) : int list)]; + [%expect {| (1 2 3 4 10 20 30 100 200 1_000) |}]; + print_s [%sexp (to_list (interleave seq_of_seqs) : int list)]; + [%expect {| (1 10 2 100 20 3 1_000 200 30 4) |}]; +;; + +let%test_unit _ = + let evens = unfold ~init:0 ~f:(fun i -> Some (i, i + 2)) in + let vowels = cycle_list_exn ['a';'e';'i';'o';'u'] in + [%test_result: (int * char) list] + (to_list (take (interleaved_cartesian_product evens vowels) 10)) + ~expect:[ 0,'a' + ; 0,'e' ; 2,'a' + ; 0,'i' ; 2,'e' ; 4,'a' + ; 0,'o' ; 2,'i' ; 4,'e' ; 6,'a' + ] + +let%test_module "Sequence.merge*" = + (module struct + let%test_unit _ = + [%test_eq: (int, int) Merge_with_duplicates_element.t list] + (to_list + (merge_with_duplicates + (of_list [ 1; 2; ]) + (of_list [ 2; 3; ]) + (* Can't use Core_int.compare because it would be a dependency cycle. *) + ~compare:Int.compare)) + [ Left 1; Both (2, 2); Right 3; ] + + let%test_unit _ = + [%test_eq: (int, int) Merge_with_duplicates_element.t list] + (to_list + (merge_with_duplicates + (of_list [ 2; 1; ]) + (of_list [ 2; 3; ]) + ~compare:Int.compare)) + [ Both (2, 2); Left 1; Right 3; ] + + let%test_unit _ = + [%test_eq: (int * string) list] + (to_list + (merge + (of_list [ (0, "A"); (1, "A"); ]) + (of_list [ (1, "B"); (2, "B"); ]) + ~compare:(fun a b -> [%compare: int] (fst a) (fst b)))) + [ (0, "A"); (1, "A"); (2, "B"); ] + end) + +let%test _ = fold ~f:(+) ~init:0 (of_list [1; 2; 3; 4; 5]) = 15 +let%test _ = fold ~f:(+) ~init:0 (of_list []) = 0 + +let%test_unit _ = + let test_equal l = [%test_result: int list] (to_list (of_list l)) ~expect:l in + test_equal []; + test_equal [1; 2; 3; 4; 5] +(* The test for longer list is after range *) + +let%test_unit _ = [%test_result: int list] (to_list (range 0 5)) ~expect:[0;1;2;3;4] +let%test_unit _ = [%test_result: int list] (to_list (range ~stop:`inclusive 0 5)) ~expect:[0;1;2;3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (range ~start:`exclusive 0 5)) ~expect:[1;2;3;4] +let%test_unit _ = [%test_result: int list] (to_list (range ~stride:(-2) 5 1)) ~expect:[5;3] + +(* Test for to_list *) +let%test_unit _ = [%test_result: int list] (to_list (range 0 5000)) ~expect:(List.range 0 5000) + +(* Functions used for testing by comparing to List implementation*) +let test_to_list s f g = + [%test_result: int list] (to_list (f s)) ~expect:(g (to_list s)) + +(* For testing, we create a sequence which is equal to 1;2;3;4;5, but + with a more interesting structure inside*) + +let s12345 = map ~f:(fun x -> x / 2) (filter ~f:(fun x -> x % 2 = 0) + (of_list [1;2;3;4;5;6;7;8;9;10])) + +let sempty = filter ~f:(fun x -> x < 0) (of_list [1;2;3;4]) + +let test f g = test_to_list s12345 f g; test_to_list sempty f g + +let%test_unit _ = + [%test_result: int list] (to_list s12345) ~expect:[1; 2; 3; 4; 5]; + [%test_result: int list] (to_list sempty) ~expect:[] + +let%test_unit _ = + [%test_result: int list] + (to_list (unfold_with s12345 ~init:1 + ~f:(fun s _ -> + if s % 2 = 0 then + Skip (s+1) + else if s = 5 then + Done + else + Yield(s, s+1)))) + ~expect:[1;3] + +let test_delay init = + unfold_with_and_finish ~init + ~running_step:(fun prev next -> + Yield (prev, next)) + ~inner_finished:(fun x -> Some x) + ~finishing_step:(fun prev -> + match prev with + | None -> Done + | Some prev -> Yield (prev, None)) + +let%test_unit _ = + [%test_result: int list] + (to_list (test_delay 0 s12345)) + ~expect:[0; 1; 2; 3; 4; 5] + +let%test_unit _ = + [%test_result: int list] + (to_list (test_delay 0 sempty)) + ~expect:[0] + +let%test_unit _ = [%test_result: int list] (to_list s12345) ~expect:[1; 2; 3; 4; 5] + +let%test_unit _ = test + (map ~f:(fun i -> -i)) + (List.map ~f:(fun i -> -i)) + +let%test_unit _ = test + (mapi ~f:(fun i j -> j - 2 *i)) + (List.mapi ~f:(fun i j -> j - 2 *i)) + +let%test_unit _ = test + (filter ~f:(fun i -> i % 2 = 0)) + (List.filter ~f:(fun i -> i % 2 = 0)) + +let%test _ = length s12345 = 5 && length sempty = 0 + +let%test_unit _ = + [%test_result: int option] (find s12345 ~f:(fun x -> x = 3)) ~expect:(Some 3); + [%test_result: int option] (find s12345 ~f:(fun x -> x = 7)) ~expect:None + +let%test_unit _ = + [%test_result: string option] + (find_map s12345 ~f:(fun x -> if x = 3 then Some "a" else None)) + ~expect:(Some "a"); + [%test_result: string option] + (find_map s12345 ~f:(fun x -> if x = 7 then Some "a" else None)) + ~expect:None + +let%test_unit _ = + [%test_result: string option] + (find_mapi s12345 ~f:(fun _ x -> if x = 3 then Some "a" else None)) + ~expect:(Some "a") +let%test_unit _ = + [%test_result: string option] + (find_mapi s12345 ~f:(fun _ x -> if x = 7 then Some "a" else None)) + ~expect:None +let%test_unit _ = + [%test_result: (int * int) option] + (find_mapi s12345 ~f:(fun i x -> if i + x >= 6 then Some (i,x) else None)) + ~expect:(Some (3,4)) + +let%test _ = for_all sempty ~f:(fun _ -> false) +let%test _ = for_all s12345 ~f:(fun x -> x > 0) +let%test _ = not (for_all s12345 ~f:(fun x -> x < 5)) + +let%test _ = for_alli sempty ~f:(fun _ _ -> false) +let%test _ = for_alli s12345 ~f:(fun _ x -> x > 0) +let%test _ = not (for_alli s12345 ~f:(fun _ x -> x < 5)) +let%test _ = for_alli s12345 ~f:(fun i x -> x = i+1) + +let%test _ = not (exists sempty ~f:(fun _ -> assert false)) +let%test _ = exists s12345 ~f:(fun x -> x = 5) +let%test _ = not (exists s12345 ~f:(fun x -> x = 0)) + +let%test _ = not (existsi sempty ~f:(fun _ _ -> assert false)) +let%test _ = existsi s12345 ~f:(fun _ x -> x = 5) +let%test _ = not (existsi s12345 ~f:(fun _ x -> x = 0)) +let%test _ = not (existsi s12345 ~f:(fun i x -> x <> i+1)) + +let%test_unit _ = + let l = ref [] in + iter s12345 ~f:(fun x -> l := x::!l); + [%test_result: int list] !l ~expect:[5;4;3;2;1] + +let%test _ = is_empty sempty +let%test _ = not (is_empty (of_list [1])) + +let%test _ = mem s12345 1 ~equal:Int.equal +let%test _ = not (mem s12345 6 ~equal:Int.equal) + +let%test_unit _ = [%test_result: int list] (to_list empty) ~expect:[] + +let%test_unit _ = + [%test_result: int list] + (to_list (bind sempty ~f:(fun _ -> s12345))) + ~expect:[] +let%test_unit _ = + [%test_result: int list] + (to_list (bind s12345 ~f:(fun _ -> sempty))) + ~expect:[] +let%test_unit _ = + [%test_result: int list] + (to_list (bind s12345 ~f:(fun x -> of_list [x;-x]))) + ~expect:[1;-1;2;-2;3;-3;4;-4;5;-5] + +let%test_unit _ = [%test_result: int list] (to_list (return 1)) ~expect:[1] + +let%test_unit _ = [%test_result: int option] (nth s12345 3) ~expect:(Some 4) +let%test_unit _ = [%test_result: int option] (nth s12345 5) ~expect:None + +let%test_unit _ = [%test_result: int option] (hd s12345) ~expect:(Some 1) +let%test_unit _ = [%test_result: int option] (hd sempty) ~expect:None + +let%test_unit _ = [%test_result: int t option] (tl sempty) ~expect:None +let%test_unit _ = match tl s12345 with + | Some l -> [%test_result: int list] (to_list l) ~expect:[2;3;4;5] + | None -> failwith "expected Some" + +let%test_unit _ = [%test_result: (int * int t) option] (next sempty) ~expect:None +let%test_unit _ = match next s12345 with + | Some (hd,tl) -> + [%test_result: int] hd ~expect:1; + [%test_result: int list] (to_list tl) ~expect:[2;3;4;5] + | None -> failwith "expected Some" + +let%test_unit _ = + [%test_result: int list] + (to_list (filter_opt (of_list [None; Some 1; None ;Some 2; Some 3]))) + ~expect:[1;2;3] + +let%test_unit _ = + let (l,r) = split_n s12345 2 in + [%test_result: int list] l ~expect:[1;2]; + [%test_result: int list] (to_list r) ~expect:[3;4;5] + +let%test_unit _ = [%test_result: int list list] (to_list (chunks_exn s12345 2)) ~expect:[[1;2];[3;4];[5]] + +let%test_unit _ = [%test_result: int list] (to_list (append s12345 s12345)) ~expect:[1;2;3;4;5;1;2;3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (append sempty s12345)) ~expect:[1;2;3;4;5] + +let%test_unit _ = + [%test_result: (int * int) list] + (to_list (zip s12345 sempty)) + ~expect:[] +let%test_unit _ = + [%test_result: (int * int) list] + (to_list (zip s12345 (of_list [6;5;4;3;2;1]))) + ~expect:[1,6;2,5;3,4;4,3;5,2] +let%test_unit _ = + [%test_result: (int * string) list] + (to_list (zip s12345 (of_list ["a"]))) + ~expect:[1,"a"] + +let%test_unit _ = + [%test_result: (int * int) option] + (find_consecutive_duplicate s12345 ~equal:(=)) + ~expect:None +let%test_unit _ = + [%test_result: (int * int) option] + (find_consecutive_duplicate (of_list [1;2;2;3;4;4;5]) ~equal:(=)) + ~expect:(Some (2,2)) + +let%test_unit _ = + [%test_result: int list] + (to_list + (remove_consecutive_duplicates ~equal:(=) (of_list [1;2;2;3;3;3;3;4;4;5;6;6;7]))) + ~expect:[1;2;3;4;5;6;7] +let%test_unit _ = + [%test_result: int list] + (to_list + (remove_consecutive_duplicates ~equal:(=) s12345)) + ~expect:[1;2;3;4;5] + +let%test_unit _ = + [%test_result: int list] + (to_list (remove_consecutive_duplicates ~equal:(fun _ _ -> true) s12345)) + ~expect:[1] + +let%test_unit _ = [%test_result: int list] (to_list (init (-1) ~f:(fun _ -> assert false))) ~expect:[] +let%test_unit _ = [%test_result: int list] (to_list (init 5 ~f:Fn.id)) ~expect:[0; 1; 2; 3; 4] + +let%test_unit _ = [%test_result: int list] (to_list (sub s12345 ~pos:4 ~len:10)) ~expect:[5] +let%test_unit _ = [%test_result: int list] (to_list (sub s12345 ~pos:1 ~len:2)) ~expect:[2;3] +let%test_unit _ = [%test_result: int list] (to_list (sub s12345 ~pos:0 ~len:0)) ~expect:[] + +let%test_unit _ = [%test_result: int list] (to_list (take s12345 2)) ~expect:[1;2] +let%test_unit _ = [%test_result: int list] (to_list (take s12345 0)) ~expect:[] +let%test_unit _ = [%test_result: int list] (to_list (take s12345 9)) ~expect:[1;2;3;4;5] + +let%test_unit _ = [%test_result: int list] (to_list (drop s12345 2)) ~expect:[3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (drop s12345 0)) ~expect:[1;2;3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (drop s12345 9)) ~expect:[] + +let%test_unit _ = [%test_result: int list] (to_list (take_while ~f:(fun x -> x < 3) s12345)) ~expect:[1;2] + +let%test_unit _ = [%test_result: int list] (to_list (drop_while ~f:(fun x -> x < 3) s12345)) ~expect:[3;4;5] + +let%test_unit _ = + [%test_result: int list] + (to_list (shift_right (shift_right s12345 0) (-1))) + ~expect:[-1;0;1;2;3;4;5] + +let%test_unit _ = [%test_result: char list] (to_list (intersperse ~sep:'a' (of_list []))) ~expect:[] +let%test_unit _ = [%test_result: char list] (to_list (intersperse ~sep:'a' (of_list ['b']))) ~expect:['b'] +let%test_unit _ = [%test_result: int list] (to_list (intersperse ~sep:(-1) (take s12345 1))) ~expect:[1] +let%test_unit _ = [%test_result: int list] (to_list (intersperse ~sep:0 s12345)) ~expect:[1;0;2;0;3;0;4;0;5] + +let%test_unit _ = [%test_result: int list] (to_list (take (repeat 1) 3)) ~expect:[1;1;1] + +let%test_unit _ = + [%test_result: int list] + (to_list (take (cycle_list_exn [1;2;3;4;5]) 7)) + ~expect:[1;2;3;4;5;1;2] + +let%test_unit _ = require_does_raise [%here] (fun () -> cycle_list_exn []) + +let%test_unit _ = + [%test_result: (char * int) list] + (to_list (cartesian_product (of_list ['a';'b']) s12345)) + ~expect:['a',1;'a',2;'a',3;'a',4;'a',5; + 'b',1;'b',2;'b',3;'b',4;'b',5] + +let%test_unit _ = + [%test_result: float] + (delayed_fold s12345 ~init:0.0 + ~f:(fun a i ~k -> + if Float.(<=) a 5.0 then + k (a +. (Float.of_int i)) else + a) + ~finish:(fun _ -> assert false)) + ~expect:6.0 + +let%expect_test "fold_m" = + let module Simple_monad = struct + type 'a t = + | Return of 'a + | Step of 'a t + [@@deriving sexp_of] + + let return a = Return a + + let rec bind t ~f = + match t with + | Return a -> f a + | Step t -> Step (bind t ~f) + ;; + + let step = Step (Return ()) + end in + fold_m ~bind:Simple_monad.bind ~return:Simple_monad.return s12345 + ~init:[] + ~f:(fun acc n -> + Simple_monad.bind Simple_monad.step ~f:(fun () -> + Simple_monad.return (n :: acc))) + |> printf !"%{sexp: int list Simple_monad.t}\n"; + [%expect {| (Step (Step (Step (Step (Step (Return (5 4 3 2 1))))))) |}] +;; + +let%expect_test "iter_m" = + iter_m ~bind:Generator.bind ~return:Generator.return s12345 ~f:Generator.yield + |> Generator.run + |> printf !"%{sexp: int t}\n"; + [%expect {| (1 2 3 4 5) |}] +;; + +let%test _ = + let num_computations = ref 0 in + let t = memoize (unfold ~init:() ~f:(fun () -> Int.incr num_computations; None)) in + iter t ~f:Fn.id; + iter t ~f:Fn.id; + !num_computations = 1 + +let%test_unit _ = [%test_result: int list] (to_list (drop_eagerly s12345 0)) ~expect:[1;2;3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (drop_eagerly s12345 2)) ~expect:[3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (drop_eagerly s12345 5)) ~expect:[] +let%test_unit _ = [%test_result: int list] (to_list (drop_eagerly s12345 8)) ~expect:[] + +let compare_tests = + [ [1; 2; 3] , [1; 2; 3] , 0 + ; [1; 2; 3] , [] , 1 + ; [] , [1; 2; 3] , -1 + ; [1; 2] , [1; 2; 3] , -1 + ; [1; 2; 3] , [1; 2] , 1 + ; [1; 3; 2] , [1; 2; 3] , 1 + ; [1; 2; 3] , [1; 3; 2] , -1 ] + +(* this test has to use base OCaml library functions to avoid circular dependencies *) +let%test _ = + List.for_all + ~f:Fn.id + (List.map + ~f:(fun (l1, l2, expected_res) -> + compare Int.compare (of_list l1) (of_list l2) = expected_res) + compare_tests) + +let%test_unit _ = + [%test_result: int list] + (folding_map (of_list [1;2;3;4]) ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) |> to_list) + ~expect:[1;3;6;10] +let%test_unit _ = + [%test_result: bool] + (folding_map empty ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) |> is_empty) + ~expect:true +let%test_unit _ = + [%test_result: int list] + (folding_mapi (of_list [1;2;3;4]) ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) |> to_list) + ~expect:[0;2;8;20] +let%test_unit _ = + [%test_result: bool] + (folding_mapi empty ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) |> is_empty) + ~expect:true + +let%expect_test _ = + let xs = init 3 ~f:Fn.id |> Generator.of_sequence in + let ( @ ) xs ys = Generator.bind xs ~f:(fun () -> ys) in + (xs @ xs @ xs @ xs @ xs) + |> Generator.run + |> [%sexp_of: int t] + |> print_s; + [%expect {| + (0 1 2 0 1 2 0 1 2 0 1 2 0 1 2) |}] +;; + +let%test_module "group" = + (module struct + let%test _ = + of_list [1; 2; 3; 4] + |> group ~break:(fun _ x -> Int.equal x 3) + |> [%compare.equal: int list t] (of_list [[1; 2]; [3; 4]]) + ;; + + let%test _ = + group empty ~break:(fun _ -> assert false) + |> [%compare.equal: unit list t] empty + ;; + + let mis = of_list ['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i'] + ;; + + let equal_letters = + of_list [['M'];['i'];['s';'s'];['i'];['s';'s'];['i'];['p';'p'];['i']] + ;; + + let single_letters = of_list [['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i']] + ;; + + let%test _ = + group ~break:Char.(<>) mis + |> [%compare.equal: char list t] equal_letters + ;; + + let%test _ = + group ~break:(fun _ _ -> false) mis + |> [%compare.equal: char list t] single_letters + ;; + end) diff --git a/test/test_sequence.mli b/test/test_sequence.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_sequence.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_set.ml b/test/test_set.ml new file mode 100644 index 0000000..2129f9b --- /dev/null +++ b/test/test_set.ml @@ -0,0 +1,28 @@ +open! Import +open! Set + +type int_set = Set.M(Int).t [@@deriving compare, hash, sexp] + +let%test _ = + invariants + (of_increasing_iterator_unchecked (module Int) ~len:20 ~f:Fn.id) + +let%test _ = + invariants + (Poly.of_increasing_iterator_unchecked ~len:20 ~f:Fn.id) + +module Poly = struct + let%test _ = + length Poly.empty = 0 + ;; + + let%test _ = + Poly.equal (Poly.of_list []) Poly.empty + ;; + + let%test _ = + let a = Poly.of_list [1; 1] in + let b = Poly.of_list ["a"] in + length a = length b + ;; +end diff --git a/test/test_sexp.ml b/test/test_sexp.ml new file mode 100644 index 0000000..7014760 --- /dev/null +++ b/test/test_sexp.ml @@ -0,0 +1,40 @@ +open! Import + +let%expect_test "[sexp_array]" = + let module M = struct + type t = { x : int sexp_array } [@@deriving sexp_of] + end in + List.iter [ [| |]; [| 13 |] ] ~f:(fun x -> print_s [%sexp ({ x } : M.t)]); + [%expect {| + () + ((x (13))) |}]; +;; + +let%expect_test "[sexp_list]" = + let module M = struct + type t = { x : int sexp_list } [@@deriving sexp_of] + end in + List.iter [ [ ]; [ 13 ] ] ~f:(fun x -> print_s [%sexp ({ x } : M.t)]); + [%expect {| + () + ((x (13))) |}]; +;; + +let%expect_test "[sexp_opaque]" = + let module M = struct + type t = { x : int sexp_opaque } [@@deriving sexp_of] + end in + print_s [%sexp ({ x = 13 } : M.t)]; + [%expect {| + ((x )) |}]; +;; + +let%expect_test "[sexp_option]" = + let module M = struct + type t = { x : int sexp_option } [@@deriving sexp_of] + end in + List.iter [ None; Some 13 ] ~f:(fun x -> print_s [%sexp ({ x } : M.t)]); + [%expect {| + () + ((x 13)) |}]; +;; diff --git a/test/test_sexp.mli b/test/test_sexp.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_sexp.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_sign.ml b/test/test_sign.ml new file mode 100644 index 0000000..83471a6 --- /dev/null +++ b/test/test_sign.ml @@ -0,0 +1,17 @@ +open! Import +open! Sign + +let%test "of_int" = + of_int 37 = Pos && of_int (-22) = Neg && of_int 0 = Zero + +let%test_unit "( * )" = + List.cartesian_product all all + |> List.iter ~f:(fun (s1, s2) -> + [%test_result: int] + (to_int (s1 * s2)) + ~expect:(Int.( * ) (to_int s1) (to_int s2))) + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_hash_coherence [%here] (module Sign) all; + [%expect {| |}]; +;; diff --git a/test/test_sign.mli b/test/test_sign.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_sign.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_sign_or_nan.ml b/test/test_sign_or_nan.ml new file mode 100644 index 0000000..cb30856 --- /dev/null +++ b/test/test_sign_or_nan.ml @@ -0,0 +1,10 @@ +open! Import +open! Sign_or_nan + +let%test "of_int" = + of_int 37 = Pos && of_int (-22) = Neg && of_int 0 = Zero + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_hash_coherence [%here] (module Sign_or_nan) all; + [%expect {| |}]; +;; diff --git a/test/test_sign_or_nan.mli b/test/test_sign_or_nan.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_sign_or_nan.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_stack.ml b/test/test_stack.ml new file mode 100644 index 0000000..a3df0af --- /dev/null +++ b/test/test_stack.ml @@ -0,0 +1,235 @@ +open! Base +open! Import +open! Stack + +module Debug (Stack : S) : S with type 'a t = 'a Stack.t = struct + + open Stack + + type nonrec 'a t = 'a t + + let invariant = invariant + + let check_and_return t = invariant ignore t; t + + let debug t f = + let result = Result.try_with f in + invariant ignore t; + Result.ok_exn result; + ;; + + (* The return-type annotations are to prevent an error where we don't supply all the + arguments to the function, and thus wouldn't be checking the invariant after fully + applying the function. *) + let clear t : unit = debug t (fun () -> clear t) + let copy t : _ t = check_and_return (debug t (fun () -> copy t)) + let count t ~f : int = debug t (fun () -> count t ~f) + let sum m t ~f = debug t (fun () -> sum m t ~f) + let create () : _ t = check_and_return (create ()) + let exists t ~f : bool = debug t (fun () -> exists t ~f) + let find t ~f : _ option = debug t (fun () -> find t ~f) + let find_map t ~f : _ option = debug t (fun () -> find_map t ~f) + let fold (type a) t ~init ~f : a = debug t (fun () -> fold t ~init ~f) + let for_all t ~f : bool = debug t (fun () -> for_all t ~f) + let is_empty t : bool = debug t (fun () -> is_empty t) + let iter t ~f : unit = debug t (fun () -> iter t ~f) + let length t : int = debug t (fun () -> length t) + let mem t a ~equal : bool = debug t (fun () -> mem t a ~equal) + let of_list l : _ t = check_and_return (of_list l) + let pop t : _ option = debug t (fun () -> pop t) + let pop_exn (type a) t : a = debug t (fun () -> pop_exn t) + let push t a : unit = debug t (fun () -> push t a) + let sexp_of_t sexp_of_a t : Sexp.t = debug t (fun () -> [%sexp_of: a t] t) + let singleton x : _ t = check_and_return (singleton x) + let t_of_sexp a_of_sexp sexp : _ t = check_and_return ([%of_sexp: a t] sexp) + let to_array t : _ array = debug t (fun () -> to_array t) + let to_list t : _ list = debug t (fun () -> to_list t) + let top t : _ option = debug t (fun () -> top t) + let top_exn (type a) t : a = debug t (fun () -> top_exn t) + let until_empty t f : unit = debug t (fun () -> until_empty t f) + let min_elt t ~compare : _ option = debug t (fun () -> min_elt t ~compare) + let max_elt t ~compare : _ option = debug t (fun () -> max_elt t ~compare) + let fold_result t ~init ~f = debug t (fun () -> fold_result t ~init ~f) + let fold_until t ~init ~f = debug t (fun () -> fold_until t ~init ~f) +end + +module Test (Stack : S) + (* This signature is here to remind us to add a unit test whenever we add something to + the stack interface. *) + : S with type 'a t = 'a Stack.t = struct + + open Stack + + type nonrec 'a t = 'a t + + include Test_container.Test_S1 (Stack) + + let invariant = invariant + + let create = create + let is_empty = is_empty + let top_exn = top_exn + let pop_exn = pop_exn + let pop = pop + let top = top + let singleton = singleton + let%test_unit _ = + let empty = create () in + invariant ignore empty; + invariant (fun b -> assert b) (of_list [true]); + assert (is_empty empty); + let t = create () in + push t 0; + assert (not (is_empty t)); + assert (Exn.does_raise (fun () -> top_exn empty)); + let t = create () in + push t 0; + [%test_result: int] (top_exn t) ~expect:0; + assert (Exn.does_raise (fun () -> pop_exn empty)); + let t = create () in + push t 0; + [%test_result: int] (pop_exn t) ~expect:0; + assert (is_none (pop empty)); + assert (is_some (pop (of_list [0]))); + assert (is_none (top empty)); + assert (is_some (top (of_list [0]))); + assert (is_some (top (singleton 0))); + assert (is_some (pop (singleton 0))); + assert (let t = singleton 0 in + ignore (pop_exn t : int); + is_none (top t)); + ;; + + let min_elt = min_elt + let max_elt = max_elt + + let%test_unit _ = + let empty = create () in + [%test_result: _ option] (min_elt ~compare:Int.compare empty) ~expect:None; + [%test_result: _ option] (max_elt ~compare:Int.compare empty) ~expect:None; + [%test_result: int] (sum (module Int) ~f:Fn.id empty) ~expect:0 + ;; + + let push = push + let copy = copy + let until_empty = until_empty + + let%test_unit _ = + let t = + let t = create () in + push t 0; + push t 1; + push t 2; + t + in + [%test_result: bool] (is_empty t) ~expect:false; + [%test_result: int] (length t) ~expect:3; + [%test_result: int option] (top t) ~expect:(Some 2); + [%test_result: int] (top_exn t) ~expect:2; + [%test_result: int option] (min_elt ~compare:Int.compare t) ~expect:(Some 0); + [%test_result: int option] (max_elt ~compare:Int.compare t) ~expect:(Some 2); + [%test_result: int] (sum (module Int) ~f:Fn.id t) ~expect:3; + let t' = copy t in + [%test_result: int] (pop_exn t') ~expect:2; + [%test_result: int] (pop_exn t') ~expect:1; + [%test_result: int] (pop_exn t') ~expect:0; + [%test_result: int] (length t') ~expect:0; + [%test_result: bool] (is_empty t') ~expect:true; + let t' = copy t in + [%test_result: int option] (pop t') ~expect:(Some 2); + [%test_result: int option] (pop t') ~expect:(Some 1); + [%test_result: int option] (pop t') ~expect:(Some 0); + [%test_result: int] (length t') ~expect:0; + [%test_result: bool] (is_empty t') ~expect:true; + (* test that t was not modified by pops applied to copies *) + [%test_result: int] (length t) ~expect:3; + [%test_result: int] (top_exn t) ~expect:2; + [%test_result: int list] (to_list t) ~expect:[2; 1; 0]; + [%test_result: int array] (to_array t) ~expect:[|2; 1; 0|]; + [%test_result: int] (length t) ~expect:3; + [%test_result: int] (top_exn t) ~expect:2; + let t' = copy t in + let n = ref 0 in + until_empty t' (fun x -> n := !n + x); + [%test_result: int] !n ~expect:3; + [%test_result: bool] (is_empty t') ~expect:true; + [%test_result: int] (length t') ~expect:0 + ;; + + let%test_unit _ = + let t = create () in + [%test_result: bool] (is_empty t) ~expect:true; + [%test_result: int] (length t) ~expect:0; + [%test_result: _ list] (to_list t) ~expect:[]; + [%test_result: _ option] (pop t) ~expect:None; + push t 13; + [%test_result: bool] (is_empty t) ~expect:false; + [%test_result: int] (length t) ~expect:1; + [%test_result: int option] (min_elt ~compare:Int.compare t) ~expect:(Some 13); + [%test_result: int option] (max_elt ~compare:Int.compare t) ~expect:(Some 13); + [%test_result: int] (sum (module Int) ~f:Fn.id t) ~expect:13; + [%test_result: int] (pop_exn t) ~expect:13; + [%test_result: bool] (is_empty t) ~expect:true; + [%test_result: int] (length t) ~expect:0; + push t 13; + push t 14; + [%test_result: bool] (is_empty t) ~expect:false; + [%test_result: int] (length t) ~expect:2; + [%test_result: int list] (to_list t) ~expect:[14; 13]; + [%test_result: int option] (min_elt ~compare:Int.compare t) ~expect:(Some 13); + [%test_result: int option] (max_elt ~compare:Int.compare t) ~expect:(Some 14); + [%test_result: int] (sum (module Int) ~f:Fn.id t) ~expect:27; + [%test_result: bool] (is_some (pop t)) ~expect:true; + [%test_result: bool] (is_some (pop t)) ~expect:true + ;; + + let of_list = of_list + + let%test_unit _ = + for n = 0 to 5 do + let l = List.init n ~f:Fn.id in + [%test_result: int list] (to_list (of_list l)) ~expect:l + done + ;; + + let clear = clear + + let%test_unit _ = + for n = 0 to 5 do + let t = of_list (List.init n ~f:Fn.id) in + clear t; + assert (is_empty t); + push t 13; + [%test_result: int] (length t) ~expect:1 + done + ;; + + let%test_unit "float test" = + let s = create () in + push s 1.0; + push s 2.0; + push s 3.0 + +end + +include Test_container.Test_S1 (Stack) + +include Test (Debug (Stack)) + +let capacity = capacity +let set_capacity = set_capacity + +let%test_unit _ = + let t = create () in + [%test_result: int] (capacity t) ~expect:0; + set_capacity t (-1); + [%test_result: int] (capacity t) ~expect:0; + set_capacity t 10; + [%test_result: int] (capacity t) ~expect:10; + set_capacity t 0; + [%test_result: int] (capacity t) ~expect:0; + push t (); + set_capacity t 0; + [%test_result: int] (length t) ~expect:1; + [%test_pred: int] (fun c -> c >= 1) (capacity t) +;; diff --git a/test/test_stack.mli b/test/test_stack.mli new file mode 100644 index 0000000..2f12c18 --- /dev/null +++ b/test/test_stack.mli @@ -0,0 +1,5 @@ +open! Base +open! Import + +module Debug (S : Stack.S) : Stack.S with type 'a t = 'a S.t +module Test (S : Stack.S) : sig end diff --git a/test/test_stdlib_shadowing.mlt b/test/test_stdlib_shadowing.mlt new file mode 100644 index 0000000..0212225 --- /dev/null +++ b/test/test_stdlib_shadowing.mlt @@ -0,0 +1,41 @@ +(* Additional shadowing tests, to make sure the [@@deprecated] attributes are properly + transported in [Base] *) +open Base + +let () = seek_in stdin 0 +[%%expect{| +Line _, characters 9-16: +Error (Warning 3): deprecated: Base.seek_in +[2016-09] this element comes from the stdlib distributed with OCaml. +Use [Stdio.In_channel.seek] instead. +Line _, characters 17-22: +Error (Warning 3): deprecated: Base.stdin +[2016-09] this element comes from the stdlib distributed with OCaml. +Use [Stdio.stdin] instead. +|}] + +let _ = StringLabels.make 10 'x' +[%%expect{| +Line _, characters 8-25: +Error (Warning 3): deprecated: module Base.StringLabels +[2016-09] this element comes from the stdlib distributed with OCaml. +Referring to the stdlib directly is discouraged by Base. You should either +use the equivalent functionality offered by Base, or if you really want to +refer to the stdlib, use Caml.StringLabels instead +|}] + +let _ = ( == ) +[%%expect{| +Line _, characters 8-14: +Error (Warning 3): deprecated: Base.== +[2016-09] this element comes from the stdlib distributed with OCaml. +Use [phys_equal] instead. +|}] + +let _ = ( != ) +[%%expect{| +Line _, characters 8-14: +Error (Warning 3): deprecated: Base.!= +[2016-09] this element comes from the stdlib distributed with OCaml. +Use [not (phys_equal ...)] instead. +|}] diff --git a/test/test_string.ml b/test/test_string.ml new file mode 100644 index 0000000..40a5e25 --- /dev/null +++ b/test/test_string.ml @@ -0,0 +1,678 @@ +open! Import +open! String + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_hash_coherence [%here] (module String) [ ""; "a"; "foo" ]; + [%expect {| |}]; +;; + +let%test_module "Caseless Suffix/Prefix" = + (module struct + let%test _ = Caseless.is_suffix "OCaml" ~suffix:"AmL" + let%test _ = Caseless.is_suffix "OCaml" ~suffix:"ocAmL" + let%test _ = Caseless.is_suffix "a@!$b" ~suffix:"a@!$B" + let%test _ = not (Caseless.is_suffix "a@!$b" ~suffix:"C@!$B") + let%test _ = not (Caseless.is_suffix "aa" ~suffix:"aaa") + let%test _ = Caseless.is_prefix "OCaml" ~prefix:"oc" + let%test _ = Caseless.is_prefix "OCaml" ~prefix:"ocAmL" + let%test _ = Caseless.is_prefix "a@!$b" ~prefix:"a@!$B" + let%test _ = not (Caseless.is_prefix "a@!$b" ~prefix:"a@!$C") + let%test _ = not (Caseless.is_prefix "aa" ~prefix:"aaa") + end) + +let%test_module "Caseless Comparable" = + (module struct + (* examples from docs *) + let%test _ = Caseless.equal "OCaml" "ocaml" + let%test _ = Caseless.("apple" < "Banana") + + let%test _ = Caseless.("aa" < "aaa") + let%test _ = Int.(<>) (Caseless.compare "apple" "Banana") (compare "apple" "Banana") + let%test _ = Caseless.equal "XxX" "xXx" + let%test _ = Caseless.("XxX" < "xXxX") + let%test _ = Caseless.("XxXx" > "xXx") + + let%test _ = List.is_sorted ~compare:Caseless.compare ["Apples"; "bananas"; "Carrots"] + + let%expect_test _ = + let x = Sys.opaque_identity "one string" in + let y = Sys.opaque_identity "another" in + require_no_allocation [%here] (fun () -> + ignore (Sys.opaque_identity (Caseless.equal x y) : bool)); + [%expect {||}]; + ;; + end) + +let%test_module "Caseless Hashable" = + (module struct + let%test _ = + Int.(<>) (hash "x") (hash "X") + && Int.(=) (Caseless.hash "x") (Caseless.hash "X") + let%test _ = Int.(=) (Caseless.hash "OCaml") (Caseless.hash "ocaml") + let%test _ = Int.(<>) (Caseless.hash "aaa") (Caseless.hash "aaaa") + let%test _ = Int.(<>) (Caseless.hash "aaa") (Caseless.hash "aab") + let%test _ = + let tbl = Hashtbl.create (module Caseless) in + Hashtbl.add_exn tbl ~key:"x" ~data:7; + [%compare.equal: int option] (Hashtbl.find tbl "X") (Some 7) + end) + +let%test _ = not (contains "" 'a') +let%test _ = contains "a" 'a' +let%test _ = not (contains "a" 'b') +let%test _ = contains "ab" 'a' +let%test _ = contains "ab" 'b' +let%test _ = not (contains "ab" 'c') +let%test _ = not (contains "abcd" 'b' ~pos:1 ~len:0) +let%test _ = contains "abcd" 'b' ~pos:1 ~len:1 +let%test _ = contains "abcd" 'c' ~pos:1 ~len:2 +let%test _ = not (contains "abcd" 'd' ~pos:1 ~len:2) +let%test _ = contains "abcd" 'd' ~pos:1 +let%test _ = not (contains "abcd" 'a' ~pos:1) + +let%test_module "Search_pattern" = + (module struct + open Search_pattern + + let%test_module "Search_pattern.create" = + (module struct + let prefix s n = sub s ~pos:0 ~len:n + let suffix s n = sub s ~pos:(length s - n) ~len:n + + let slow_create pattern = + (* Compute the longest prefix-suffix array from definition, O(n^3) *) + let n = length pattern in + let kmp_arr = Array.create ~len:n (-1) in + for i = 0 to n - 1 do + let x = prefix pattern (i + 1) in + for j = 0 to i do + if String.equal (prefix x j) (suffix x j) then + kmp_arr.(i) <- j + done + done; + (pattern, kmp_arr) + ;; + + let sexp_of_int = Base.Not_exposed_properly.Sexp_conv.sexp_of_int + + let test_both (s, a) = + let create_s = create s |> [%sexp_of: t ] in + let slow_create_s = slow_create s |> [%sexp_of: string * int array] in + let expected = [%sexp ((s, a) : string * int array)] in + require [%here] (Sexp.equal create_s expected && Sexp.equal slow_create_s expected) + ~if_false_then_print_s:(lazy ( + [%message "not equal" + (create_s : Sexp.t) + (slow_create_s : Sexp.t) + (expected : Sexp.t)])) + ;; + + let cmp_both s = + let create_s = create s |> [%sexp_of: t ] in + let slow_create_s = slow_create s |> [%sexp_of: string * int array] in + require [%here] (Sexp.equal create_s slow_create_s) + ~if_false_then_print_s:(lazy ( + [%message "not equal" + (create_s : Sexp.t) + (slow_create_s : Sexp.t)])) + ;; + + let%expect_test _ = + test_both ("", [| |]) + let%expect_test _ = + test_both ("ababab", [|0; 0; 1; 2; 3; 4|]) + let%expect_test _ = + test_both ("abaCabaD", [|0; 0; 1; 0; 1; 2; 3; 0|]) + let%expect_test _ = + test_both ("abaCabaDabaCabaCabaDabaCabaEabab", + [|0; 0; 1; 0; 1; 2; 3; 0; 1; 2; 3; 4; 5; 6; 7; 4; 5; 6; 7; 8; + 9; 10; 11; 12; 13; 14; 15; 0; 1; 2; 3; 2|]) + + let rec x k = + if Int.(<) k 0 then "" else + let b = x (k - 1) in + b ^ (make 1 (Caml.Char.unsafe_chr (65 + k))) ^ b + ;; + + let%expect_test _ = + cmp_both (x 10) + let%expect_test _ = + cmp_both ((x 5) ^ "E" ^ (x 4) ^ "D" ^ (x 3) ^ "B" ^ (x 2) ^ "C" ^ (x 3)) + end) + + let (=) = [%compare.equal: int option] + let%test _ = index (create "") ~in_:"abababac" = Some 0 + let%test _ = index ~pos:(-1) (create "") ~in_:"abababac" = None + let%test _ = index ~pos:1 (create "") ~in_:"abababac" = Some 1 + let%test _ = index ~pos:7 (create "") ~in_:"abababac" = Some 7 + let%test _ = index ~pos:8 (create "") ~in_:"abababac" = Some 8 + let%test _ = index ~pos:9 (create "") ~in_:"abababac" = None + let%test _ = index (create "abababaca") ~in_:"abababac" = None + let%test _ = index (create "abababac") ~in_:"abababac" = Some 0 + let%test _ = index ~pos:0 (create "abababac") ~in_:"abababac" = Some 0 + let%test _ = index (create "abac") ~in_:"abababac" = Some 4 + let%test _ = index ~pos:4 (create "abac") ~in_:"abababac" = Some 4 + let%test _ = index ~pos:5 (create "abac") ~in_:"abababac" = None + let%test _ = index ~pos:5 (create "abac") ~in_:"abababaca" = None + let%test _ = index ~pos:5 (create "baca") ~in_:"abababaca" = Some 5 + let%test _ = index ~pos:(-1) (create "a") ~in_:"abc" = None + let%test _ = index ~pos:2 (create "a") ~in_:"abc" = None + let%test _ = index ~pos:2 (create "c") ~in_:"abc" = Some 2 + let%test _ = index ~pos:3 (create "c") ~in_:"abc" = None + + let (=) = [%compare.equal: bool] + let%test _ = matches (create "") "abababac" = true + let%test _ = matches (create "abababaca") "abababac" = false + let%test _ = matches (create "abababac") "abababac" = true + let%test _ = matches (create "abac") "abababac" = true + let%test _ = matches (create "abac") "abababaca" = true + let%test _ = matches (create "baca") "abababaca" = true + let%test _ = matches (create "a") "abc" = true + let%test _ = matches (create "c") "abc" = true + + let (=) = [%compare.equal: int list] + let%test _ = index_all (create "") ~may_overlap:false ~in_:"abcd" = [0; 1; 2; 3; 4] + let%test _ = index_all (create "") ~may_overlap:true ~in_:"abcd" = [0; 1; 2; 3; 4] + let%test _ = index_all (create "abab") ~may_overlap:false ~in_:"abababab" = [0; 4] + let%test _ = index_all (create "abab") ~may_overlap:true ~in_:"abababab" = [0; 2; 4] + let%test _ = index_all (create "abab") ~may_overlap:false ~in_:"ababababab" = [0; 4] + let%test _ = index_all (create "abab") ~may_overlap:true ~in_:"ababababab" = [0; 2; 4; 6] + let%test _ = index_all (create "aaa") ~may_overlap:false ~in_:"aaaaBaaaaaa" = [0; 5; 8] + let%test _ = index_all (create "aaa") ~may_overlap:true ~in_:"aaaaBaaaaaa" = [0; 1; 5; 6; 7; 8] + + let (=) = [%compare.equal: string] + let%test _ = replace_first (create "abab") ~in_:"abababab" ~with_:"" = "abab" + let%test _ = replace_first (create "abab") ~in_:"abacabab" ~with_:"" = "abac" + let%test _ = replace_first (create "abab") ~in_:"ababacab" ~with_:"A" = "Aacab" + let%test _ = replace_first (create "abab") ~in_:"acabababab" ~with_:"A" = "acAabab" + let%test _ = replace_first (create "ababab") ~in_:"acabababab" ~with_:"A" = "acAab" + let%test _ = replace_first (create "abab") ~in_:"abababab" ~with_:"abababab" = "abababababab" + + let%test _ = replace_all (create "abab") ~in_:"abababab" ~with_:"" = "" + let%test _ = replace_all (create "abab") ~in_:"abacabab" ~with_:"" = "abac" + let%test _ = replace_all (create "abab") ~in_:"acabababab" ~with_:"A" = "acAA" + let%test _ = replace_all (create "ababab") ~in_:"acabababab" ~with_:"A" = "acAab" + let%test _ = replace_all (create "abaC") ~in_:"abaCabaDCababaCabaCaba" ~with_:"x" = "xabaDCabxxaba" + let%test _ = replace_all (create "a") ~in_:"aa" ~with_:"aaa" = "aaaaaa" + let%test _ = replace_all (create "") ~in_:"abcdeefff" ~with_:"X1" = "X1aX1bX1cX1dX1eX1eX1fX1fX1fX1" + + (* a doc comment in core_string.mli gives this as an example *) + let%test _ = replace_all (create "bc") ~in_:"aabbcc" ~with_:"cb" = "aabcbc" + end) + +let%test _ = rev "" = "";; +let%test _ = rev "a" = "a";; +let%test _ = rev "ab" = "ba";; +let%test _ = rev "abc" = "cba";; + +let%test_unit _ = + List.iter ~f:(fun (t, expect) -> + let actual = split_lines t in + if not ([%compare.equal: string list] actual expect) + then raise_s [%message "split_lines bug" + (t : t) (actual : t list) (expect : t list)]) + [ "" , []; + "\n" , [""]; + "a" , ["a"]; + "a\n" , ["a"]; + "a\nb" , ["a"; "b"]; + "a\nb\n" , ["a"; "b"]; + "a\n\n" , ["a"; "" ]; + "a\n\nb" , ["a"; "" ; "b"]; + ] +;; + +let%test_unit _ = + let lines = [ ""; "a"; "bc" ] in + let newlines = [ "\n"; "\r\n" ] in + let rec loop n expect to_concat = + if Int.(=) n 0 then begin + let input = concat to_concat in + let actual = Or_error.try_with (fun () -> split_lines input) in + if not ([%compare.equal: t list Or_error.t] actual (Ok expect)) + then raise_s [%message "split_lines bug" + (input : t) + (actual : t list Or_error.t) + (expect : t list)] + end else begin + loop (n - 1) expect to_concat; + List.iter lines ~f:(fun t -> + let loop to_concat = loop (n - 1) (t :: expect) (t :: to_concat) in + if not (is_empty t) && List.is_empty to_concat then loop []; + List.iter newlines ~f:(fun newline -> loop (newline :: to_concat))); + end + in + loop 3 [] [] +;; + +let%test_unit _ = + let s = init 10 ~f:(Char.of_int_exn) in + assert (phys_equal s (sub s ~pos:0 ~len:(String.length s))); + assert (phys_equal s (prefix s (String.length s))); + assert (phys_equal s (suffix s (String.length s))); + assert (phys_equal s (concat [s])); + assert (phys_equal s (tr s ~target:'\255' ~replacement:'\000')) + +let%test_module "tr_multi" = (module struct + let gold_standard ~target ~replacement string = + map string ~f:(fun char -> + match rindex target char with + | None -> char + | Some i -> get replacement (Int.min i (length replacement - 1))) + + module Test = struct + type nonrec t = + { target : t + ; replacement : t + ; string : t + ; expected : t sexp_option } + [@@deriving sexp_of] + + let quickcheck_generator = + let open Base_quickcheck.Generator in + let open Base_quickcheck.Generator.Let_syntax in + let%bind size = size in + let%bind target_len = int_log_uniform_inclusive 1 255 in + let%bind target = string_with_length ~length:target_len in + let%bind replacement_len = int_inclusive 1 target_len in + let%bind replacement = string_with_length ~length:replacement_len in + let%bind string_length = int_inclusive 0 size in + let%map string = string_with_length ~length:string_length in + { target; replacement; string; expected = None } + + let quickcheck_shrinker = Base_quickcheck.Shrinker.atomic + end + + let examples = + [ "" , "" , "abcdefg", "abcdefg" + ; "" , "a" , "abcdefg", "abcdefg" + ; "aaaa", "abcd", "abcdefg", "dbcdefg" + ; "abcd", "bcde", "abcdefg", "bcdeefg" + ; "abcd", "bcde", "" , "" + ; "abcd", "_" , "abcdefg", "____efg" + ; "abcd", "b_" , "abcdefg", "b___efg" + ; "a" , "dcba", "abcdefg", "dbcdefg" + ; "ab" , "dcba", "abcdefg", "dccdefg" + ] + |> List.map ~f:(fun (target, replacement, string, expected) -> + { Test.target; replacement; string; expected = Some expected }) + + let%test_unit _ = + Base_quickcheck.Test.run_exn (module Test) ~examples + ~f:(fun ({ target; replacement; string; expected } : Test.t) -> + (* test implementation behavior against gold standard *) + let impl_result = unstage (tr_multi ~target ~replacement) string in + let gold_result = gold_standard ~target ~replacement string in + [%test_result:t] ~expect:gold_result impl_result; + (* test against expected result, if one is provided (non-random examples) *) + Option.iter expected ~f:(fun expected -> + [%test_result:t] ~expect:expected impl_result); + (* test for returning input if the string is unchanged *) + if equal string impl_result + then assert (phys_equal string impl_result)) +end) + +let%test_unit _ = [%test_result: int option] (lfindi "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 0) +let%test_unit _ = [%test_result: int option] (lfindi ~pos:0 "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 0) +let%test_unit _ = [%test_result: int option] (lfindi ~pos:1 "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 2) +let%test_unit _ = [%test_result: int option] (lfindi "bob" ~f:(fun _ -> Char.(=) 'x')) ~expect:None + +let%test_unit _ = [%test_result: char option] + (find_map "fop" ~f:(fun c -> if Char.(c >= 'o') then Some c else None)) + ~expect:(Some 'o') +let%test_unit _ = [%test_result: _ option] (find_map "bar" ~f:(fun _ -> None)) ~expect:None +let%test_unit _ = [%test_result: _ option] (find_map "" ~f:(fun _ -> assert false)) ~expect:None + +let%test_unit _ = [%test_result: int option] (rfindi "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 2) +let%test_unit _ = [%test_result: int option] (rfindi ~pos:2 "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 2) +let%test_unit _ = [%test_result: int option] (rfindi ~pos:1 "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 0) +let%test_unit _ = [%test_result: int option] (rfindi "bob" ~f:(fun _ -> Char.(=) 'x')) ~expect:None + +let%test_unit _ = [%test_result: string] (strip " foo bar \n") ~expect:"foo bar" +let%test_unit _ = [%test_result: string] (strip ~drop:(Char.(=) '"') "\" foo bar ") ~expect:" foo bar " +let%test_unit _ = [%test_result: string] (strip ~drop:(Char.(=) '"') " \" foo bar ") ~expect:" \" foo bar " + +let%test_unit _ = [%test_result: bool] ~expect:false (exists "" ~f:(fun _ -> assert false)) +let%test_unit _ = [%test_result: bool] ~expect:false (exists "abc" ~f:(Fn.const false)) +let%test_unit _ = [%test_result: bool] ~expect:true (exists "abc" ~f:(Fn.const true)) +let%test_unit _ = [%test_result: bool] ~expect:true (exists "abc" ~f:(function + 'a' -> false | 'b' -> true | _ -> assert false)) + +let%test_unit _ = [%test_result: bool] ~expect:true (for_all "" ~f:(fun _ -> assert false)) +let%test_unit _ = [%test_result: bool] ~expect:true (for_all "abc" ~f:(Fn.const true)) +let%test_unit _ = [%test_result: bool] ~expect:false (for_all "abc" ~f:(Fn.const false)) +let%test_unit _ = [%test_result: bool] ~expect:false (for_all "abc" ~f:(function + 'a' -> true | 'b' -> false | _ -> assert false)) + +let%test_unit _ = [%test_result: (int * char) list] + (foldi "hello" ~init:[] ~f:(fun i acc ch -> (i,ch)::acc)) + ~expect:(List.rev [0,'h';1,'e';2,'l';3,'l';4,'o']) + +let%test_unit _ = [%test_result: t] (filter "hello" ~f:(Char.(<>) 'h')) ~expect:"ello" +let%test_unit _ = [%test_result: t] (filter "hello" ~f:(Char.(<>) 'l')) ~expect:"heo" +let%test_unit _ = [%test_result: t] (filter "hello" ~f:(fun _ -> false)) ~expect:"" +let%test_unit _ = [%test_result: t] (filter "hello" ~f:(fun _ -> true)) ~expect:"hello" +let%test_unit _ = + let s = "hello" in + [%test_result: bool] ~expect:true (phys_equal (filter s ~f:(fun _ -> true)) s) +let%test_unit _ = + let s = "abc" in + let r = ref 0 in + assert (phys_equal s (filter s ~f:(fun _ -> Int.incr r; true))); + assert (Int.(=) !r (String.length s)) +;; + +let%test_module "Hash" = + (module struct + external hash : string -> int = "Base_hash_string" [@@noalloc] + + let%test_unit _ = + List.iter ~f:(fun string -> + assert (Int.(=) (hash string) (Caml.Hashtbl.hash string)); + (* with 31-bit integers, the hash computed by ppx_hash overflows so it doesn't match + polymorphic hash exactly. *) + if Int.(>) Int.num_bits 31 then + assert (Int.(=) (hash string) ([%hash: string] string)) + ) + [ "Oh Gloria inmarcesible! Oh jubilo inmortal!" + ; "Oh say can you see, by the dawn's early light" + ; "Hahahaha\200" + ] + ;; + end) + +let%test _ = of_char_list ['a';'b';'c'] = "abc" +let%test _ = of_char_list [] = "" + +let%expect_test "mem does not allocate" = + let string = Sys.opaque_identity "abracadabra" in + let char = Sys.opaque_identity 'd' in + require_no_allocation [%here] (fun () -> + ignore (String.mem string char : bool)); + [%expect {||}]; +;; + +let%expect_test "is_substring_at" = + let string = "lorem ipsum dolor sit amet" in + let test pos substring = + match is_substring_at string ~pos ~substring with + | bool -> print_s [%sexp (bool : bool)] + | exception exn -> print_s [%message "raised" ~_:(exn : exn)] + in + test 0 "lorem"; + [%expect {| true |}]; + test 1 "lorem"; + [%expect {| false |}]; + test 6 "ipsum"; + [%expect {| true |}]; + test 5 "ipsum"; + [%expect {| false |}]; + test 22 "amet"; + [%expect {| true |}]; + test 23 "amet"; + [%expect {| false |}]; + test 22 "amet and some other stuff"; + [%expect {| false |}]; + test 0 ""; + [%expect {| true |}]; + test 10 ""; + [%expect {| true |}]; + test 26 ""; + [%expect {| true |}]; + test 100 ""; + [%expect {| + (raised ( + Invalid_argument + "String.is_substring_at: invalid index 100 for string of length 26")) |}]; + test (-1) ""; + [%expect {| + (raised ( + Invalid_argument + "String.is_substring_at: invalid index -1 for string of length 26")) |}]; +;; + +let%test_module "Escaping" = + (module struct + open Escaping + + let%test_module "escape_gen" = + (module struct + let escape = unstage + (escape_gen_exn + ~escapeworthy_map:[('%','p');('^','c')] ~escape_char:'_') + + let%test _ = escape "" = "" + let%test _ = escape "foo" = "foo" + let%test _ = escape "_" = "__" + let%test _ = escape "foo%bar" = "foo_pbar" + let%test _ = escape "^foo%" = "_cfoo_p" + + let escape2 = unstage + (escape_gen_exn + ~escapeworthy_map:[('_','.');('%','p');('^','c')] ~escape_char:'_') + + let%test _ = escape2 "_." = "_.." + let%test _ = escape2 "_" = "_." + let%test _ = escape2 "foo%_bar" = "foo_p_.bar" + let%test _ = escape2 "_foo%" = "_.foo_p" + + let checks_for_one_to_one escapeworthy_map = + Exn.does_raise (fun () -> escape_gen_exn ~escapeworthy_map ~escape_char:'_') + + let%test _ = checks_for_one_to_one [('%','p');('^','c');('$','c')] + let%test _ = checks_for_one_to_one [('%','p');('^','c');('%','d')] + end) + + let%test_module "unescape_gen" = + (module struct + let unescape = + unstage + (unescape_gen_exn ~escapeworthy_map:['%','p';'^','c'] ~escape_char:'_') + + let%test _ = unescape "__" = "_" + let%test _ = unescape "foo" = "foo" + let%test _ = unescape "__" = "_" + let%test _ = unescape "foo_pbar" = "foo%bar" + let%test _ = unescape "_cfoo_p" = "^foo%" + + let unescape2 = + unstage + (unescape_gen_exn ~escapeworthy_map:['_','.';'%','p';'^','c'] ~escape_char:'_') + + (* this one is ill-formed, just ignore the escape_char without escaped char *) + let%test _ = unescape2 "_" = "" + let%test _ = unescape2 "a_" = "a" + + let%test _ = unescape2 "__" = "_" + let%test _ = unescape2 "_.." = "_." + let%test _ = unescape2 "_." = "_" + let%test _ = unescape2 "foo_p_.bar" = "foo%_bar" + let%test _ = unescape2 "_.foo_p" = "_foo%" + + (* generate [n] random string and check if escaping and unescaping are consistent *) + let random_test ~escapeworthy_map ~escape_char n = + let escape = + unstage (escape_gen_exn ~escapeworthy_map ~escape_char) + in + let unescape = + unstage (unescape_gen_exn ~escapeworthy_map ~escape_char) + in + let test str = + let escaped = escape str in + let unescaped = unescape escaped in + if str <> unescaped then + failwith ( + Printf.sprintf + "string: %s\nescaped string: %s\nunescaped string: %s" + str escaped unescaped) + in + let random_char = + let print_chars = + List.range (Char.to_int Char.min_value) (Char.to_int Char.max_value + 1) + |> List.filter_map ~f:Char.of_int + |> List.filter ~f:Char.is_print + |> Array.of_list + in + fun () -> Array.random_element_exn print_chars + in + let escapeworthy_chars = + List.map escapeworthy_map ~f:fst |> Array.of_list + in + try + for _ = 0 to n - 1 do + let str = + List.init (Random.int 50) ~f:(fun _ -> + let p = Random.int 100 in + if Int.(p < 10) then + escape_char + else if Int.(p < 25) then + Array.random_element_exn escapeworthy_chars + else + random_char () + ) + |> of_char_list + in + test str + done; + true + with e -> + raise e + + let%test _ = random_test 1000 ~escapeworthy_map:['%','p';'^','c'] ~escape_char:'_' + let%test _ = random_test 1000 ~escapeworthy_map:['_','.';'%','p';'^','c'] ~escape_char:'_' + end) + + let%test_module "escape" = + (module struct + let escape = unstage (escape ~escape_char:'_' ~escapeworthy:['_'; '%'; '^']) + let%test _ = escape "foo" = "foo" + let%test _ = escape "_" = "__" + let%test _ = escape "foo%bar" = "foo_%bar" + let%test _ = escape "^foo%" = "_^foo_%" + end) + + let%test_module "unescape" = + (module struct + let unescape = unstage (unescape ~escape_char:'_') + let%test _ = unescape "foo" = "foo" + let%test _ = unescape "__" = "_" + let%test _ = unescape "foo_%bar" = "foo%bar" + let%test _ = unescape "_^foo_%" = "^foo%" + end) + + let%test_module "is_char_escaping" = + (module struct + let is = is_char_escaping ~escape_char:'_' + let%test_unit _ = [%test_result: bool] (is "___" 0) ~expect:true + let%test_unit _ = [%test_result: bool] (is "___" 1) ~expect:false + let%test_unit _ = [%test_result: bool] (is "___" 2) ~expect:true + (* considered escaping, though there's nothing to escape *) + + let%test_unit _ = [%test_result: bool] (is "a_b__c" 0) ~expect:false + let%test_unit _ = [%test_result: bool] (is "a_b__c" 1) ~expect:true + let%test_unit _ = [%test_result: bool] (is "a_b__c" 2) ~expect:false + let%test_unit _ = [%test_result: bool] (is "a_b__c" 3) ~expect:true + let%test_unit _ = [%test_result: bool] (is "a_b__c" 4) ~expect:false + let%test_unit _ = [%test_result: bool] (is "a_b__c" 5) ~expect:false + end) + + let%test_module "is_char_escaped" = + (module struct + let is = is_char_escaped ~escape_char:'_' + let%test_unit _ = [%test_result: bool] (is "___" 2) ~expect:false + let%test_unit _ = [%test_result: bool] (is "x" 0) ~expect:false + let%test_unit _ = [%test_result: bool] (is "_x" 1) ~expect:true + let%test_unit _ = [%test_result: bool] (is "sadflkas____sfff" 12) ~expect:false + let%test_unit _ = [%test_result: bool] (is "s_____s" 6) ~expect:true + end) + + let%test_module "is_char_literal" = + (module struct + let is_char_literal = is_char_literal ~escape_char:'_' + let%test_unit _ = [%test_result: bool] (is_char_literal "123456" 4) ~expect:true + let%test_unit _ = [%test_result: bool] (is_char_literal "12345_6" 6) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "12345_6" 5) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "123__456" 4) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "123456__" 7) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "__123456" 1) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "__123456" 0) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "__123456" 2) ~expect:true + end) + + let%test_module "index_from" = + (module struct + let f = index_from ~escape_char:'_' + let%test_unit _ = [%test_result: int option] (f "__" 0 '_') ~expect:None + let%test_unit _ = [%test_result: int option] (f "_.." 0 '.') ~expect:(Some 2) + let%test_unit _ = [%test_result: int option] (f "1273456_7789" 3 '7') ~expect:(Some 9) + let%test_unit _ = [%test_result: int option] (f "1273_7456_7789" 3 '7') ~expect:(Some 11) + let%test_unit _ = [%test_result: int option] (f "1273_7456_7789" 3 'z') ~expect:None + end) + + let%test_module "rindex" = + (module struct + let f = rindex_from ~escape_char:'_' + let%test_unit _ = [%test_result: int option] (f "__" 0 '_') ~expect:None + let%test_unit _ = [%test_result: int option] (f "123456_37839" 9 '3') ~expect:(Some 2) + let%test_unit _ = [%test_result: int option] (f "123_2321" 6 '2') ~expect:(Some 6) + let%test_unit _ = [%test_result: int option] (f "123_2321" 5 '2') ~expect:(Some 1) + + let%test_unit _ = [%test_result: int option] (rindex "" ~escape_char:'_' 'x') ~expect:None + let%test_unit _ = [%test_result: int option] (rindex "a_a" ~escape_char:'_' 'a') ~expect:(Some 0) + end) + + let%test_module "split" = + (module struct + let split = split ~escape_char:'_' ~on:',' + let%test_unit _ = [%test_result: string list] (split "foo,bar,baz") ~expect:["foo"; "bar"; "baz"] + let%test_unit _ = [%test_result: string list] (split "foo_,bar,baz") ~expect:["foo_,bar"; "baz"] + let%test_unit _ = [%test_result: string list] (split "foo_,bar_,baz") ~expect:["foo_,bar_,baz"] + let%test_unit _ = [%test_result: string list] (split "foo__,bar,baz") ~expect:["foo__"; "bar"; "baz"] + let%test_unit _ = [%test_result: string list] (split "foo,bar,baz_,") ~expect:["foo"; "bar"; "baz_,"] + let%test_unit _ = [%test_result: string list] (split "foo,bar_,baz_,,") ~expect:["foo"; "bar_,baz_,"; ""] + end) + + let%test_module "split_on_chars" = + (module struct + let split = split_on_chars ~escape_char:'_' ~on:[',';':'] + let%test_unit _ = [%test_result: string list] (split "foo,bar:baz") ~expect:["foo"; "bar"; "baz"] + let%test_unit _ = [%test_result: string list] (split "foo_,bar,baz") ~expect:["foo_,bar"; "baz"] + let%test_unit _ = [%test_result: string list] (split "foo_:bar_,baz") ~expect:["foo_:bar_,baz"] + let%test_unit _ = [%test_result: string list] (split "foo,bar,baz_,") ~expect:["foo"; "bar"; "baz_,"] + let%test_unit _ = [%test_result: string list] (split "foo:bar_,baz_,,") ~expect:["foo"; "bar_,baz_,"; ""] + end) + + let%test_module "split2" = + (module struct + let escape_char = '_' + let on = ',' + let%test_unit _ = [%test_result: (string * string) option] (lsplit2 ~escape_char ~on "foo_,bar,baz_,0") ~expect:(Some ("foo_,bar", "baz_,0")) + let%test_unit _ = [%test_result: (string * string) option] (rsplit2 ~escape_char ~on "foo_,bar,baz_,0") ~expect:(Some ("foo_,bar", "baz_,0")) + let%test_unit _ = [%test_result: string * string] (lsplit2_exn ~escape_char ~on "foo_,bar,baz_,0") ~expect:("foo_,bar", "baz_,0") + let%test_unit _ = [%test_result: string * string] (rsplit2_exn ~escape_char ~on "foo_,bar,baz_,0") ~expect:("foo_,bar", "baz_,0") + let%test_unit _ = [%test_result: (string * string) option] (lsplit2 ~escape_char ~on "foo_,bar") ~expect:None + let%test_unit _ = [%test_result: (string * string) option] (rsplit2 ~escape_char ~on "foo_,bar") ~expect:None + let%test _ = Exn.does_raise (fun () -> lsplit2_exn ~escape_char ~on "foo_,bar") + let%test _ = Exn.does_raise (fun () -> rsplit2_exn ~escape_char ~on "foo_,bar") + end) + + let%test _ = strip_literal ~escape_char:' ' " foo bar \n" = " foo bar \n" + let%test _ = strip_literal ~escape_char:' ' " foo bar \n\n" = " foo bar \n" + let%test _ = strip_literal ~escape_char:'\n' " foo bar \n" = "foo bar \n" + + let%test _ = lstrip_literal ~escape_char:' ' " foo bar \n\n" = " foo bar \n\n" + let%test _ = rstrip_literal ~escape_char:' ' " foo bar \n\n" = " foo bar \n" + let%test _ = lstrip_literal ~escape_char:'\n' " foo bar \n" = "foo bar \n" + let%test _ = rstrip_literal ~escape_char:'\n' " foo bar \n" = " foo bar \n" + + let%test _ = strip_literal ~drop:(Char.is_alpha) ~escape_char:'\\' "foo boar" = " " + let%test _ = strip_literal ~drop:(Char.is_alpha) ~escape_char:'\\' "fooboar" = "" + let%test _ = strip_literal ~drop:(Char.is_alpha) ~escape_char:'o' "foo boar" = "oo boa" + let%test _ = strip_literal ~drop:(Char.is_alpha) ~escape_char:'a' "foo boar" = " boar" + let%test _ = strip_literal ~drop:(Char.is_alpha) ~escape_char:'b' "foo boar" = " bo" + + let%test _ = lstrip_literal ~drop:(Char.is_alpha) ~escape_char:'o' "foo boar" = "oo boar" + let%test _ = rstrip_literal ~drop:(Char.is_alpha) ~escape_char:'o' "foo boar" = "foo boa" + let%test _ = lstrip_literal ~drop:(Char.is_alpha) ~escape_char:'b' "foo boar" = " boar" + let%test _ = rstrip_literal ~drop:(Char.is_alpha) ~escape_char:'b' "foo boar" = "foo bo" + end) diff --git a/test/test_string.mli b/test/test_string.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_string.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_type_equal.ml b/test/test_type_equal.ml new file mode 100644 index 0000000..64b3b78 --- /dev/null +++ b/test/test_type_equal.ml @@ -0,0 +1,73 @@ +open! Import +open! Type_equal + +let%expect_test "[Id.sexp_of_t]" = + let id = Id.create ~name:"some-type-id" [%sexp_of: unit] in + print_s [%sexp (id : _ Id.t)]; + [%expect {| + some-type-id |}]; +;; + +let%test_module "Type_equal.Id" = + (module struct + open Type_equal.Id + let t1 = create ~name:"t1" [%sexp_of: _] + let t2 = create ~name:"t2" [%sexp_of: _] + + let%test _ = same t1 t1 + let%test _ = not (same t1 t2) + + let%test _ = Option.is_some (same_witness t1 t1) + let%test _ = Option.is_none (same_witness t1 t2) + + let%test_unit _ = ignore (same_witness_exn t1 t1 : (_, _) Type_equal.equal) + let%test _ = Result.is_error (Result.try_with (fun () -> same_witness_exn t1 t2)) + end) + +(* This test shows that we need [conv] even though [Type_equal.T] is exposed. *) +let%test_module "Type_equal" = + (module struct + open Type_equal + + let id = Id.create ~name:"int" [%sexp_of: int] + + module A : sig + type t + val id : t Id.t + end = struct + type t = int + let id = id + end + + module B : sig + type t + val id : t Id.t + end = struct + type t = int + let id = id + end + + let _a_to_b (a : A.t) = + let eq = Id.same_witness_exn A.id B.id in + (conv eq a : B.t) + ;; + + (* the following is rejected by the compiler *) + (* let _a_to_b (a : A.t) = + * let T = Id.same_witness_exn A.id B.id in + * (a : B.t) + *) + + module C = struct + type 'a t + end + + module Liftc = Lift (C) + + let _ac_to_bc (ac : A.t C.t) = + let eq = Liftc.lift (Id.same_witness_exn A.id B.id) in + (conv eq ac : B.t C.t) + ;; + end) + + diff --git a/test/test_type_equal.mli b/test/test_type_equal.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_type_equal.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_uchar.ml b/test/test_uchar.ml new file mode 100644 index 0000000..f3ec288 --- /dev/null +++ b/test/test_uchar.ml @@ -0,0 +1,81 @@ +open! Import + +let min_int = Int.min_value +let max_int = Int.max_value + +let raises f v = Exn.does_raise (fun () -> f v) + +let%test_module "test_constants" = + (module struct + let%test _ = Uchar.(to_scalar min_value) = 0x0000 + let%test _ = Uchar.(to_scalar max_value) = 0x10FFFF + end) + +let%test_module "test_succ_exn" = + (module struct + let%test _ = raises Uchar.succ_exn Uchar.max_value + let%test _ = Uchar.(to_scalar (succ_exn min_value)) = 0x0001 + let%test _ = Uchar.(to_scalar (succ_exn (of_scalar_exn 0xD7FF))) = 0xE000 + let%test _ = Uchar.(to_scalar (succ_exn (of_scalar_exn 0xE000))) = 0xE001 + end) + +let%test_module "test_pred_exn" = + (module struct + let%test _ = raises Uchar.pred_exn Uchar.min_value + let%test _ = Uchar.(to_scalar (pred_exn (of_scalar_exn 0xD7FF))) = 0xD7FE + let%test _ = Uchar.(to_scalar (pred_exn (of_scalar_exn 0xE000))) = 0xD7FF + let%test _ = Uchar.(to_scalar (pred_exn max_value)) = 0x10FFFE + end) + +let%test_module "test_int_is_scalar" = + (module struct + let%test _ = not (Uchar.int_is_scalar (-1)) + let%test _ = Uchar.int_is_scalar 0x0000 + let%test _ = Uchar.int_is_scalar 0xD7FF + let%test _ = not (Uchar.int_is_scalar 0xD800) + let%test _ = not (Uchar.int_is_scalar 0xDFFF) + let%test _ = Uchar.int_is_scalar 0xE000 + let%test _ = Uchar.int_is_scalar 0x10FFFF + let%test _ = not (Uchar.int_is_scalar 0x110000) + let%test _ = not (Uchar.int_is_scalar min_int) + let%test _ = not (Uchar.int_is_scalar max_int) + end) + +let char_max = Uchar.of_scalar_exn 0x00FF + +let%test_module "test_is_char" = + (module struct + let%test _ = Uchar.(is_char Uchar.min_value) + let%test _ = Uchar.(is_char char_max) + let%test _ = Uchar.(not (is_char (of_scalar_exn 0x0100))) + let%test _ = not (Uchar.is_char Uchar.max_value) + end) + +let%test_module "test_of_char" = + (module struct + let%test _ = Uchar.(equal (of_char '\xFF') char_max) + let%test _ = Uchar.(equal (of_char '\x00') min_value) + end) + +let%test_module "test_to_char_exn" = + (module struct + let%test _ = Char.equal Uchar.(to_char_exn min_value) '\x00' + let%test _ = Char.equal Uchar.(to_char_exn char_max) '\xFF' + let%test _ = raises Uchar.to_char_exn (Uchar.succ_exn char_max) + let%test _ = raises Uchar.to_char_exn Uchar.max_value + end) + +let%test_module "test_equal" = + (module struct + let%test _ = Uchar.(equal min_value min_value) + let%test _ = Uchar.(equal max_value max_value) + let%test _ = not Uchar.(equal min_value max_value) + end) + +let%test_module "test_compare" = + (module struct + let%test _ = Uchar.(compare min_value min_value) = 0 + let%test _ = Uchar.(compare max_value max_value) = 0 + let%test _ = Uchar.(compare min_value max_value) = (-1) + let%test _ = Uchar.(compare max_value min_value) = 1 + end) diff --git a/test/test_uchar.mli b/test/test_uchar.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_uchar.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_uniform_array.ml b/test/test_uniform_array.ml new file mode 100644 index 0000000..dab159c --- /dev/null +++ b/test/test_uniform_array.ml @@ -0,0 +1,13 @@ +open! Import +open Uniform_array + +module Sequence = struct + type nonrec 'a t = 'a t + type 'a z = 'a + let length = length + let get = get + let set = set + let create_bool ~len = create ~len false +end + +include Base_for_tests.Test_blit.Test1(Sequence)(Uniform_array) diff --git a/test/test_validate.ml b/test/test_validate.ml new file mode 100644 index 0000000..f88ace0 --- /dev/null +++ b/test/test_validate.ml @@ -0,0 +1,95 @@ +open! Import +open! Validate + +let print t = + List.iter (errors t) ~f:Caml.print_endline +;; + +let%expect_test "Validate.all" = + print + (all [ + (fun _ -> fail "a"); + (fun _ -> pass); + (fun _ -> fail "b"); + (fun _ -> pass); + (fun _ -> fail "c"); + ] + ()); + [%expect {| + ("" a) + ("" b) + ("" c) + |}] +;; + +let%expect_test _ = + print (first_failure pass (fail "foo")); + [%expect {| ("" foo) |}] +;; + +let%expect_test _ = + print (first_failure (fail "foo") (fail "bar")); + [%expect {| ("" foo) |}] +;; + +let two_errors = of_list [fail "foo"; fail "bar"] + +let%expect_test _ = + print (first_failure two_errors (fail "snoo")); + [%expect {| + ("" foo) + ("" bar) + |}] +;; + +let%expect_test _ = + print (first_failure (fail "snoo") two_errors); + [%expect {| ("" snoo) |}] +;; + +let%expect_test _ = + let v () = + if true + then + failwith "This unit validation raises"; + Validate.pass + in + print (protect v ()); + [%expect {| + ("" + ("Exception raised during validation" + (Failure "This unit validation raises"))) |}] +;; + +let%expect_test "try_with" = + let v () = + failwith "this function raises" + in + print (try_with v); + [%expect {| + ("" ("Exception raised during validation" (Failure "this function raises"))) |}] +;; + +type t = { x : bool } [@@deriving fields] + +let%expect_test "typical use of Validate.field_direct_folder doesn't allocate on success" = + let validate_x = Staged.unstage (Validate.field_direct_folder Validate.pass_bool) in + let validate t = + Fields.Direct.fold t ~init:[] ~x:validate_x + |> Validate.of_list + |> Validate.result + in + let t = { x = true } in + require_no_allocation [%here] (fun () -> ignore (validate t : unit Or_error.t)); +;; + +let%expect_test "Validate.all doesn't allocate on success" = + let checks = List.init 5 ~f:(Fn.const Validate.pass_bool) in + require_no_allocation [%here] (fun () -> ignore (Validate.all checks true : Validate.t)); +;; + +let%expect_test "Validate.combine doesn't allocate on success" = + require_no_allocation [%here] (fun () -> + ignore (Validate.combine Validate.pass Validate.pass : Validate.t)); +;; + diff --git a/test/test_validate.mli b/test/test_validate.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_validate.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_with_return.ml b/test/test_with_return.ml new file mode 100644 index 0000000..be7423b --- /dev/null +++ b/test/test_with_return.ml @@ -0,0 +1,54 @@ +open! Import +open! With_return + +let test_loop loop_limit jump_out = + with_return (fun { return } -> + for i = 0 to loop_limit do begin + if i = jump_out then return (`Jumped_out i); + end done; + `Normal) +;; + +let ( = ) = Poly.equal + +let%test _ = test_loop 5 10 = `Normal +let%test _ = test_loop 10 5 = `Jumped_out 5 +let%test _ = test_loop 5 5 = `Jumped_out 5 + +let test_nested outer inner = + with_return (fun { return = return_outer } -> + if outer = `Outer_jump then return_outer `Outer_jump; + let inner_res = + with_return (fun { return = return_inner } -> + if inner = `Inner_jump_out_completely then return_outer `Inner_jump; + if inner = `Inner_jump then return_inner `Inner_jump; + `Inner_normal) + in + if outer = `Jump_with_inner then return_outer (`Outer_later_jump inner_res); + `Outer_normal inner_res) +;; + +let%test _ = test_nested `Outer_jump `Inner_jump = `Outer_jump +let%test _ = test_nested `Outer_jump `Inner_jump_out_completely = `Outer_jump +let%test _ = test_nested `Outer_jump `Foo = `Outer_jump + +let%test _ = test_nested `Jump_with_inner `Inner_jump_out_completely = `Inner_jump +let%test _ = test_nested `Jump_with_inner `Inner_jump = `Outer_later_jump `Inner_jump +let%test _ = test_nested `Jump_with_inner `Foo = `Outer_later_jump `Inner_normal + +let%test _ = test_nested `Foo `Inner_jump_out_completely = `Inner_jump +let%test _ = test_nested `Foo `Inner_jump = `Outer_normal `Inner_jump +let%test _ = test_nested `Foo `Foo = `Outer_normal `Inner_normal + +let test_loop loop_limit jump_out = + with_return_option (fun { return } -> + for i = 0 to loop_limit do begin + if i = jump_out then return (`Jumped_out i); + end done) +;; + +let ( = ) = Poly.equal + +let%test _ = test_loop 5 10 = None +let%test _ = test_loop 10 5 = Some (`Jumped_out 5) +let%test _ = test_loop 5 5 = Some (`Jumped_out 5) diff --git a/test/test_with_return.mli b/test/test_with_return.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_with_return.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_word_size.ml b/test/test_word_size.ml new file mode 100644 index 0000000..1275de1 --- /dev/null +++ b/test/test_word_size.ml @@ -0,0 +1,11 @@ +open! Import +open! Word_size + +let%expect_test _ = + print_s [%message (W32 : t)]; + [%expect {| + (W32 W32) |}]; + print_s [%message (W64 : t)]; + [%expect {| + (W64 W64) |}]; +;; diff --git a/test/test_word_size.mli b/test/test_word_size.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_word_size.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/validate_fields_folder.mlt b/test/validate_fields_folder.mlt new file mode 100644 index 0000000..2a9a01f --- /dev/null +++ b/test/validate_fields_folder.mlt @@ -0,0 +1,127 @@ +open! Core_kernel + +(* Regression tests to ensure that [Validate.field], [Validate.field_folder], and + [Validate.field_direct_folder] continue to work with private record types. *) + +module Fold_with_private + (M : sig + type t = private + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w f = Validate.field_folder t f in + Fields.fold ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end + +[%%expect {| +|}] + +module Fold_regular + (M : sig + type t = + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w f = Validate.field_folder t f in + Fields.fold ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end + +[%%expect {| +|}] + +module Fold_direct_private + (M : sig + type t = private + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w f = unstage (Validate.field_direct_folder f) in + Fields.Direct.fold t ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end + +[%%expect {| +|}] + +module Fold_direct_regular + (M : sig + type t = + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w f = unstage (Validate.field_direct_folder f) in + Fields.Direct.fold t ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end +[%%expect{| +|}] + +module Validate_field_private + (M : sig + type t = private + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w check acc field = Validate.field t field check :: acc in + Fields.fold ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end +[%%expect{| +|}] + +module Validate_field + (M : sig + type t = + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w check acc field = Validate.field t field check :: acc in + Fields.fold ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end +[%%expect{| +|}] -- cgit v1.2.3